#include #include #include #include "session.h" static char EAuthState[] = "bad authentication state"; static char ENotServer[] = "not a server session"; static char EVersion[] = "incorrect version number"; static char EProtocolBotch[] = "venti protocol botch"; VtSession * vtServerAlloc(VtServerVtbl *vtbl) { VtSession *z = vtAlloc(); z->vtbl = vtMemAlloc(sizeof(VtServerVtbl)); *z->vtbl = *vtbl; return z; } static int srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int ) { vtLock(z->lk); if(z->auth.state != VtAuthHello) { vtSetError(EAuthState); goto Err; } if(strcmp(version, vtGetVersion(z)) != 0) { vtSetError(EVersion); goto Err; } vtMemFree(z->uid); z->uid = vtStrDup(uid); z->auth.state = VtAuthOK; vtUnlock(z->lk); return 1; Err: z->auth.state = VtAuthFailed; vtUnlock(z->lk); return 0; } static int dispatchHello(VtSession *z, Packet **pkt) { char *version, *uid; uchar *crypto, *codec; uchar buf[10]; int ncrypto, ncodec, cryptoStrength; int ret; Packet *p; p = *pkt; version = nil; uid = nil; crypto = nil; codec = nil; ret = 0; if(!vtGetString(p, &version)) goto Err; if(!vtGetString(p, &uid)) goto Err; if(!packetConsume(p, buf, 2)) goto Err; cryptoStrength = buf[0]; ncrypto = buf[1]; crypto = vtMemAlloc(ncrypto); if(!packetConsume(p, crypto, ncrypto)) goto Err; if(!packetConsume(p, buf, 1)) goto Err; ncodec = buf[0]; codec = vtMemAlloc(ncodec); if(!packetConsume(p, codec, ncodec)) goto Err; if(packetSize(p) != 0) { vtSetError(EProtocolBotch); goto Err; } if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) { packetFree(p); *pkt = nil; } else { if(!vtAddString(p, vtGetSid(z))) goto Err; buf[0] = vtGetCrypto(z); buf[1] = vtGetCodec(z); packetAppend(p, buf, 2); } ret = 1; Err: vtMemFree(version); vtMemFree(uid); vtMemFree(crypto); vtMemFree(codec); return ret; } static int dispatchRead(VtSession *z, Packet **pkt) { Packet *p; int type, n; uchar score[VtScoreSize], buf[4]; p = *pkt; if(!packetConsume(p, score, VtScoreSize)) return 0; if(!packetConsume(p, buf, 4)) return 0; type = buf[0]; n = (buf[2]<<8) | buf[3]; if(packetSize(p) != 0) { vtSetError(EProtocolBotch); return 0; } packetFree(p); *pkt = (*z->vtbl->read)(z, score, type, n); return 1; } static int dispatchWrite(VtSession *z, Packet **pkt) { Packet *p; int type; uchar score[VtScoreSize], buf[4]; p = *pkt; if(!packetConsume(p, buf, 4)) return 0; type = buf[0]; if(!(z->vtbl->write)(z, score, type, p)) { *pkt = 0; } else { *pkt = packetAlloc(); packetAppend(*pkt, score, VtScoreSize); } return 1; } static int dispatchSync(VtSession *z, Packet **pkt) { (z->vtbl->sync)(z); if(packetSize(*pkt) != 0) { vtSetError(EProtocolBotch); return 0; } return 1; } int vtExport(VtSession *z) { Packet *p; uchar buf[10], *hdr; int op, tid, clean; if(z->vtbl == nil) { vtSetError(ENotServer); return 0; } /* fork off slave */ switch(rfork(RFNOWAIT|RFMEM|RFPROC)){ case -1: vtOSError(); return 0; case 0: break; default: return 1; } p = nil; clean = 0; vtAttach(); if(!vtConnect(z, nil)) goto Exit; vtDebug(z, "server connected!\n"); if(0) vtSetDebug(z, 1); for(;;) { p = vtRecvPacket(z); if(p == nil) { break; } vtDebug(z, "server recv: "); vtDebugMesg(z, p, "\n"); if(!packetConsume(p, buf, 2)) { vtSetError(EProtocolBotch); break; } op = buf[0]; tid = buf[1]; switch(op) { default: vtSetError(EProtocolBotch); goto Exit; case VtQPing: break; case VtQGoodbye: clean = 1; goto Exit; case VtQHello: if(!dispatchHello(z, &p)) goto Exit; break; case VtQRead: if(!dispatchRead(z, &p)) goto Exit; break; case VtQWrite: if(!dispatchWrite(z, &p)) goto Exit; break; case VtQSync: if(!dispatchSync(z, &p)) goto Exit; break; } if(p != nil) { hdr = packetHeader(p, 2); hdr[0] = op+1; hdr[1] = tid; } else { p = packetAlloc(); hdr = packetHeader(p, 2); hdr[0] = VtRError; hdr[1] = tid; if(!vtAddString(p, vtGetError())) goto Exit; } vtDebug(z, "server send: "); vtDebugMesg(z, p, "\n"); if(!vtSendPacket(z, p)) { p = nil; goto Exit; } } Exit: if(p != nil) packetFree(p); if(z->vtbl->closing) z->vtbl->closing(z, clean); vtClose(z); vtFree(z); vtDetach(); exits(0); return 0; /* never gets here */ }