rpc.c 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. /*
  2. * This file is part of the UCB release of Plan 9. It is subject to the license
  3. * terms in the LICENSE file found in the top-level directory of this
  4. * distribution and at http://akaros.cs.berkeley.edu/files/Plan9License. No
  5. * part of the UCB release of Plan 9, including this file, may be copied,
  6. * modified, propagated, or distributed except according to the terms contained
  7. * in the LICENSE file.
  8. */
  9. #include <u.h>
  10. #include <libc.h>
  11. #include <oventi.h>
  12. #include "session.h"
  13. struct {
  14. int version;
  15. char *s;
  16. } vtVersions[] = {
  17. VtVersion02, "02",
  18. 0, 0,
  19. };
  20. static char EBigString[] = "string too long";
  21. static char EBigPacket[] = "packet too long";
  22. static char ENullString[] = "missing string";
  23. static char EBadVersion[] = "bad format in version string";
  24. static Packet *vtRPC(VtSession *z, int op, Packet *p);
  25. VtSession *
  26. vtAlloc(void)
  27. {
  28. VtSession *z;
  29. z = vtMemAllocZ(sizeof(VtSession));
  30. z->lk = vtLockAlloc();
  31. // z->inHash = vtSha1Alloc();
  32. z->inLock = vtLockAlloc();
  33. z->part = packetAlloc();
  34. // z->outHash = vtSha1Alloc();
  35. z->outLock = vtLockAlloc();
  36. z->fd = -1;
  37. z->uid = vtStrDup("anonymous");
  38. z->sid = vtStrDup("anonymous");
  39. return z;
  40. }
  41. void
  42. vtReset(VtSession *z)
  43. {
  44. vtLock(z->lk);
  45. z->cstate = VtStateAlloc;
  46. if(z->fd >= 0){
  47. vtFdClose(z->fd);
  48. z->fd = -1;
  49. }
  50. vtUnlock(z->lk);
  51. }
  52. int
  53. vtConnected(VtSession *z)
  54. {
  55. return z->cstate == VtStateConnected;
  56. }
  57. void
  58. vtDisconnect(VtSession *z, int error)
  59. {
  60. Packet *p;
  61. uint8_t *b;
  62. vtDebug(z, "vtDisconnect\n");
  63. vtLock(z->lk);
  64. if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
  65. /* clean shutdown */
  66. p = packetAlloc();
  67. b = packetHeader(p, 2);
  68. b[0] = VtQGoodbye;
  69. b[1] = 0;
  70. vtSendPacket(z, p);
  71. }
  72. if(z->fd >= 0)
  73. vtFdClose(z->fd);
  74. z->fd = -1;
  75. z->cstate = VtStateClosed;
  76. vtUnlock(z->lk);
  77. }
  78. void
  79. vtClose(VtSession *z)
  80. {
  81. vtDisconnect(z, 0);
  82. }
  83. void
  84. vtFree(VtSession *z)
  85. {
  86. if(z == nil)
  87. return;
  88. vtLockFree(z->lk);
  89. vtSha1Free(z->inHash);
  90. vtLockFree(z->inLock);
  91. packetFree(z->part);
  92. vtSha1Free(z->outHash);
  93. vtLockFree(z->outLock);
  94. vtMemFree(z->uid);
  95. vtMemFree(z->sid);
  96. vtMemFree(z->vtbl);
  97. memset(z, 0, sizeof(VtSession));
  98. z->fd = -1;
  99. vtMemFree(z);
  100. }
  101. char *
  102. vtGetUid(VtSession *s)
  103. {
  104. return s->uid;
  105. }
  106. char *
  107. vtGetSid(VtSession *z)
  108. {
  109. return z->sid;
  110. }
  111. int
  112. vtSetDebug(VtSession *z, int debug)
  113. {
  114. int old;
  115. vtLock(z->lk);
  116. old = z->debug;
  117. z->debug = debug;
  118. vtUnlock(z->lk);
  119. return old;
  120. }
  121. int
  122. vtSetFd(VtSession *z, int fd)
  123. {
  124. vtLock(z->lk);
  125. if(z->cstate != VtStateAlloc) {
  126. vtSetError("bad state");
  127. vtUnlock(z->lk);
  128. return 0;
  129. }
  130. if(z->fd >= 0)
  131. vtFdClose(z->fd);
  132. z->fd = fd;
  133. vtUnlock(z->lk);
  134. return 1;
  135. }
  136. int
  137. vtGetFd(VtSession *z)
  138. {
  139. return z->fd;
  140. }
  141. int
  142. vtSetCryptoStrength(VtSession *z, int c)
  143. {
  144. if(z->cstate != VtStateAlloc) {
  145. vtSetError("bad state");
  146. return 0;
  147. }
  148. if(c != VtCryptoStrengthNone) {
  149. vtSetError("not supported yet");
  150. return 0;
  151. }
  152. return 1;
  153. }
  154. int
  155. vtGetCryptoStrength(VtSession *s)
  156. {
  157. return s->cryptoStrength;
  158. }
  159. int
  160. vtSetCompression(VtSession *z, int fd)
  161. {
  162. vtLock(z->lk);
  163. if(z->cstate != VtStateAlloc) {
  164. vtSetError("bad state");
  165. vtUnlock(z->lk);
  166. return 0;
  167. }
  168. z->fd = fd;
  169. vtUnlock(z->lk);
  170. return 1;
  171. }
  172. int
  173. vtGetCompression(VtSession *s)
  174. {
  175. return s->compression;
  176. }
  177. int
  178. vtGetCrypto(VtSession *s)
  179. {
  180. return s->crypto;
  181. }
  182. int
  183. vtGetCodec(VtSession *s)
  184. {
  185. return s->codec;
  186. }
  187. char *
  188. vtGetVersion(VtSession *z)
  189. {
  190. int v, i;
  191. v = z->version;
  192. if(v == 0)
  193. return "unknown";
  194. for(i=0; vtVersions[i].version; i++)
  195. if(vtVersions[i].version == v)
  196. return vtVersions[i].s;
  197. assert(0);
  198. return 0;
  199. }
  200. /* hold z->inLock */
  201. static int
  202. vtVersionRead(VtSession *z, char *prefix, int *ret)
  203. {
  204. char c;
  205. char buf[VtMaxStringSize];
  206. char *q, *p, *pp;
  207. int i;
  208. q = prefix;
  209. p = buf;
  210. for(;;) {
  211. if(p >= buf + sizeof(buf)) {
  212. vtSetError(EBadVersion);
  213. return 0;
  214. }
  215. if(!vtFdReadFully(z->fd, (uint8_t*)&c, 1))
  216. return 0;
  217. if(z->inHash)
  218. vtSha1Update(z->inHash, (uint8_t*)&c, 1);
  219. if(c == '\n') {
  220. *p = 0;
  221. break;
  222. }
  223. if(c < ' ' || *q && c != *q) {
  224. vtSetError(EBadVersion);
  225. return 0;
  226. }
  227. *p++ = c;
  228. if(*q)
  229. q++;
  230. }
  231. vtDebug(z, "version string in: %s\n", buf);
  232. p = buf + strlen(prefix);
  233. for(;;) {
  234. for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++)
  235. ;
  236. for(i=0; vtVersions[i].version; i++) {
  237. if(strlen(vtVersions[i].s) != pp-p)
  238. continue;
  239. if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
  240. *ret = vtVersions[i].version;
  241. return 1;
  242. }
  243. }
  244. p = pp;
  245. if(*p != ':')
  246. return 0;
  247. p++;
  248. }
  249. }
  250. Packet*
  251. vtRecvPacket(VtSession *z)
  252. {
  253. uint8_t buf[10], *b;
  254. int n;
  255. Packet *p;
  256. int size, len;
  257. if(z->cstate != VtStateConnected) {
  258. vtSetError("session not connected");
  259. return 0;
  260. }
  261. vtLock(z->inLock);
  262. p = z->part;
  263. /* get enough for head size */
  264. size = packetSize(p);
  265. while(size < 2) {
  266. b = packetTrailer(p, MaxFragSize);
  267. assert(b != nil);
  268. n = vtFdRead(z->fd, b, MaxFragSize);
  269. if(n <= 0)
  270. goto Err;
  271. size += n;
  272. packetTrim(p, 0, size);
  273. }
  274. if(!packetConsume(p, buf, 2))
  275. goto Err;
  276. len = (buf[0] << 8) | buf[1];
  277. size -= 2;
  278. while(size < len) {
  279. n = len - size;
  280. if(n > MaxFragSize)
  281. n = MaxFragSize;
  282. b = packetTrailer(p, n);
  283. if(!vtFdReadFully(z->fd, b, n))
  284. goto Err;
  285. size += n;
  286. }
  287. p = packetSplit(p, len);
  288. vtUnlock(z->inLock);
  289. return p;
  290. Err:
  291. vtUnlock(z->inLock);
  292. return nil;
  293. }
  294. int
  295. vtSendPacket(VtSession *z, Packet *p)
  296. {
  297. IOchunk ioc;
  298. int n;
  299. uint8_t buf[2];
  300. /* add framing */
  301. n = packetSize(p);
  302. if(n >= (1<<16)) {
  303. vtSetError(EBigPacket);
  304. packetFree(p);
  305. return 0;
  306. }
  307. buf[0] = n>>8;
  308. buf[1] = n;
  309. packetPrefix(p, buf, 2);
  310. for(;;) {
  311. n = packetFragments(p, &ioc, 1, 0);
  312. if(n == 0)
  313. break;
  314. if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
  315. packetFree(p);
  316. return 0;
  317. }
  318. packetConsume(p, nil, n);
  319. }
  320. packetFree(p);
  321. return 1;
  322. }
  323. int
  324. vtGetString(Packet *p, char **ret)
  325. {
  326. uint8_t buf[2];
  327. int n;
  328. char *s;
  329. if(!packetConsume(p, buf, 2))
  330. return 0;
  331. n = (buf[0]<<8) + buf[1];
  332. if(n > VtMaxStringSize) {
  333. vtSetError(EBigString);
  334. return 0;
  335. }
  336. s = vtMemAlloc(n+1);
  337. setmalloctag(s, getcallerpc(&p));
  338. if(!packetConsume(p, (uint8_t*)s, n)) {
  339. vtMemFree(s);
  340. return 0;
  341. }
  342. s[n] = 0;
  343. *ret = s;
  344. return 1;
  345. }
  346. int
  347. vtAddString(Packet *p, char *s)
  348. {
  349. uint8_t buf[2];
  350. int n;
  351. if(s == nil) {
  352. vtSetError(ENullString);
  353. return 0;
  354. }
  355. n = strlen(s);
  356. if(n > VtMaxStringSize) {
  357. vtSetError(EBigString);
  358. return 0;
  359. }
  360. buf[0] = n>>8;
  361. buf[1] = n;
  362. packetAppend(p, buf, 2);
  363. packetAppend(p, (uint8_t*)s, n);
  364. return 1;
  365. }
  366. int
  367. vtConnect(VtSession *z, char *password)
  368. {
  369. char buf[VtMaxStringSize], *p, *ep, *prefix;
  370. int i;
  371. USED(password);
  372. vtLock(z->lk);
  373. if(z->cstate != VtStateAlloc) {
  374. vtSetError("bad session state");
  375. vtUnlock(z->lk);
  376. return 0;
  377. }
  378. if(z->fd < 0){
  379. vtSetError("%s", z->fderror);
  380. vtUnlock(z->lk);
  381. return 0;
  382. }
  383. /* be a little anal */
  384. vtLock(z->inLock);
  385. vtLock(z->outLock);
  386. prefix = "venti-";
  387. p = buf;
  388. ep = buf + sizeof(buf);
  389. p = seprint(p, ep, "%s", prefix);
  390. p += strlen(p);
  391. for(i=0; vtVersions[i].version; i++) {
  392. if(i != 0)
  393. *p++ = ':';
  394. p = seprint(p, ep, "%s", vtVersions[i].s);
  395. }
  396. p = seprint(p, ep, "-libventi\n");
  397. assert(p-buf < sizeof(buf));
  398. if(z->outHash)
  399. vtSha1Update(z->outHash, (uint8_t*)buf, p-buf);
  400. if(!vtFdWrite(z->fd, (uint8_t*)buf, p-buf))
  401. goto Err;
  402. vtDebug(z, "version string out: %s", buf);
  403. if(!vtVersionRead(z, prefix, &z->version))
  404. goto Err;
  405. vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
  406. vtUnlock(z->inLock);
  407. vtUnlock(z->outLock);
  408. z->cstate = VtStateConnected;
  409. vtUnlock(z->lk);
  410. if(z->vtbl)
  411. return 1;
  412. if(!vtHello(z))
  413. goto Err;
  414. return 1;
  415. Err:
  416. if(z->fd >= 0)
  417. vtFdClose(z->fd);
  418. z->fd = -1;
  419. vtUnlock(z->inLock);
  420. vtUnlock(z->outLock);
  421. z->cstate = VtStateClosed;
  422. vtUnlock(z->lk);
  423. return 0;
  424. }