123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- #include <u.h>
- #include <libc.h>
- #include <venti.h>
- #include "session.h"
- static char EAuthState[] = "bad authentication state";
- static char ENotServer[] = "not a server session";
- static char EVersion[] = "incorrect version number";
- static char EProtocolBotch[] = "venti protocol botch";
- VtSession *
- vtServerAlloc(VtServerVtbl *vtbl)
- {
- VtSession *z = vtAlloc();
- z->vtbl = vtMemAlloc(sizeof(VtServerVtbl));
- *z->vtbl = *vtbl;
- return z;
- }
- static int
- srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int )
- {
- vtLock(z->lk);
- if(z->auth.state != VtAuthHello) {
- vtSetError(EAuthState);
- goto Err;
- }
- if(strcmp(version, vtGetVersion(z)) != 0) {
- vtSetError(EVersion);
- goto Err;
- }
- vtMemFree(z->uid);
- z->uid = vtStrDup(uid);
- z->auth.state = VtAuthOK;
- vtUnlock(z->lk);
- return 1;
- Err:
- z->auth.state = VtAuthFailed;
- vtUnlock(z->lk);
- return 0;
- }
- static int
- dispatchHello(VtSession *z, Packet **pkt)
- {
- char *version, *uid;
- uchar *crypto, *codec;
- uchar buf[10];
- int ncrypto, ncodec, cryptoStrength;
- int ret;
- Packet *p;
- p = *pkt;
- version = nil;
- uid = nil;
- crypto = nil;
- codec = nil;
- ret = 0;
- if(!vtGetString(p, &version))
- goto Err;
- if(!vtGetString(p, &uid))
- goto Err;
- if(!packetConsume(p, buf, 2))
- goto Err;
- cryptoStrength = buf[0];
- ncrypto = buf[1];
- crypto = vtMemAlloc(ncrypto);
- if(!packetConsume(p, crypto, ncrypto))
- goto Err;
- if(!packetConsume(p, buf, 1))
- goto Err;
- ncodec = buf[0];
- codec = vtMemAlloc(ncodec);
- if(!packetConsume(p, codec, ncodec))
- goto Err;
- if(packetSize(p) != 0) {
- vtSetError(EProtocolBotch);
- goto Err;
- }
- if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) {
- packetFree(p);
- *pkt = nil;
- } else {
- if(!vtAddString(p, vtGetSid(z)))
- goto Err;
- buf[0] = vtGetCrypto(z);
- buf[1] = vtGetCodec(z);
- packetAppend(p, buf, 2);
- }
- ret = 1;
- Err:
- vtMemFree(version);
- vtMemFree(uid);
- vtMemFree(crypto);
- vtMemFree(codec);
- return ret;
- }
- static int
- dispatchRead(VtSession *z, Packet **pkt)
- {
- Packet *p;
- int type, n;
- uchar score[VtScoreSize], buf[4];
- p = *pkt;
- if(!packetConsume(p, score, VtScoreSize))
- return 0;
- if(!packetConsume(p, buf, 4))
- return 0;
- type = buf[0];
- n = (buf[2]<<8) | buf[3];
- if(packetSize(p) != 0) {
- vtSetError(EProtocolBotch);
- return 0;
- }
- packetFree(p);
- *pkt = (*z->vtbl->read)(z, score, type, n);
- return 1;
- }
- static int
- dispatchWrite(VtSession *z, Packet **pkt)
- {
- Packet *p;
- int type;
- uchar score[VtScoreSize], buf[4];
- p = *pkt;
- if(!packetConsume(p, buf, 4))
- return 0;
- type = buf[0];
- if(!(z->vtbl->write)(z, score, type, p)) {
- *pkt = 0;
- } else {
- *pkt = packetAlloc();
- packetAppend(*pkt, score, VtScoreSize);
- }
- return 1;
- }
- static int
- dispatchSync(VtSession *z, Packet **pkt)
- {
- (z->vtbl->sync)(z);
- if(packetSize(*pkt) != 0) {
- vtSetError(EProtocolBotch);
- return 0;
- }
- return 1;
- }
- int
- vtExport(VtSession *z)
- {
- Packet *p;
- uchar buf[10], *hdr;
- int op, tid, clean;
- if(z->vtbl == nil) {
- vtSetError(ENotServer);
- return 0;
- }
- /* fork off slave */
- switch(rfork(RFNOWAIT|RFMEM|RFPROC)){
- case -1:
- vtOSError();
- return 0;
- case 0:
- break;
- default:
- return 1;
- }
-
- p = nil;
- clean = 0;
- vtAttach();
- if(!vtConnect(z, nil))
- goto Exit;
- vtDebug(z, "server connected!\n");
- if(0) vtSetDebug(z, 1);
- for(;;) {
- p = vtRecvPacket(z);
- if(p == nil) {
- break;
- }
- vtDebug(z, "server recv: ");
- vtDebugMesg(z, p, "\n");
- if(!packetConsume(p, buf, 2)) {
- vtSetError(EProtocolBotch);
- break;
- }
- op = buf[0];
- tid = buf[1];
- switch(op) {
- default:
- vtSetError(EProtocolBotch);
- goto Exit;
- case VtQPing:
- break;
- case VtQGoodbye:
- clean = 1;
- goto Exit;
- case VtQHello:
- if(!dispatchHello(z, &p))
- goto Exit;
- break;
- case VtQRead:
- if(!dispatchRead(z, &p))
- goto Exit;
- break;
- case VtQWrite:
- if(!dispatchWrite(z, &p))
- goto Exit;
- break;
- case VtQSync:
- if(!dispatchSync(z, &p))
- goto Exit;
- break;
- }
- if(p != nil) {
- hdr = packetHeader(p, 2);
- hdr[0] = op+1;
- hdr[1] = tid;
- } else {
- p = packetAlloc();
- hdr = packetHeader(p, 2);
- hdr[0] = VtRError;
- hdr[1] = tid;
- if(!vtAddString(p, vtGetError()))
- goto Exit;
- }
- vtDebug(z, "server send: ");
- vtDebugMesg(z, p, "\n");
- if(!vtSendPacket(z, p)) {
- p = nil;
- goto Exit;
- }
- }
- Exit:
- if(p != nil)
- packetFree(p);
- if(z->vtbl->closing)
- z->vtbl->closing(z, clean);
- vtClose(z);
- vtFree(z);
- vtDetach();
- exits(0);
- return 0; /* never gets here */
- }
|