123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- #include <u.h>
- #include <libc.h>
- #include <venti.h>
- #include "session.h"
- struct {
- int version;
- char *s;
- } vtVersions[] = {
- VtVersion02, "02",
- 0, 0,
- };
- static char EBigString[] = "string too long";
- static char EBigPacket[] = "packet too long";
- static char ENullString[] = "missing string";
- static char EBadVersion[] = "bad format in version string";
- static Packet *vtRPC(VtSession *z, int op, Packet *p);
- VtSession *
- vtAlloc(void)
- {
- VtSession *z;
- z = vtMemAllocZ(sizeof(VtSession));
- z->lk = vtLockAlloc();
- // z->inHash = vtSha1Alloc();
- z->inLock = vtLockAlloc();
- z->part = packetAlloc();
- // z->outHash = vtSha1Alloc();
- z->outLock = vtLockAlloc();
- z->fd = -1;
- z->uid = vtStrDup("anonymous");
- z->sid = vtStrDup("anonymous");
- return z;
- }
- void
- vtReset(VtSession *z)
- {
- vtLock(z->lk);
- z->cstate = VtStateAlloc;
- if(z->fd >= 0){
- vtFdClose(z->fd);
- z->fd = -1;
- }
- vtUnlock(z->lk);
- }
- int
- vtConnected(VtSession *z)
- {
- return z->cstate == VtStateConnected;
- }
- void
- vtDisconnect(VtSession *z, int error)
- {
- Packet *p;
- uchar *b;
- vtDebug(z, "vtDisconnect\n");
- vtLock(z->lk);
- if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
- /* clean shutdown */
- p = packetAlloc();
- b = packetHeader(p, 2);
- b[0] = VtQGoodbye;
- b[1] = 0;
- vtSendPacket(z, p);
- }
- if(z->fd >= 0)
- vtFdClose(z->fd);
- z->fd = -1;
- z->cstate = VtStateClosed;
- vtUnlock(z->lk);
- }
- void
- vtClose(VtSession *z)
- {
- vtDisconnect(z, 0);
- }
- void
- vtFree(VtSession *z)
- {
- if(z == nil)
- return;
- vtLockFree(z->lk);
- vtSha1Free(z->inHash);
- vtLockFree(z->inLock);
- packetFree(z->part);
- vtSha1Free(z->outHash);
- vtLockFree(z->outLock);
- vtMemFree(z->uid);
- vtMemFree(z->sid);
- vtMemFree(z->vtbl);
- memset(z, 0, sizeof(VtSession));
- z->fd = -1;
- vtMemFree(z);
- }
- char *
- vtGetUid(VtSession *s)
- {
- return s->uid;
- }
- char *
- vtGetSid(VtSession *z)
- {
- return z->sid;
- }
- int
- vtSetDebug(VtSession *z, int debug)
- {
- int old;
- vtLock(z->lk);
- old = z->debug;
- z->debug = debug;
- vtUnlock(z->lk);
- return old;
- }
- int
- vtSetFd(VtSession *z, int fd)
- {
- vtLock(z->lk);
- if(z->cstate != VtStateAlloc) {
- vtSetError("bad state");
- vtUnlock(z->lk);
- return 0;
- }
- if(z->fd >= 0)
- vtFdClose(z->fd);
- z->fd = fd;
- vtUnlock(z->lk);
- return 1;
- }
- int
- vtGetFd(VtSession *z)
- {
- return z->fd;
- }
- int
- vtSetCryptoStrength(VtSession *z, int c)
- {
- if(z->cstate != VtStateAlloc) {
- vtSetError("bad state");
- return 0;
- }
- if(c != VtCryptoStrengthNone) {
- vtSetError("not supported yet");
- return 0;
- }
- return 1;
- }
- int
- vtGetCryptoStrength(VtSession *s)
- {
- return s->cryptoStrength;
- }
- int
- vtSetCompression(VtSession *z, int fd)
- {
- vtLock(z->lk);
- if(z->cstate != VtStateAlloc) {
- vtSetError("bad state");
- vtUnlock(z->lk);
- return 0;
- }
- z->fd = fd;
- vtUnlock(z->lk);
- return 1;
- }
- int
- vtGetCompression(VtSession *s)
- {
- return s->compression;
- }
- int
- vtGetCrypto(VtSession *s)
- {
- return s->crypto;
- }
- int
- vtGetCodec(VtSession *s)
- {
- return s->codec;
- }
- char *
- vtGetVersion(VtSession *z)
- {
- int v, i;
-
- v = z->version;
- if(v == 0)
- return "unknown";
- for(i=0; vtVersions[i].version; i++)
- if(vtVersions[i].version == v)
- return vtVersions[i].s;
- assert(0);
- return 0;
- }
- /* hold z->inLock */
- static int
- vtVersionRead(VtSession *z, char *prefix, int *ret)
- {
- char c;
- char buf[VtMaxStringSize];
- char *q, *p, *pp;
- int i;
- q = prefix;
- p = buf;
- for(;;) {
- if(p >= buf + sizeof(buf)) {
- vtSetError(EBadVersion);
- return 0;
- }
- if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
- return 0;
- if(z->inHash)
- vtSha1Update(z->inHash, (uchar*)&c, 1);
- if(c == '\n') {
- *p = 0;
- break;
- }
- if(c < ' ' || c > 0x7f || *q && c != *q) {
- vtSetError(EBadVersion);
- return 0;
- }
- *p++ = c;
- if(*q)
- q++;
- }
-
- vtDebug(z, "version string in: %s\n", buf);
- p = buf + strlen(prefix);
- for(;;) {
- for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++)
- ;
- for(i=0; vtVersions[i].version; i++) {
- if(strlen(vtVersions[i].s) != pp-p)
- continue;
- if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
- *ret = vtVersions[i].version;
- return 1;
- }
- }
- p = pp;
- if(*p != ':')
- return 0;
- p++;
- }
- }
- Packet*
- vtRecvPacket(VtSession *z)
- {
- uchar buf[10], *b;
- int n;
- Packet *p;
- int size, len;
- if(z->cstate != VtStateConnected) {
- vtSetError("session not connected");
- return 0;
- }
- vtLock(z->inLock);
- p = z->part;
- /* get enough for head size */
- size = packetSize(p);
- while(size < 2) {
- b = packetTrailer(p, MaxFragSize);
- assert(b != nil);
- n = vtFdRead(z->fd, b, MaxFragSize);
- if(n <= 0)
- goto Err;
- size += n;
- packetTrim(p, 0, size);
- }
- if(!packetConsume(p, buf, 2))
- goto Err;
- len = (buf[0] << 8) | buf[1];
- size -= 2;
- while(size < len) {
- n = len - size;
- if(n > MaxFragSize)
- n = MaxFragSize;
- b = packetTrailer(p, n);
- if(!vtFdReadFully(z->fd, b, n))
- goto Err;
- size += n;
- }
- p = packetSplit(p, len);
- vtUnlock(z->inLock);
- return p;
- Err:
- vtUnlock(z->inLock);
- return nil;
- }
- int
- vtSendPacket(VtSession *z, Packet *p)
- {
- IOchunk ioc;
- int n;
- uchar buf[2];
-
- /* add framing */
- n = packetSize(p);
- if(n >= (1<<16)) {
- vtSetError(EBigPacket);
- packetFree(p);
- return 0;
- }
- buf[0] = n>>8;
- buf[1] = n;
- packetPrefix(p, buf, 2);
- for(;;) {
- n = packetFragments(p, &ioc, 1, 0);
- if(n == 0)
- break;
- if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
- packetFree(p);
- return 0;
- }
- packetConsume(p, nil, n);
- }
- packetFree(p);
- return 1;
- }
- int
- vtGetString(Packet *p, char **ret)
- {
- uchar buf[2];
- int n;
- char *s;
- if(!packetConsume(p, buf, 2))
- return 0;
- n = (buf[0]<<8) + buf[1];
- if(n > VtMaxStringSize) {
- vtSetError(EBigString);
- return 0;
- }
- s = vtMemAlloc(n+1);
- if(!packetConsume(p, (uchar*)s, n)) {
- vtMemFree(s);
- return 0;
- }
- s[n] = 0;
- *ret = s;
- return 1;
- }
- int
- vtAddString(Packet *p, char *s)
- {
- uchar buf[2];
- int n;
- if(s == nil) {
- vtSetError(ENullString);
- return 0;
- }
- n = strlen(s);
- if(n > VtMaxStringSize) {
- vtSetError(EBigString);
- return 0;
- }
- buf[0] = n>>8;
- buf[1] = n;
- packetAppend(p, buf, 2);
- packetAppend(p, (uchar*)s, n);
- return 1;
- }
- int
- vtConnect(VtSession *z, char *password)
- {
- char buf[VtMaxStringSize], *p, *ep, *prefix;
- int i;
- USED(password);
- vtLock(z->lk);
- if(z->cstate != VtStateAlloc) {
- vtSetError("bad session state");
- vtUnlock(z->lk);
- return 0;
- }
- if(z->fd < 0){
- vtSetError("%s", z->fderror);
- vtUnlock(z->lk);
- return 0;
- }
- /* be a little anal */
- vtLock(z->inLock);
- vtLock(z->outLock);
- prefix = "venti-";
- p = buf;
- ep = buf + sizeof(buf);
- p = seprint(p, ep, "%s", prefix);
- p += strlen(p);
- for(i=0; vtVersions[i].version; i++) {
- if(i != 0)
- *p++ = ':';
- p = seprint(p, ep, "%s", vtVersions[i].s);
- }
- p = seprint(p, ep, "-libventi\n");
- assert(p-buf < sizeof(buf));
- if(z->outHash)
- vtSha1Update(z->outHash, (uchar*)buf, p-buf);
- if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
- goto Err;
-
- vtDebug(z, "version string out: %s", buf);
- if(!vtVersionRead(z, prefix, &z->version))
- goto Err;
-
- vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
- vtUnlock(z->inLock);
- vtUnlock(z->outLock);
- z->cstate = VtStateConnected;
- vtUnlock(z->lk);
- if(z->vtbl)
- return 1;
- if(!vtHello(z))
- goto Err;
- return 1;
- Err:
- if(z->fd >= 0)
- vtFdClose(z->fd);
- z->fd = -1;
- vtUnlock(z->inLock);
- vtUnlock(z->outLock);
- z->cstate = VtStateClosed;
- vtUnlock(z->lk);
- return 0;
- }
|