123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- #include <u.h>
- #include <libc.h>
- #include <auth.h>
- #include <fcall.h>
- #include <thread.h>
- #define NS(x) ((vlong)x)
- #define US(x) (NS(x) * 1000LL)
- #define MS(x) (US(x) * 1000LL)
- #define S(x) (MS(x) * 1000LL)
- #define LOGNAME "aan"
- enum {
- Synctime = S(8),
- Nbuf = 10,
- K = 1024,
- Bufsize = 8 * K,
- Stacksize = 8 * K,
- Timer = 0, // Alt channels.
- Unsent = 1,
- Maxto = 24 * 3600, // A full day to reconnect.
- };
- typedef struct Endpoints Endpoints;
- struct Endpoints {
- char *lsys;
- char *lserv;
- char *rsys;
- char *rserv;
- };
- typedef struct {
- ulong nb; // Number of data bytes in this message
- ulong msg; // Message number
- ulong acked; // Number of messages acked
- } Hdr;
- typedef struct t_Buf {
- Hdr hdr;
- uchar buf[Bufsize];
- } Buf;
- static char *progname;
- static Channel *unsent;
- static Channel *unacked;
- static Channel *empty;
- static int netfd;
- static int inmsg;
- static char *devdir;
- static int debug;
- static int done;
- static char *dialstring;
- static int maxto = Maxto;
- static char *Logname = LOGNAME;
- static int client;
- static Alt a[] = {
- /* c v op */
- { nil, nil, CHANRCV }, // timer
- { nil, nil, CHANRCV }, // unsent
- { nil, nil, CHANEND },
- };
- static void fromnet(void*);
- static void fromclient(void*);
- static void reconnect(void);
- static void synchronize(void);
- static int sendcommand(ulong, ulong);
- static void showmsg(int, char *, Buf *);
- static int writen(int, uchar *, int);
- static int getport(char *);
- static void dmessage(int, char *, ...);
- static void timerproc(void *);
- static Endpoints *getendpoints(char *);
- static void freeendpoints(Endpoints *);
- static void
- usage(void)
- {
- fprint(2, "Usage: %s [-c] [-d] [-m maxto] dialstring|handle\n", progname);
- exits(nil);
- }
- static int
- catch(void *, char *s)
- {
- if (!strcmp(s, "alarm")) {
- syslog(0, Logname, "Timed out while waiting for client on %s, exiting...",
- devdir);
- threadexitsall(nil);
- }
- return 0;
- }
- void
- threadmain(int argc, char **argv)
- {
- int i, failed;
- Buf *b;
- Channel *timer;
- vlong synctime;
- progname = argv[0];
- ARGBEGIN {
- case 'c':
- client++;
- break;
- case 'd':
- debug++;
- break;
- case 'm':
- maxto = (int)strtol(EARGF(usage()), (char **)nil, 0);
- break;
- case '?':
- default:
- usage();
- }
- ARGEND;
- if (argc != 1)
- usage();
- if (!client) {
- char *p;
- devdir = argv[0];
- if ((p = strstr(devdir, "/local")) != nil)
- *p = '\0';
- }
- else
- dialstring = argv[0];
- if (debug > 0) {
- int fd = open("#c/cons", OWRITE|OCEXEC);
- dup(fd, 2);
- }
- fmtinstall('F', fcallfmt);
- atnotify(catch, 1);
- unsent = chancreate(sizeof(Buf *), Nbuf);
- unacked = chancreate(sizeof(Buf *), Nbuf);
- empty = chancreate(sizeof(Buf *), Nbuf);
- timer = chancreate(sizeof(uchar *), 1);
- for (i = 0; i != Nbuf; i++) {
- Buf *b = malloc(sizeof(Buf));
- sendp(empty, b);
- }
- netfd = -1;
- if (proccreate(fromnet, nil, Stacksize) < 0)
- sysfatal("%s; Cannot start fromnet; %r\n", progname);
- reconnect(); // Set up the initial connection.
- synchronize();
- if (proccreate(fromclient, nil, Stacksize) < 0)
- sysfatal("%s; Cannot start fromclient; %r\n", progname);
- if (proccreate(timerproc, timer, Stacksize) < 0)
- sysfatal("%s; Cannot start timerproc; %r\n", progname);
- a[Timer].c = timer;
- a[Unsent].c = unsent;
- a[Unsent].v = &b;
- synctime = nsec() + Synctime;
- failed = 0;
- while (!done) {
- vlong now;
- int delta;
- if (failed) {
- // Wait for the netreader to die.
- while (netfd >= 0) {
- dmessage(1, "main; waiting for netreader to die\n");
- sleep(1000);
- }
- // the reader died; reestablish the world.
- reconnect();
- synchronize();
- failed = 0;
- }
- now = nsec();
- delta = (synctime - nsec()) / MS(1);
- if (delta <= 0) {
- Hdr hdr;
- hdr.nb = 0;
- hdr.acked = inmsg;
- hdr.msg = -1;
- if (writen(netfd, (uchar *)&hdr, sizeof(Hdr)) < 0) {
- dmessage(2, "main; writen failed; %r\n");
- failed = 1;
- continue;
- }
- synctime = nsec() + Synctime;
- assert(synctime > now);
- }
- switch (alt(a)) {
- case Timer:
- break;
- case Unsent:
- sendp(unacked, b);
- b->hdr.acked = inmsg;
- if (writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)) < 0) {
- dmessage(2, "main; writen failed; %r\n");
- failed = 1;
- }
- if (writen(netfd, b->buf, b->hdr.nb) < 0) {
- dmessage(2, "main; writen failed; %r\n");
- failed = 1;
- }
- if (b->hdr.nb == 0)
- done = 1;
- break;
- }
- }
- syslog(0, Logname, "exiting...");
- threadexitsall(nil);
- }
- static void
- fromclient(void*)
- {
- static int outmsg;
- for (;;) {
- Buf *b;
- b = recvp(empty);
- if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
- if ((int)b->hdr.nb < 0)
- dmessage(2, "fromclient; Cannot read 9P message; %r\n");
- else
- dmessage(2, "fromclient; Client terminated\n");
- b->hdr.nb = 0;
- }
- b->hdr.msg = outmsg++;
- showmsg(1, "fromclient", b);
- sendp(unsent, b);
-
- if (b->hdr.nb == 0)
- break;
- }
- }
- static void
- fromnet(void*)
- {
- static int lastacked;
- Buf *b;
- b = (Buf *)malloc(sizeof(Buf));
- assert(b);
- while (!done) {
- int len, acked, i;
- while (netfd < 0) {
- dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n",
- inmsg);
- sleep(1000);
- }
- // Read the header.
- if ((len = readn(netfd, &b->hdr, sizeof(Hdr))) <= 0) {
- if (len < 0)
- dmessage(1, "fromnet; (hdr) network failure; %r\n");
- else
- dmessage(1, "fromnet; (hdr) network closed\n");
- close(netfd);
- netfd = -1;
- continue;
- }
- dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len,
- b->hdr.nb, b->hdr.msg);
- if (b->hdr.nb == 0) {
- if ((long)b->hdr.msg >= 0) {
- dmessage(1, "fromnet; network closed\n");
- break;
- }
- continue;
- }
-
- if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) {
- if (len == 0)
- dmessage(1, "fromnet; network closed\n");
- else
- dmessage(1, "fromnet; network failure; %r\n");
- close(netfd);
- netfd = -1;
- continue;
- }
- if (b->hdr.msg < inmsg) {
- dmessage(1, "fromnet; skipping message %d, currently at %d\n",
- b->hdr.msg, inmsg);
- continue;
- }
- // Process the acked list.
- acked = b->hdr.acked - lastacked;
- for (i = 0; i != acked; i++) {
- Buf *rb;
- rb = recvp(unacked);
- if (rb->hdr.msg != lastacked + i) {
- dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
- rb, rb? rb->hdr.msg: -2, lastacked, i);
- assert(0);
- }
- rb->hdr.msg = -1;
- sendp(empty, rb);
- }
- lastacked = b->hdr.acked;
- inmsg++;
- showmsg(1, "fromnet", b);
- if (writen(1, b->buf, len) < 0)
- sysfatal("fromnet; cannot write to client; %r\n");
- }
- done = 1;
- }
- static void
- reconnect(void)
- {
- char ldir[40];
- int lcfd, fd;
- if (dialstring) {
- syslog(0, Logname, "dialing %s", dialstring);
- while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
- char err[32];
- err[0] = '\0';
- errstr(err, sizeof err);
- if (strstr(err, "connection refused")) {
- dmessage(1, "reconnect; server died...\n");
- threadexitsall("server died...");
- }
- dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
- sleep(1000);
- }
- syslog(0, Logname, "reconnected to %s", dialstring);
- }
- else {
- Endpoints *ep;
- syslog(0, Logname, "waiting for connection on %s", devdir);
- alarm(maxto * 1000);
- if ((lcfd = listen(devdir, ldir)) < 0)
- sysfatal("reconnect; cannot listen; %r");
-
- if ((fd = accept(lcfd, ldir)) < 0)
- sysfatal("reconnect; cannot accept; %r");
- alarm(0);
- close(lcfd);
-
- ep = getendpoints(ldir);
- dmessage(1, "rsys '%s'\n", ep->rsys);
- syslog(0, Logname, "connected from %s", ep->rsys);
- freeendpoints(ep);
- }
-
- netfd = fd; // Wakes up the netreader.
- }
- static void
- synchronize(void)
- {
- Channel *tmp;
- Buf *b;
- // Ignore network errors here. If we fail during
- // synchronization, the next alarm will pick up
- // the error.
- tmp = chancreate(sizeof(Buf *), Nbuf);
- while ((b = nbrecvp(unacked)) != nil) {
- writen(netfd, (uchar *)b, sizeof(Hdr) + b->hdr.nb);
- sendp(tmp, b);
- }
- chanfree(unacked);
- unacked = tmp;
- }
- static void
- showmsg(int level, char *s, Buf *b)
- {
- if (b == nil) {
- dmessage(level, "%s; b == nil\n", s);
- return;
- }
- dmessage(level,
- "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s,
- b->hdr.nb,
- b->buf[0], b->buf[1], b->buf[2],
- b->buf[3], b->buf[4], b->buf[5],
- b->buf[6], b->buf[7], b->buf[8], b);
- }
- static int
- writen(int fd, uchar *buf, int nb)
- {
- int len = nb;
- while (nb > 0) {
- int n;
- if (fd < 0)
- return -1;
- if ((n = write(fd, buf, nb)) < 0) {
- dmessage(1, "writen; Write failed; %r\n");
- return -1;
- }
- dmessage(2, "writen: wrote %d bytes\n", n);
- buf += n;
- nb -= n;
- }
- return len;
- }
- static void
- timerproc(void *x)
- {
- Channel *timer = x;
- while (!done) {
- sleep((Synctime / MS(1)) >> 1);
- sendp(timer, "timer");
- }
- }
- static void
- dmessage(int level, char *fmt, ...)
- {
- va_list arg;
- if (level > debug)
- return;
- va_start(arg, fmt);
- vfprint(2, fmt, arg);
- va_end(arg);
- }
- static void
- getendpoint(char *dir, char *file, char **sysp, char **servp)
- {
- int fd, n;
- char buf[128];
- char *sys, *serv;
- sys = serv = 0;
- snprint(buf, sizeof buf, "%s/%s", dir, file);
- fd = open(buf, OREAD);
- if(fd >= 0){
- n = read(fd, buf, sizeof(buf)-1);
- if(n>0){
- buf[n-1] = 0;
- serv = strchr(buf, '!');
- if(serv){
- *serv++ = 0;
- serv = strdup(serv);
- }
- sys = strdup(buf);
- }
- close(fd);
- }
- if(serv == 0)
- serv = strdup("unknown");
- if(sys == 0)
- sys = strdup("unknown");
- *servp = serv;
- *sysp = sys;
- }
- static Endpoints *
- getendpoints(char *dir)
- {
- Endpoints *ep;
- ep = malloc(sizeof(*ep));
- getendpoint(dir, "local", &ep->lsys, &ep->lserv);
- getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
- return ep;
- }
- static void
- freeendpoints(Endpoints *ep)
- {
- free(ep->lsys);
- free(ep->rsys);
- free(ep->lserv);
- free(ep->rserv);
- free(ep);
- }
|