aan.c 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. #include <u.h>
  2. #include <libc.h>
  3. #include <auth.h>
  4. #include <fcall.h>
  5. #include <thread.h>
  6. #define NS(x) ((vlong)x)
  7. #define US(x) (NS(x) * 1000LL)
  8. #define MS(x) (US(x) * 1000LL)
  9. #define S(x) (MS(x) * 1000LL)
  10. #define LOGNAME "aan"
  11. enum {
  12. Synctime = S(8),
  13. Nbuf = 10,
  14. K = 1024,
  15. Bufsize = 8 * K,
  16. Stacksize = 8 * K,
  17. Timer = 0, // Alt channels.
  18. Unsent = 1,
  19. Maxto = 24 * 3600, // A full day to reconnect.
  20. };
  21. typedef struct Endpoints Endpoints;
  22. struct Endpoints {
  23. char *lsys;
  24. char *lserv;
  25. char *rsys;
  26. char *rserv;
  27. };
  28. typedef struct {
  29. ulong nb; // Number of data bytes in this message
  30. ulong msg; // Message number
  31. ulong acked; // Number of messages acked
  32. } Hdr;
  33. typedef struct t_Buf {
  34. Hdr hdr;
  35. uchar buf[Bufsize];
  36. } Buf;
  37. static char *progname;
  38. static Channel *unsent;
  39. static Channel *unacked;
  40. static Channel *empty;
  41. static int netfd;
  42. static int inmsg;
  43. static char *devdir;
  44. static int debug;
  45. static int done;
  46. static char *dialstring;
  47. static int maxto = Maxto;
  48. static char *Logname = LOGNAME;
  49. static int client;
  50. static Alt a[] = {
  51. /* c v op */
  52. { nil, nil, CHANRCV }, // timer
  53. { nil, nil, CHANRCV }, // unsent
  54. { nil, nil, CHANEND },
  55. };
  56. static void fromnet(void*);
  57. static void fromclient(void*);
  58. static void reconnect(void);
  59. static void synchronize(void);
  60. static int sendcommand(ulong, ulong);
  61. static void showmsg(int, char *, Buf *);
  62. static int writen(int, uchar *, int);
  63. static int getport(char *);
  64. static void dmessage(int, char *, ...);
  65. static void timerproc(void *);
  66. static Endpoints *getendpoints(char *);
  67. static void freeendpoints(Endpoints *);
  68. static void
  69. usage(void)
  70. {
  71. fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", progname);
  72. threadexitsall("usage");
  73. }
  74. static int
  75. catch(void *, char *s)
  76. {
  77. if (strcmp(s, "alarm") == 0) {
  78. syslog(0, Logname, "Timed out while waiting for client on %s, exiting...",
  79. devdir);
  80. threadexitsall(nil);
  81. }
  82. return 0;
  83. }
  84. void
  85. threadmain(int argc, char **argv)
  86. {
  87. int i, failed;
  88. Buf *b;
  89. Channel *timer;
  90. vlong synctime;
  91. progname = argv[0];
  92. ARGBEGIN {
  93. case 'c':
  94. client++;
  95. break;
  96. case 'd':
  97. debug++;
  98. break;
  99. case 'm':
  100. maxto = (int)strtol(EARGF(usage()), (char **)nil, 0);
  101. break;
  102. default:
  103. usage();
  104. } ARGEND;
  105. if (argc != 1)
  106. usage();
  107. if (!client) {
  108. char *p;
  109. devdir = argv[0];
  110. if ((p = strstr(devdir, "/local")) != nil)
  111. *p = '\0';
  112. }
  113. else
  114. dialstring = argv[0];
  115. if (debug > 0) {
  116. int fd = open("#c/cons", OWRITE|OCEXEC);
  117. dup(fd, 2);
  118. }
  119. fmtinstall('F', fcallfmt);
  120. atnotify(catch, 1);
  121. unsent = chancreate(sizeof(Buf *), Nbuf);
  122. unacked = chancreate(sizeof(Buf *), Nbuf);
  123. empty = chancreate(sizeof(Buf *), Nbuf);
  124. timer = chancreate(sizeof(uchar *), 1);
  125. for (i = 0; i != Nbuf; i++) {
  126. Buf *b = malloc(sizeof(Buf));
  127. sendp(empty, b);
  128. }
  129. netfd = -1;
  130. if (proccreate(fromnet, nil, Stacksize) < 0)
  131. sysfatal("%s; Cannot start fromnet; %r", progname);
  132. reconnect(); // Set up the initial connection.
  133. synchronize();
  134. if (proccreate(fromclient, nil, Stacksize) < 0)
  135. sysfatal("%s; Cannot start fromclient; %r", progname);
  136. if (proccreate(timerproc, timer, Stacksize) < 0)
  137. sysfatal("%s; Cannot start timerproc; %r", progname);
  138. a[Timer].c = timer;
  139. a[Unsent].c = unsent;
  140. a[Unsent].v = &b;
  141. synctime = nsec() + Synctime;
  142. failed = 0;
  143. while (!done) {
  144. vlong now;
  145. int delta;
  146. if (failed) {
  147. // Wait for the netreader to die.
  148. while (netfd >= 0) {
  149. dmessage(1, "main; waiting for netreader to die\n");
  150. sleep(1000);
  151. }
  152. // the reader died; reestablish the world.
  153. reconnect();
  154. synchronize();
  155. failed = 0;
  156. }
  157. now = nsec();
  158. delta = (synctime - nsec()) / MS(1);
  159. if (delta <= 0) {
  160. Hdr hdr;
  161. hdr.nb = 0;
  162. hdr.acked = inmsg;
  163. hdr.msg = -1;
  164. if (writen(netfd, (uchar *)&hdr, sizeof(Hdr)) < 0) {
  165. dmessage(2, "main; writen failed; %r\n");
  166. failed = 1;
  167. continue;
  168. }
  169. synctime = nsec() + Synctime;
  170. assert(synctime > now);
  171. }
  172. switch (alt(a)) {
  173. case Timer:
  174. break;
  175. case Unsent:
  176. sendp(unacked, b);
  177. b->hdr.acked = inmsg;
  178. if (writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)) < 0) {
  179. dmessage(2, "main; writen failed; %r\n");
  180. failed = 1;
  181. }
  182. if (writen(netfd, b->buf, b->hdr.nb) < 0) {
  183. dmessage(2, "main; writen failed; %r\n");
  184. failed = 1;
  185. }
  186. if (b->hdr.nb == 0)
  187. done = 1;
  188. break;
  189. }
  190. }
  191. syslog(0, Logname, "exiting...");
  192. threadexitsall(nil);
  193. }
  194. static void
  195. fromclient(void*)
  196. {
  197. static int outmsg;
  198. for (;;) {
  199. Buf *b;
  200. b = recvp(empty);
  201. if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
  202. if ((int)b->hdr.nb < 0)
  203. dmessage(2, "fromclient; Cannot read 9P message; %r\n");
  204. else
  205. dmessage(2, "fromclient; Client terminated\n");
  206. b->hdr.nb = 0;
  207. }
  208. b->hdr.msg = outmsg++;
  209. showmsg(1, "fromclient", b);
  210. sendp(unsent, b);
  211. if (b->hdr.nb == 0)
  212. break;
  213. }
  214. }
  215. static void
  216. fromnet(void*)
  217. {
  218. static int lastacked;
  219. Buf *b;
  220. b = (Buf *)malloc(sizeof(Buf));
  221. assert(b);
  222. while (!done) {
  223. int len, acked, i;
  224. while (netfd < 0) {
  225. dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n",
  226. inmsg);
  227. sleep(1000);
  228. }
  229. // Read the header.
  230. if ((len = readn(netfd, &b->hdr, sizeof(Hdr))) <= 0) {
  231. if (len < 0)
  232. dmessage(1, "fromnet; (hdr) network failure; %r\n");
  233. else
  234. dmessage(1, "fromnet; (hdr) network closed\n");
  235. close(netfd);
  236. netfd = -1;
  237. continue;
  238. }
  239. dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len,
  240. b->hdr.nb, b->hdr.msg);
  241. if (b->hdr.nb == 0) {
  242. if ((long)b->hdr.msg >= 0) {
  243. dmessage(1, "fromnet; network closed\n");
  244. break;
  245. }
  246. continue;
  247. }
  248. if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) {
  249. if (len == 0)
  250. dmessage(1, "fromnet; network closed\n");
  251. else
  252. dmessage(1, "fromnet; network failure; %r\n");
  253. close(netfd);
  254. netfd = -1;
  255. continue;
  256. }
  257. if (b->hdr.msg < inmsg) {
  258. dmessage(1, "fromnet; skipping message %d, currently at %d\n",
  259. b->hdr.msg, inmsg);
  260. continue;
  261. }
  262. // Process the acked list.
  263. acked = b->hdr.acked - lastacked;
  264. for (i = 0; i != acked; i++) {
  265. Buf *rb;
  266. rb = recvp(unacked);
  267. if (rb->hdr.msg != lastacked + i) {
  268. dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
  269. rb, rb? rb->hdr.msg: -2, lastacked, i);
  270. assert(0);
  271. }
  272. rb->hdr.msg = -1;
  273. sendp(empty, rb);
  274. }
  275. lastacked = b->hdr.acked;
  276. inmsg++;
  277. showmsg(1, "fromnet", b);
  278. if (writen(1, b->buf, len) < 0)
  279. sysfatal("fromnet; cannot write to client; %r");
  280. }
  281. done = 1;
  282. }
  283. static void
  284. reconnect(void)
  285. {
  286. char ldir[40];
  287. int lcfd, fd;
  288. if (dialstring) {
  289. syslog(0, Logname, "dialing %s", dialstring);
  290. while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
  291. char err[32];
  292. err[0] = '\0';
  293. errstr(err, sizeof err);
  294. if (strstr(err, "connection refused")) {
  295. dmessage(1, "reconnect; server died...\n");
  296. threadexitsall("server died...");
  297. }
  298. dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
  299. sleep(1000);
  300. }
  301. syslog(0, Logname, "reconnected to %s", dialstring);
  302. }
  303. else {
  304. Endpoints *ep;
  305. syslog(0, Logname, "waiting for connection on %s", devdir);
  306. alarm(maxto * 1000);
  307. if ((lcfd = listen(devdir, ldir)) < 0)
  308. sysfatal("reconnect; cannot listen; %r");
  309. if ((fd = accept(lcfd, ldir)) < 0)
  310. sysfatal("reconnect; cannot accept; %r");
  311. alarm(0);
  312. close(lcfd);
  313. ep = getendpoints(ldir);
  314. dmessage(1, "rsys '%s'\n", ep->rsys);
  315. syslog(0, Logname, "connected from %s", ep->rsys);
  316. freeendpoints(ep);
  317. }
  318. netfd = fd; // Wakes up the netreader.
  319. }
  320. static void
  321. synchronize(void)
  322. {
  323. Channel *tmp;
  324. Buf *b;
  325. // Ignore network errors here. If we fail during
  326. // synchronization, the next alarm will pick up
  327. // the error.
  328. tmp = chancreate(sizeof(Buf *), Nbuf);
  329. while ((b = nbrecvp(unacked)) != nil) {
  330. writen(netfd, (uchar *)b, sizeof(Hdr) + b->hdr.nb);
  331. sendp(tmp, b);
  332. }
  333. chanfree(unacked);
  334. unacked = tmp;
  335. }
  336. static void
  337. showmsg(int level, char *s, Buf *b)
  338. {
  339. if (b == nil) {
  340. dmessage(level, "%s; b == nil\n", s);
  341. return;
  342. }
  343. dmessage(level,
  344. "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s,
  345. b->hdr.nb,
  346. b->buf[0], b->buf[1], b->buf[2],
  347. b->buf[3], b->buf[4], b->buf[5],
  348. b->buf[6], b->buf[7], b->buf[8], b);
  349. }
  350. static int
  351. writen(int fd, uchar *buf, int nb)
  352. {
  353. int len = nb;
  354. while (nb > 0) {
  355. int n;
  356. if (fd < 0)
  357. return -1;
  358. if ((n = write(fd, buf, nb)) < 0) {
  359. dmessage(1, "writen; Write failed; %r\n");
  360. return -1;
  361. }
  362. dmessage(2, "writen: wrote %d bytes\n", n);
  363. buf += n;
  364. nb -= n;
  365. }
  366. return len;
  367. }
  368. static void
  369. timerproc(void *x)
  370. {
  371. Channel *timer = x;
  372. while (!done) {
  373. sleep((Synctime / MS(1)) >> 1);
  374. sendp(timer, "timer");
  375. }
  376. }
  377. static void
  378. dmessage(int level, char *fmt, ...)
  379. {
  380. va_list arg;
  381. if (level > debug)
  382. return;
  383. va_start(arg, fmt);
  384. vfprint(2, fmt, arg);
  385. va_end(arg);
  386. }
  387. static void
  388. getendpoint(char *dir, char *file, char **sysp, char **servp)
  389. {
  390. int fd, n;
  391. char buf[128];
  392. char *sys, *serv;
  393. sys = serv = 0;
  394. snprint(buf, sizeof buf, "%s/%s", dir, file);
  395. fd = open(buf, OREAD);
  396. if(fd >= 0){
  397. n = read(fd, buf, sizeof(buf)-1);
  398. if(n>0){
  399. buf[n-1] = 0;
  400. serv = strchr(buf, '!');
  401. if(serv){
  402. *serv++ = 0;
  403. serv = strdup(serv);
  404. }
  405. sys = strdup(buf);
  406. }
  407. close(fd);
  408. }
  409. if(serv == 0)
  410. serv = strdup("unknown");
  411. if(sys == 0)
  412. sys = strdup("unknown");
  413. *servp = serv;
  414. *sysp = sys;
  415. }
  416. static Endpoints *
  417. getendpoints(char *dir)
  418. {
  419. Endpoints *ep;
  420. ep = malloc(sizeof(*ep));
  421. getendpoint(dir, "local", &ep->lsys, &ep->lserv);
  422. getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
  423. return ep;
  424. }
  425. static void
  426. freeendpoints(Endpoints *ep)
  427. {
  428. free(ep->lsys);
  429. free(ep->rsys);
  430. free(ep->lserv);
  431. free(ep->rserv);
  432. free(ep);
  433. }