rpc.c 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. #include <u.h>
  2. #include <libc.h>
  3. #include <venti.h>
  4. #include "session.h"
  5. struct {
  6. int version;
  7. char *s;
  8. } vtVersions[] = {
  9. VtVersion02, "02",
  10. 0, 0,
  11. };
  12. static char EBigString[] = "string too long";
  13. static char EBigPacket[] = "packet too long";
  14. static char ENullString[] = "missing string";
  15. static char EBadVersion[] = "bad format in version string";
  16. static Packet *vtRPC(VtSession *z, int op, Packet *p);
  17. VtSession *
  18. vtAlloc(void)
  19. {
  20. VtSession *z;
  21. z = vtMemAllocZ(sizeof(VtSession));
  22. z->lk = vtLockAlloc();
  23. // z->inHash = vtSha1Alloc();
  24. z->inLock = vtLockAlloc();
  25. z->part = packetAlloc();
  26. // z->outHash = vtSha1Alloc();
  27. z->outLock = vtLockAlloc();
  28. z->fd = -1;
  29. z->uid = vtStrDup("anonymous");
  30. z->sid = vtStrDup("anonymous");
  31. return z;
  32. }
  33. void
  34. vtReset(VtSession *z)
  35. {
  36. vtLock(z->lk);
  37. z->cstate = VtStateAlloc;
  38. if(z->fd >= 0){
  39. vtFdClose(z->fd);
  40. z->fd = -1;
  41. }
  42. vtUnlock(z->lk);
  43. }
  44. int
  45. vtConnected(VtSession *z)
  46. {
  47. return z->cstate == VtStateConnected;
  48. }
  49. void
  50. vtDisconnect(VtSession *z, int error)
  51. {
  52. Packet *p;
  53. uchar *b;
  54. vtDebug(z, "vtDisconnect\n");
  55. vtLock(z->lk);
  56. if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
  57. /* clean shutdown */
  58. p = packetAlloc();
  59. b = packetHeader(p, 2);
  60. b[0] = VtQGoodbye;
  61. b[1] = 0;
  62. vtSendPacket(z, p);
  63. }
  64. if(z->fd >= 0)
  65. vtFdClose(z->fd);
  66. z->fd = -1;
  67. z->cstate = VtStateClosed;
  68. vtUnlock(z->lk);
  69. }
  70. void
  71. vtClose(VtSession *z)
  72. {
  73. vtDisconnect(z, 0);
  74. }
  75. void
  76. vtFree(VtSession *z)
  77. {
  78. if(z == nil)
  79. return;
  80. vtLockFree(z->lk);
  81. vtSha1Free(z->inHash);
  82. vtLockFree(z->inLock);
  83. packetFree(z->part);
  84. vtSha1Free(z->outHash);
  85. vtLockFree(z->outLock);
  86. vtMemFree(z->uid);
  87. vtMemFree(z->sid);
  88. vtMemFree(z->vtbl);
  89. memset(z, 0, sizeof(VtSession));
  90. z->fd = -1;
  91. vtMemFree(z);
  92. }
  93. char *
  94. vtGetUid(VtSession *s)
  95. {
  96. return s->uid;
  97. }
  98. char *
  99. vtGetSid(VtSession *z)
  100. {
  101. return z->sid;
  102. }
  103. int
  104. vtSetDebug(VtSession *z, int debug)
  105. {
  106. int old;
  107. vtLock(z->lk);
  108. old = z->debug;
  109. z->debug = debug;
  110. vtUnlock(z->lk);
  111. return old;
  112. }
  113. int
  114. vtSetFd(VtSession *z, int fd)
  115. {
  116. vtLock(z->lk);
  117. if(z->cstate != VtStateAlloc) {
  118. vtSetError("bad state");
  119. vtUnlock(z->lk);
  120. return 0;
  121. }
  122. if(z->fd >= 0)
  123. vtFdClose(z->fd);
  124. z->fd = fd;
  125. vtUnlock(z->lk);
  126. return 1;
  127. }
  128. int
  129. vtGetFd(VtSession *z)
  130. {
  131. return z->fd;
  132. }
  133. int
  134. vtSetCryptoStrength(VtSession *z, int c)
  135. {
  136. if(z->cstate != VtStateAlloc) {
  137. vtSetError("bad state");
  138. return 0;
  139. }
  140. if(c != VtCryptoStrengthNone) {
  141. vtSetError("not supported yet");
  142. return 0;
  143. }
  144. return 1;
  145. }
  146. int
  147. vtGetCryptoStrength(VtSession *s)
  148. {
  149. return s->cryptoStrength;
  150. }
  151. int
  152. vtSetCompression(VtSession *z, int fd)
  153. {
  154. vtLock(z->lk);
  155. if(z->cstate != VtStateAlloc) {
  156. vtSetError("bad state");
  157. vtUnlock(z->lk);
  158. return 0;
  159. }
  160. z->fd = fd;
  161. vtUnlock(z->lk);
  162. return 1;
  163. }
  164. int
  165. vtGetCompression(VtSession *s)
  166. {
  167. return s->compression;
  168. }
  169. int
  170. vtGetCrypto(VtSession *s)
  171. {
  172. return s->crypto;
  173. }
  174. int
  175. vtGetCodec(VtSession *s)
  176. {
  177. return s->codec;
  178. }
  179. char *
  180. vtGetVersion(VtSession *z)
  181. {
  182. int v, i;
  183. v = z->version;
  184. if(v == 0)
  185. return "unknown";
  186. for(i=0; vtVersions[i].version; i++)
  187. if(vtVersions[i].version == v)
  188. return vtVersions[i].s;
  189. assert(0);
  190. return 0;
  191. }
  192. /* hold z->inLock */
  193. static int
  194. vtVersionRead(VtSession *z, char *prefix, int *ret)
  195. {
  196. char c;
  197. char buf[VtMaxStringSize];
  198. char *q, *p, *pp;
  199. int i;
  200. q = prefix;
  201. p = buf;
  202. for(;;) {
  203. if(p >= buf + sizeof(buf)) {
  204. vtSetError(EBadVersion);
  205. return 0;
  206. }
  207. if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
  208. return 0;
  209. if(z->inHash)
  210. vtSha1Update(z->inHash, (uchar*)&c, 1);
  211. if(c == '\n') {
  212. *p = 0;
  213. break;
  214. }
  215. if(c < ' ' || c > 0x7f || *q && c != *q) {
  216. vtSetError(EBadVersion);
  217. return 0;
  218. }
  219. *p++ = c;
  220. if(*q)
  221. q++;
  222. }
  223. vtDebug(z, "version string in: %s\n", buf);
  224. p = buf + strlen(prefix);
  225. for(;;) {
  226. for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++)
  227. ;
  228. for(i=0; vtVersions[i].version; i++) {
  229. if(strlen(vtVersions[i].s) != pp-p)
  230. continue;
  231. if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
  232. *ret = vtVersions[i].version;
  233. return 1;
  234. }
  235. }
  236. p = pp;
  237. if(*p != ':')
  238. return 0;
  239. p++;
  240. }
  241. }
  242. Packet*
  243. vtRecvPacket(VtSession *z)
  244. {
  245. uchar buf[10], *b;
  246. int n;
  247. Packet *p;
  248. int size, len;
  249. if(z->cstate != VtStateConnected) {
  250. vtSetError("session not connected");
  251. return 0;
  252. }
  253. vtLock(z->inLock);
  254. p = z->part;
  255. /* get enough for head size */
  256. size = packetSize(p);
  257. while(size < 2) {
  258. b = packetTrailer(p, MaxFragSize);
  259. assert(b != nil);
  260. n = vtFdRead(z->fd, b, MaxFragSize);
  261. if(n <= 0)
  262. goto Err;
  263. size += n;
  264. packetTrim(p, 0, size);
  265. }
  266. if(!packetConsume(p, buf, 2))
  267. goto Err;
  268. len = (buf[0] << 8) | buf[1];
  269. size -= 2;
  270. while(size < len) {
  271. n = len - size;
  272. if(n > MaxFragSize)
  273. n = MaxFragSize;
  274. b = packetTrailer(p, n);
  275. if(!vtFdReadFully(z->fd, b, n))
  276. goto Err;
  277. size += n;
  278. }
  279. p = packetSplit(p, len);
  280. vtUnlock(z->inLock);
  281. return p;
  282. Err:
  283. vtUnlock(z->inLock);
  284. return nil;
  285. }
  286. int
  287. vtSendPacket(VtSession *z, Packet *p)
  288. {
  289. IOchunk ioc;
  290. int n;
  291. uchar buf[2];
  292. /* add framing */
  293. n = packetSize(p);
  294. if(n >= (1<<16)) {
  295. vtSetError(EBigPacket);
  296. packetFree(p);
  297. return 0;
  298. }
  299. buf[0] = n>>8;
  300. buf[1] = n;
  301. packetPrefix(p, buf, 2);
  302. for(;;) {
  303. n = packetFragments(p, &ioc, 1, 0);
  304. if(n == 0)
  305. break;
  306. if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
  307. packetFree(p);
  308. return 0;
  309. }
  310. packetConsume(p, nil, n);
  311. }
  312. packetFree(p);
  313. return 1;
  314. }
  315. int
  316. vtGetString(Packet *p, char **ret)
  317. {
  318. uchar buf[2];
  319. int n;
  320. char *s;
  321. if(!packetConsume(p, buf, 2))
  322. return 0;
  323. n = (buf[0]<<8) + buf[1];
  324. if(n > VtMaxStringSize) {
  325. vtSetError(EBigString);
  326. return 0;
  327. }
  328. s = vtMemAlloc(n+1);
  329. if(!packetConsume(p, (uchar*)s, n)) {
  330. vtMemFree(s);
  331. return 0;
  332. }
  333. s[n] = 0;
  334. *ret = s;
  335. return 1;
  336. }
  337. int
  338. vtAddString(Packet *p, char *s)
  339. {
  340. uchar buf[2];
  341. int n;
  342. if(s == nil) {
  343. vtSetError(ENullString);
  344. return 0;
  345. }
  346. n = strlen(s);
  347. if(n > VtMaxStringSize) {
  348. vtSetError(EBigString);
  349. return 0;
  350. }
  351. buf[0] = n>>8;
  352. buf[1] = n;
  353. packetAppend(p, buf, 2);
  354. packetAppend(p, (uchar*)s, n);
  355. return 1;
  356. }
  357. int
  358. vtConnect(VtSession *z, char *password)
  359. {
  360. char buf[VtMaxStringSize], *p, *ep, *prefix;
  361. int i;
  362. USED(password);
  363. vtLock(z->lk);
  364. if(z->cstate != VtStateAlloc) {
  365. vtSetError("bad session state");
  366. vtUnlock(z->lk);
  367. return 0;
  368. }
  369. if(z->fd < 0){
  370. vtSetError("%s", z->fderror);
  371. vtUnlock(z->lk);
  372. return 0;
  373. }
  374. /* be a little anal */
  375. vtLock(z->inLock);
  376. vtLock(z->outLock);
  377. prefix = "venti-";
  378. p = buf;
  379. ep = buf + sizeof(buf);
  380. p = seprint(p, ep, "%s", prefix);
  381. p += strlen(p);
  382. for(i=0; vtVersions[i].version; i++) {
  383. if(i != 0)
  384. *p++ = ':';
  385. p = seprint(p, ep, "%s", vtVersions[i].s);
  386. }
  387. p = seprint(p, ep, "-libventi\n");
  388. assert(p-buf < sizeof(buf));
  389. if(z->outHash)
  390. vtSha1Update(z->outHash, (uchar*)buf, p-buf);
  391. if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
  392. goto Err;
  393. vtDebug(z, "version string out: %s", buf);
  394. if(!vtVersionRead(z, prefix, &z->version))
  395. goto Err;
  396. vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
  397. vtUnlock(z->inLock);
  398. vtUnlock(z->outLock);
  399. z->cstate = VtStateConnected;
  400. vtUnlock(z->lk);
  401. if(z->vtbl)
  402. return 1;
  403. if(!vtHello(z))
  404. goto Err;
  405. return 1;
  406. Err:
  407. if(z->fd >= 0)
  408. vtFdClose(z->fd);
  409. z->fd = -1;
  410. vtUnlock(z->inLock);
  411. vtUnlock(z->outLock);
  412. z->cstate = VtStateClosed;
  413. vtUnlock(z->lk);
  414. return 0;
  415. }