aan.c 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. #include <u.h>
  2. #include <libc.h>
  3. #include <auth.h>
  4. #include <fcall.h>
  5. #include <thread.h>
  6. #define NS(x) ((vlong)x)
  7. #define US(x) (NS(x) * 1000LL)
  8. #define MS(x) (US(x) * 1000LL)
  9. #define S(x) (MS(x) * 1000LL)
  10. #define LOGNAME "aan"
  11. enum {
  12. Synctime = S(8),
  13. Nbuf = 10,
  14. K = 1024,
  15. Bufsize = 8 * K,
  16. Stacksize = 8 * K,
  17. Timer = 0, // Alt channels.
  18. Unsent = 1,
  19. Maxto = 24 * 3600, // A full day to reconnect.
  20. };
  21. typedef struct Endpoints Endpoints;
  22. struct Endpoints {
  23. char *lsys;
  24. char *lserv;
  25. char *rsys;
  26. char *rserv;
  27. };
  28. typedef struct {
  29. ulong nb; // Number of data bytes in this message
  30. ulong msg; // Message number
  31. ulong acked; // Number of messages acked
  32. } Hdr;
  33. typedef struct t_Buf {
  34. Hdr hdr;
  35. uchar buf[Bufsize];
  36. } Buf;
  37. static char *progname;
  38. static Channel *unsent;
  39. static Channel *unacked;
  40. static Channel *empty;
  41. static int netfd;
  42. static int inmsg;
  43. static char *devdir;
  44. static int debug;
  45. static int done;
  46. static char *dialstring;
  47. static int maxto = Maxto;
  48. static char *Logname = LOGNAME;
  49. static int client;
  50. static Alt a[] = {
  51. /* c v op */
  52. { nil, nil, CHANRCV }, // timer
  53. { nil, nil, CHANRCV }, // unsent
  54. { nil, nil, CHANEND },
  55. };
  56. static void fromnet(void*);
  57. static void fromclient(void*);
  58. static void reconnect(void);
  59. static void synchronize(void);
  60. static int sendcommand(ulong, ulong);
  61. static void showmsg(int, char *, Buf *);
  62. static int writen(int, uchar *, int);
  63. static int getport(char *);
  64. static void dmessage(int, char *, ...);
  65. static void timerproc(void *);
  66. static Endpoints *getendpoints(char *);
  67. static void freeendpoints(Endpoints *);
  68. static void
  69. usage(void)
  70. {
  71. fprint(2, "Usage: %s [-c] [-d] [-m maxto] dialstring|handle\n", progname);
  72. exits(nil);
  73. }
  74. static int
  75. catch(void *, char *s)
  76. {
  77. if (!strcmp(s, "alarm")) {
  78. syslog(0, Logname, "Timed out while waiting for client on %s, exiting...",
  79. devdir);
  80. threadexitsall(nil);
  81. }
  82. return 0;
  83. }
  84. void
  85. threadmain(int argc, char **argv)
  86. {
  87. int i, failed;
  88. Buf *b;
  89. Channel *timer;
  90. vlong synctime;
  91. progname = argv[0];
  92. ARGBEGIN {
  93. case 'c':
  94. client++;
  95. break;
  96. case 'd':
  97. debug++;
  98. break;
  99. case 'm':
  100. maxto = (int)strtol(EARGF(usage()), (char **)nil, 0);
  101. break;
  102. case '?':
  103. default:
  104. usage();
  105. }
  106. ARGEND;
  107. if (argc != 1)
  108. usage();
  109. if (!client) {
  110. char *p;
  111. devdir = argv[0];
  112. if ((p = strstr(devdir, "/local")) != nil)
  113. *p = '\0';
  114. }
  115. else
  116. dialstring = argv[0];
  117. if (debug > 0) {
  118. int fd = open("#c/cons", OWRITE|OCEXEC);
  119. dup(fd, 2);
  120. }
  121. fmtinstall('F', fcallfmt);
  122. atnotify(catch, 1);
  123. unsent = chancreate(sizeof(Buf *), Nbuf);
  124. unacked = chancreate(sizeof(Buf *), Nbuf);
  125. empty = chancreate(sizeof(Buf *), Nbuf);
  126. timer = chancreate(sizeof(uchar *), 1);
  127. for (i = 0; i != Nbuf; i++) {
  128. Buf *b = malloc(sizeof(Buf));
  129. sendp(empty, b);
  130. }
  131. netfd = -1;
  132. if (proccreate(fromnet, nil, Stacksize) < 0)
  133. sysfatal("%s; Cannot start fromnet; %r\n", progname);
  134. reconnect(); // Set up the initial connection.
  135. synchronize();
  136. if (proccreate(fromclient, nil, Stacksize) < 0)
  137. sysfatal("%s; Cannot start fromclient; %r\n", progname);
  138. if (proccreate(timerproc, timer, Stacksize) < 0)
  139. sysfatal("%s; Cannot start timerproc; %r\n", progname);
  140. a[Timer].c = timer;
  141. a[Unsent].c = unsent;
  142. a[Unsent].v = &b;
  143. synctime = nsec() + Synctime;
  144. failed = 0;
  145. while (!done) {
  146. vlong now;
  147. int delta;
  148. if (failed) {
  149. // Wait for the netreader to die.
  150. while (netfd >= 0) {
  151. dmessage(1, "main; waiting for netreader to die\n");
  152. sleep(1000);
  153. }
  154. // the reader died; reestablish the world.
  155. reconnect();
  156. synchronize();
  157. failed = 0;
  158. }
  159. now = nsec();
  160. delta = (synctime - nsec()) / MS(1);
  161. if (delta <= 0) {
  162. Hdr hdr;
  163. hdr.nb = 0;
  164. hdr.acked = inmsg;
  165. hdr.msg = -1;
  166. if (writen(netfd, (uchar *)&hdr, sizeof(Hdr)) < 0) {
  167. dmessage(2, "main; writen failed; %r\n");
  168. failed = 1;
  169. continue;
  170. }
  171. synctime = nsec() + Synctime;
  172. assert(synctime > now);
  173. }
  174. switch (alt(a)) {
  175. case Timer:
  176. break;
  177. case Unsent:
  178. sendp(unacked, b);
  179. b->hdr.acked = inmsg;
  180. if (writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)) < 0) {
  181. dmessage(2, "main; writen failed; %r\n");
  182. failed = 1;
  183. }
  184. if (writen(netfd, b->buf, b->hdr.nb) < 0) {
  185. dmessage(2, "main; writen failed; %r\n");
  186. failed = 1;
  187. }
  188. if (b->hdr.nb == 0)
  189. done = 1;
  190. break;
  191. }
  192. }
  193. syslog(0, Logname, "exiting...");
  194. threadexitsall(nil);
  195. }
  196. static void
  197. fromclient(void*)
  198. {
  199. static int outmsg;
  200. for (;;) {
  201. Buf *b;
  202. b = recvp(empty);
  203. if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
  204. if ((int)b->hdr.nb < 0)
  205. dmessage(2, "fromclient; Cannot read 9P message; %r\n");
  206. else
  207. dmessage(2, "fromclient; Client terminated\n");
  208. b->hdr.nb = 0;
  209. }
  210. b->hdr.msg = outmsg++;
  211. showmsg(1, "fromclient", b);
  212. sendp(unsent, b);
  213. if (b->hdr.nb == 0)
  214. break;
  215. }
  216. }
  217. static void
  218. fromnet(void*)
  219. {
  220. static int lastacked;
  221. Buf *b;
  222. b = (Buf *)malloc(sizeof(Buf));
  223. assert(b);
  224. while (!done) {
  225. int len, acked, i;
  226. while (netfd < 0) {
  227. dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n",
  228. inmsg);
  229. sleep(1000);
  230. }
  231. // Read the header.
  232. if ((len = readn(netfd, &b->hdr, sizeof(Hdr))) <= 0) {
  233. if (len < 0)
  234. dmessage(1, "fromnet; (hdr) network failure; %r\n");
  235. else
  236. dmessage(1, "fromnet; (hdr) network closed\n");
  237. close(netfd);
  238. netfd = -1;
  239. continue;
  240. }
  241. dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len,
  242. b->hdr.nb, b->hdr.msg);
  243. if (b->hdr.nb == 0) {
  244. if ((long)b->hdr.msg >= 0) {
  245. dmessage(1, "fromnet; network closed\n");
  246. break;
  247. }
  248. continue;
  249. }
  250. if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) {
  251. if (len == 0)
  252. dmessage(1, "fromnet; network closed\n");
  253. else
  254. dmessage(1, "fromnet; network failure; %r\n");
  255. close(netfd);
  256. netfd = -1;
  257. continue;
  258. }
  259. if (b->hdr.msg < inmsg) {
  260. dmessage(1, "fromnet; skipping message %d, currently at %d\n",
  261. b->hdr.msg, inmsg);
  262. continue;
  263. }
  264. // Process the acked list.
  265. acked = b->hdr.acked - lastacked;
  266. for (i = 0; i != acked; i++) {
  267. Buf *rb;
  268. rb = recvp(unacked);
  269. if (rb->hdr.msg != lastacked + i) {
  270. dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
  271. rb, rb? rb->hdr.msg: -2, lastacked, i);
  272. assert(0);
  273. }
  274. rb->hdr.msg = -1;
  275. sendp(empty, rb);
  276. }
  277. lastacked = b->hdr.acked;
  278. inmsg++;
  279. showmsg(1, "fromnet", b);
  280. if (writen(1, b->buf, len) < 0)
  281. sysfatal("fromnet; cannot write to client; %r\n");
  282. }
  283. done = 1;
  284. }
  285. static void
  286. reconnect(void)
  287. {
  288. char ldir[40];
  289. int lcfd, fd;
  290. if (dialstring) {
  291. syslog(0, Logname, "dialing %s", dialstring);
  292. while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
  293. char err[32];
  294. err[0] = '\0';
  295. errstr(err, sizeof err);
  296. if (strstr(err, "connection refused")) {
  297. dmessage(1, "reconnect; server died...\n");
  298. threadexitsall("server died...");
  299. }
  300. dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
  301. sleep(1000);
  302. }
  303. syslog(0, Logname, "reconnected to %s", dialstring);
  304. }
  305. else {
  306. Endpoints *ep;
  307. syslog(0, Logname, "waiting for connection on %s", devdir);
  308. alarm(maxto * 1000);
  309. if ((lcfd = listen(devdir, ldir)) < 0)
  310. sysfatal("reconnect; cannot listen; %r");
  311. if ((fd = accept(lcfd, ldir)) < 0)
  312. sysfatal("reconnect; cannot accept; %r");
  313. alarm(0);
  314. close(lcfd);
  315. ep = getendpoints(ldir);
  316. dmessage(1, "rsys '%s'\n", ep->rsys);
  317. syslog(0, Logname, "connected from %s", ep->rsys);
  318. freeendpoints(ep);
  319. }
  320. netfd = fd; // Wakes up the netreader.
  321. }
  322. static void
  323. synchronize(void)
  324. {
  325. Channel *tmp;
  326. Buf *b;
  327. // Ignore network errors here. If we fail during
  328. // synchronization, the next alarm will pick up
  329. // the error.
  330. tmp = chancreate(sizeof(Buf *), Nbuf);
  331. while ((b = nbrecvp(unacked)) != nil) {
  332. writen(netfd, (uchar *)b, sizeof(Hdr) + b->hdr.nb);
  333. sendp(tmp, b);
  334. }
  335. chanfree(unacked);
  336. unacked = tmp;
  337. }
  338. static void
  339. showmsg(int level, char *s, Buf *b)
  340. {
  341. if (b == nil) {
  342. dmessage(level, "%s; b == nil\n", s);
  343. return;
  344. }
  345. dmessage(level,
  346. "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s,
  347. b->hdr.nb,
  348. b->buf[0], b->buf[1], b->buf[2],
  349. b->buf[3], b->buf[4], b->buf[5],
  350. b->buf[6], b->buf[7], b->buf[8], b);
  351. }
  352. static int
  353. writen(int fd, uchar *buf, int nb)
  354. {
  355. int len = nb;
  356. while (nb > 0) {
  357. int n;
  358. if (fd < 0)
  359. return -1;
  360. if ((n = write(fd, buf, nb)) < 0) {
  361. dmessage(1, "writen; Write failed; %r\n");
  362. return -1;
  363. }
  364. dmessage(2, "writen: wrote %d bytes\n", n);
  365. buf += n;
  366. nb -= n;
  367. }
  368. return len;
  369. }
  370. static void
  371. timerproc(void *x)
  372. {
  373. Channel *timer = x;
  374. while (!done) {
  375. sleep((Synctime / MS(1)) >> 1);
  376. sendp(timer, "timer");
  377. }
  378. }
  379. static void
  380. dmessage(int level, char *fmt, ...)
  381. {
  382. va_list arg;
  383. if (level > debug)
  384. return;
  385. va_start(arg, fmt);
  386. vfprint(2, fmt, arg);
  387. va_end(arg);
  388. }
  389. static void
  390. getendpoint(char *dir, char *file, char **sysp, char **servp)
  391. {
  392. int fd, n;
  393. char buf[128];
  394. char *sys, *serv;
  395. sys = serv = 0;
  396. snprint(buf, sizeof buf, "%s/%s", dir, file);
  397. fd = open(buf, OREAD);
  398. if(fd >= 0){
  399. n = read(fd, buf, sizeof(buf)-1);
  400. if(n>0){
  401. buf[n-1] = 0;
  402. serv = strchr(buf, '!');
  403. if(serv){
  404. *serv++ = 0;
  405. serv = strdup(serv);
  406. }
  407. sys = strdup(buf);
  408. }
  409. close(fd);
  410. }
  411. if(serv == 0)
  412. serv = strdup("unknown");
  413. if(sys == 0)
  414. sys = strdup("unknown");
  415. *servp = serv;
  416. *sysp = sys;
  417. }
  418. static Endpoints *
  419. getendpoints(char *dir)
  420. {
  421. Endpoints *ep;
  422. ep = malloc(sizeof(*ep));
  423. getendpoint(dir, "local", &ep->lsys, &ep->lserv);
  424. getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
  425. return ep;
  426. }
  427. static void
  428. freeendpoints(Endpoints *ep)
  429. {
  430. free(ep->lsys);
  431. free(ep->rsys);
  432. free(ep->lserv);
  433. free(ep->rserv);
  434. free(ep);
  435. }