rpc.c 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /*
  2. * This file is part of the UCB release of Plan 9. It is subject to the license
  3. * terms in the LICENSE file found in the top-level directory of this
  4. * distribution and at http://akaros.cs.berkeley.edu/files/Plan9License. No
  5. * part of the UCB release of Plan 9, including this file, may be copied,
  6. * modified, propagated, or distributed except according to the terms contained
  7. * in the LICENSE file.
  8. */
  9. #include <u.h>
  10. #include <libc.h>
  11. #include <oventi.h>
  12. #include "session.h"
  13. struct {
  14. int version;
  15. char *s;
  16. } vtVersions[] = {
  17. {VtVersion02, "02"},
  18. {0, 0},
  19. };
  20. static char EBigString[] = "string too long";
  21. static char EBigPacket[] = "packet too long";
  22. static char ENullString[] = "missing string";
  23. static char EBadVersion[] = "bad format in version string";
  24. VtSession *
  25. vtAlloc(void)
  26. {
  27. VtSession *z;
  28. z = vtMemAllocZ(sizeof(VtSession));
  29. z->lk = vtLockAlloc();
  30. // z->inHash = vtSha1Alloc();
  31. z->inLock = vtLockAlloc();
  32. z->part = packetAlloc();
  33. // z->outHash = vtSha1Alloc();
  34. z->outLock = vtLockAlloc();
  35. z->fd = -1;
  36. z->uid = vtStrDup("anonymous");
  37. z->sid = vtStrDup("anonymous");
  38. return z;
  39. }
  40. void
  41. vtReset(VtSession *z)
  42. {
  43. vtLock(z->lk);
  44. z->cstate = VtStateAlloc;
  45. if(z->fd >= 0){
  46. vtFdClose(z->fd);
  47. z->fd = -1;
  48. }
  49. vtUnlock(z->lk);
  50. }
  51. int
  52. vtConnected(VtSession *z)
  53. {
  54. return z->cstate == VtStateConnected;
  55. }
  56. void
  57. vtDisconnect(VtSession *z, int error)
  58. {
  59. Packet *p;
  60. uint8_t *b;
  61. vtDebug(z, "vtDisconnect\n");
  62. vtLock(z->lk);
  63. if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
  64. /* clean shutdown */
  65. p = packetAlloc();
  66. b = packetHeader(p, 2);
  67. b[0] = VtQGoodbye;
  68. b[1] = 0;
  69. vtSendPacket(z, p);
  70. }
  71. if(z->fd >= 0)
  72. vtFdClose(z->fd);
  73. z->fd = -1;
  74. z->cstate = VtStateClosed;
  75. vtUnlock(z->lk);
  76. }
  77. void
  78. vtClose(VtSession *z)
  79. {
  80. vtDisconnect(z, 0);
  81. }
  82. void
  83. vtFree(VtSession *z)
  84. {
  85. if(z == nil)
  86. return;
  87. vtLockFree(z->lk);
  88. vtSha1Free(z->inHash);
  89. vtLockFree(z->inLock);
  90. packetFree(z->part);
  91. vtSha1Free(z->outHash);
  92. vtLockFree(z->outLock);
  93. vtMemFree(z->uid);
  94. vtMemFree(z->sid);
  95. vtMemFree(z->vtbl);
  96. memset(z, 0, sizeof(VtSession));
  97. z->fd = -1;
  98. vtMemFree(z);
  99. }
  100. char *
  101. vtGetUid(VtSession *s)
  102. {
  103. return s->uid;
  104. }
  105. char *
  106. vtGetSid(VtSession *z)
  107. {
  108. return z->sid;
  109. }
  110. int
  111. vtSetDebug(VtSession *z, int debug)
  112. {
  113. int old;
  114. vtLock(z->lk);
  115. old = z->debug;
  116. z->debug = debug;
  117. vtUnlock(z->lk);
  118. return old;
  119. }
  120. int
  121. vtSetFd(VtSession *z, int fd)
  122. {
  123. vtLock(z->lk);
  124. if(z->cstate != VtStateAlloc) {
  125. vtSetError("bad state");
  126. vtUnlock(z->lk);
  127. return 0;
  128. }
  129. if(z->fd >= 0)
  130. vtFdClose(z->fd);
  131. z->fd = fd;
  132. vtUnlock(z->lk);
  133. return 1;
  134. }
  135. int
  136. vtGetFd(VtSession *z)
  137. {
  138. return z->fd;
  139. }
  140. int
  141. vtSetCryptoStrength(VtSession *z, int c)
  142. {
  143. if(z->cstate != VtStateAlloc) {
  144. vtSetError("bad state");
  145. return 0;
  146. }
  147. if(c != VtCryptoStrengthNone) {
  148. vtSetError("not supported yet");
  149. return 0;
  150. }
  151. return 1;
  152. }
  153. int
  154. vtGetCryptoStrength(VtSession *s)
  155. {
  156. return s->cryptoStrength;
  157. }
  158. int
  159. vtSetCompression(VtSession *z, int fd)
  160. {
  161. vtLock(z->lk);
  162. if(z->cstate != VtStateAlloc) {
  163. vtSetError("bad state");
  164. vtUnlock(z->lk);
  165. return 0;
  166. }
  167. z->fd = fd;
  168. vtUnlock(z->lk);
  169. return 1;
  170. }
  171. int
  172. vtGetCompression(VtSession *s)
  173. {
  174. return s->compression;
  175. }
  176. int
  177. vtGetCrypto(VtSession *s)
  178. {
  179. return s->crypto;
  180. }
  181. int
  182. vtGetCodec(VtSession *s)
  183. {
  184. return s->codec;
  185. }
  186. char *
  187. vtGetVersion(VtSession *z)
  188. {
  189. int v, i;
  190. v = z->version;
  191. if(v == 0)
  192. return "unknown";
  193. for(i=0; vtVersions[i].version; i++)
  194. if(vtVersions[i].version == v)
  195. return vtVersions[i].s;
  196. assert(0);
  197. return 0;
  198. }
  199. /* hold z->inLock */
  200. static int
  201. vtVersionRead(VtSession *z, char *prefix, int *ret)
  202. {
  203. char c;
  204. char buf[VtMaxStringSize];
  205. char *q, *p, *pp;
  206. int i;
  207. q = prefix;
  208. p = buf;
  209. for(;;) {
  210. if(p >= buf + sizeof(buf)) {
  211. vtSetError(EBadVersion);
  212. return 0;
  213. }
  214. if(!vtFdReadFully(z->fd, (uint8_t*)&c, 1))
  215. return 0;
  216. if(z->inHash)
  217. vtSha1Update(z->inHash, (uint8_t*)&c, 1);
  218. if(c == '\n') {
  219. *p = 0;
  220. break;
  221. }
  222. if(c < ' ' || (*q && c != *q)) {
  223. vtSetError(EBadVersion);
  224. return 0;
  225. }
  226. *p++ = c;
  227. if(*q)
  228. q++;
  229. }
  230. vtDebug(z, "version string in: %s\n", buf);
  231. p = buf + strlen(prefix);
  232. for(;;) {
  233. for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++)
  234. ;
  235. for(i=0; vtVersions[i].version; i++) {
  236. if(strlen(vtVersions[i].s) != pp-p)
  237. continue;
  238. if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
  239. *ret = vtVersions[i].version;
  240. return 1;
  241. }
  242. }
  243. p = pp;
  244. if(*p != ':')
  245. return 0;
  246. p++;
  247. }
  248. }
  249. Packet*
  250. vtRecvPacket(VtSession *z)
  251. {
  252. uint8_t buf[10], *b;
  253. int n;
  254. Packet *p;
  255. int size, len;
  256. if(z->cstate != VtStateConnected) {
  257. vtSetError("session not connected");
  258. return 0;
  259. }
  260. vtLock(z->inLock);
  261. p = z->part;
  262. /* get enough for head size */
  263. size = packetSize(p);
  264. while(size < 2) {
  265. b = packetTrailer(p, MaxFragSize);
  266. assert(b != nil);
  267. n = vtFdRead(z->fd, b, MaxFragSize);
  268. if(n <= 0)
  269. goto Err;
  270. size += n;
  271. packetTrim(p, 0, size);
  272. }
  273. if(!packetConsume(p, buf, 2))
  274. goto Err;
  275. len = (buf[0] << 8) | buf[1];
  276. size -= 2;
  277. while(size < len) {
  278. n = len - size;
  279. if(n > MaxFragSize)
  280. n = MaxFragSize;
  281. b = packetTrailer(p, n);
  282. if(!vtFdReadFully(z->fd, b, n))
  283. goto Err;
  284. size += n;
  285. }
  286. p = packetSplit(p, len);
  287. vtUnlock(z->inLock);
  288. return p;
  289. Err:
  290. vtUnlock(z->inLock);
  291. return nil;
  292. }
  293. int
  294. vtSendPacket(VtSession *z, Packet *p)
  295. {
  296. IOchunk ioc;
  297. int n;
  298. uint8_t buf[2];
  299. /* add framing */
  300. n = packetSize(p);
  301. if(n >= (1<<16)) {
  302. vtSetError(EBigPacket);
  303. packetFree(p);
  304. return 0;
  305. }
  306. buf[0] = n>>8;
  307. buf[1] = n;
  308. packetPrefix(p, buf, 2);
  309. for(;;) {
  310. n = packetFragments(p, &ioc, 1, 0);
  311. if(n == 0)
  312. break;
  313. if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
  314. packetFree(p);
  315. return 0;
  316. }
  317. packetConsume(p, nil, n);
  318. }
  319. packetFree(p);
  320. return 1;
  321. }
  322. int
  323. vtGetString(Packet *p, char **ret)
  324. {
  325. uint8_t buf[2];
  326. int n;
  327. char *s;
  328. if(!packetConsume(p, buf, 2))
  329. return 0;
  330. n = (buf[0]<<8) + buf[1];
  331. if(n > VtMaxStringSize) {
  332. vtSetError(EBigString);
  333. return 0;
  334. }
  335. s = vtMemAlloc(n+1);
  336. setmalloctag(s, getcallerpc());
  337. if(!packetConsume(p, (uint8_t*)s, n)) {
  338. vtMemFree(s);
  339. return 0;
  340. }
  341. s[n] = 0;
  342. *ret = s;
  343. return 1;
  344. }
  345. int
  346. vtAddString(Packet *p, char *s)
  347. {
  348. uint8_t buf[2];
  349. int n;
  350. if(s == nil) {
  351. vtSetError(ENullString);
  352. return 0;
  353. }
  354. n = strlen(s);
  355. if(n > VtMaxStringSize) {
  356. vtSetError(EBigString);
  357. return 0;
  358. }
  359. buf[0] = n>>8;
  360. buf[1] = n;
  361. packetAppend(p, buf, 2);
  362. packetAppend(p, (uint8_t*)s, n);
  363. return 1;
  364. }
  365. int
  366. vtConnect(VtSession *z, char *password)
  367. {
  368. char buf[VtMaxStringSize], *p, *ep, *prefix;
  369. int i;
  370. USED(password);
  371. vtLock(z->lk);
  372. if(z->cstate != VtStateAlloc) {
  373. vtSetError("bad session state");
  374. vtUnlock(z->lk);
  375. return 0;
  376. }
  377. if(z->fd < 0){
  378. vtSetError("%s", z->fderror);
  379. vtUnlock(z->lk);
  380. return 0;
  381. }
  382. /* be a little anal */
  383. vtLock(z->inLock);
  384. vtLock(z->outLock);
  385. prefix = "venti-";
  386. p = buf;
  387. ep = buf + sizeof(buf);
  388. p = seprint(p, ep, "%s", prefix);
  389. p += strlen(p);
  390. for(i=0; vtVersions[i].version; i++) {
  391. if(i != 0)
  392. *p++ = ':';
  393. p = seprint(p, ep, "%s", vtVersions[i].s);
  394. }
  395. p = seprint(p, ep, "-libventi\n");
  396. assert(p-buf < sizeof(buf));
  397. if(z->outHash)
  398. vtSha1Update(z->outHash, (uint8_t*)buf, p-buf);
  399. if(!vtFdWrite(z->fd, (uint8_t*)buf, p-buf))
  400. goto Err;
  401. vtDebug(z, "version string out: %s", buf);
  402. if(!vtVersionRead(z, prefix, &z->version))
  403. goto Err;
  404. vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
  405. vtUnlock(z->inLock);
  406. vtUnlock(z->outLock);
  407. z->cstate = VtStateConnected;
  408. vtUnlock(z->lk);
  409. if(z->vtbl)
  410. return 1;
  411. if(!vtHello(z))
  412. goto Err;
  413. return 1;
  414. Err:
  415. if(z->fd >= 0)
  416. vtFdClose(z->fd);
  417. z->fd = -1;
  418. vtUnlock(z->inLock);
  419. vtUnlock(z->outLock);
  420. z->cstate = VtStateClosed;
  421. vtUnlock(z->lk);
  422. return 0;
  423. }