123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- /* vim: set expandtab ts=4 sw=4: */
- /*
- * You may redistribute this program and/or modify it under the terms of
- * the GNU General Public License as published by the Free Software Foundation,
- * either version 3 of the License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <https://www.gnu.org/licenses/>.
- */
- #include "interface/FramingIface.h"
- #include "interface/Iface.h"
- #include "memory/Allocator.h"
- #include "util/Identity.h"
- #include "wire/Error.h"
- #include "util/Assert.h"
- #define REQUIRED_PADDING 512
- struct MessageList;
- struct MessageList {
- struct Message* msg;
- struct MessageList* next;
- };
- struct FramingIface_pvt {
- struct Iface messageIf;
- struct Iface streamIf;
- const uint32_t maxMessageSize;
- struct Allocator* alloc;
- // fields specific to this frame.
- uint32_t bytesRemaining;
- struct Allocator* frameAlloc;
- struct MessageList* frameParts;
- union {
- uint32_t length_be;
- uint8_t bytes[4];
- } header;
- uint32_t headerIndex;
- Identity
- };
- static struct Message* mergeMessage(struct FramingIface_pvt* fi, struct Message* last)
- {
- int length = Message_getLength(last);
- // The only accurate way to get the full length because this last might contain
- // the beginning of the next frame.
- struct MessageList* part = fi->frameParts;
- for (part = fi->frameParts; part; part = part->next) {
- length += Message_getLength(part->msg);
- }
- struct Message* out = Message_new(0, length + REQUIRED_PADDING, fi->frameAlloc);
- Er_assert(Message_epush(out, last->msgbytes, Message_getLength(last)));
- int fd = Message_getAssociatedFd(last);
- for (part = fi->frameParts; part; part = part->next) {
- Er_assert(Message_epush(out, part->msg->msgbytes, Message_getLength(part->msg)));
- if (fd == -1) {
- fd = Message_getAssociatedFd(part->msg);
- }
- }
- if (fd > -1) {
- Message_setAssociatedFd(out, fd);
- }
- Assert_true(length <= Message_getLength(out));
- return out;
- }
- static Iface_DEFUN receiveMessage(struct Message* msg, struct Iface* streamIf)
- {
- struct FramingIface_pvt* fi = Identity_containerOf(streamIf, struct FramingIface_pvt, streamIf);
- if (fi->bytesRemaining > fi->maxMessageSize) {
- // Oversize message
- Assert_ifTesting(0);
- return Error(msg, "OVERSIZE_MESSAGE");
- }
- if (fi->frameParts) {
- if (fi->bytesRemaining <= (uint32_t)Message_getLength(msg)) {
- struct Message* wholeMessage = mergeMessage(fi, msg);
- fi->bytesRemaining = 0;
- fi->frameParts = NULL;
- Assert_true(fi->headerIndex == 0);
- struct Allocator* frameAlloc = fi->frameAlloc;
- fi->frameAlloc = NULL;
- // Run the message through again since it's almost certainly not perfect size.
- Iface_CALL(receiveMessage, wholeMessage, streamIf);
- Allocator_free(frameAlloc);
- return NULL;
- }
- fi->bytesRemaining -= Message_getLength(msg);
- Allocator_adopt(fi->frameAlloc, Message_getAlloc(msg));
- struct MessageList* parts = Allocator_calloc(fi->frameAlloc, sizeof(struct MessageList), 1);
- parts->msg = msg;
- parts->next = fi->frameParts;
- fi->frameParts = parts;
- return NULL;
- }
- for (;;) {
- while (fi->headerIndex < 4) {
- if (!Message_getLength(msg)) {
- return NULL;
- }
- fi->header.bytes[fi->headerIndex] = msg->msgbytes[0];
- Er_assert(Message_eshift(msg, -1));
- fi->headerIndex++;
- }
- fi->headerIndex = 0;
- fi->bytesRemaining = Endian_bigEndianToHost32(fi->header.length_be);
- if (fi->bytesRemaining > fi->maxMessageSize) {
- // oversize
- Assert_ifTesting(0);
- return Error(msg, "OVERSIZE_MESSAGE");
- }
- if (fi->bytesRemaining == (uint32_t)Message_getLength(msg)) {
- fi->bytesRemaining = 0;
- return Iface_next(&fi->messageIf, msg);
- } else if (fi->bytesRemaining < (uint32_t)Message_getLength(msg)) {
- struct Allocator* alloc = Allocator_child(Message_getAlloc(msg));
- struct Message* m = Message_new(fi->bytesRemaining, REQUIRED_PADDING, alloc);
- Message_setAssociatedFd(m, Message_getAssociatedFd(msg));
- Bits_memcpy(m->msgbytes, msg->msgbytes, fi->bytesRemaining);
- Er_assert(Message_eshift(msg, -fi->bytesRemaining));
- fi->bytesRemaining = 0;
- Iface_send(&fi->messageIf, m);
- Allocator_free(alloc);
- continue;
- } else {
- fi->frameAlloc = Allocator_child(fi->alloc);
- struct Message* m = Message_new(0, Message_getLength(msg) + 4, fi->frameAlloc);
- Message_setAssociatedFd(m, Message_getAssociatedFd(msg));
- Er_assert(Message_epush(m, msg->msgbytes, Message_getLength(msg)));
- Er_assert(Message_epush(m, fi->header.bytes, 4));
- fi->bytesRemaining -= Message_getLength(msg);
- fi->frameParts = Allocator_calloc(fi->frameAlloc, sizeof(struct MessageList), 1);
- fi->frameParts->msg = m;
- fi->frameParts->next = NULL;
- }
- return NULL;
- }
- }
- static Iface_DEFUN sendMessage(struct Message* msg, struct Iface* messageIf)
- {
- struct FramingIface_pvt* fi =
- Identity_containerOf(messageIf, struct FramingIface_pvt, messageIf);
- Er_assert(Message_epush32be(msg, Message_getLength(msg)));
- return Iface_next(&fi->streamIf, msg);
- }
- struct Iface* FramingIface_new(uint32_t maxMsgSize, struct Iface* toWrap, struct Allocator* alloc)
- {
- struct FramingIface_pvt* context =
- Allocator_clone(alloc, (&(struct FramingIface_pvt) {
- .maxMessageSize = maxMsgSize,
- .alloc = alloc,
- .streamIf = { .send = receiveMessage },
- .messageIf = { .send = sendMessage }
- }));
- Identity_set(context);
- Iface_plumb(toWrap, &context->streamIf);
- return &context->messageIf;
- }
|