bootp.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. #include "u.h"
  2. #include "lib.h"
  3. #include "mem.h"
  4. #include "dat.h"
  5. #include "fns.h"
  6. #include "ip.h"
  7. uchar broadcast[Eaddrlen] = {
  8. 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
  9. };
  10. static ushort tftpport = 5000;
  11. static int Id = 1;
  12. static Netaddr myaddr;
  13. static Netaddr server;
  14. typedef struct {
  15. uchar header[4];
  16. uchar data[Segsize];
  17. } Tftp;
  18. static Tftp tftpb;
  19. int
  20. etherrxpkt(int ctlrno, Etherpkt *pkt, int timo)
  21. {
  22. int n;
  23. for (;;) {
  24. n = devread(ctlrno, (uchar*)pkt, sizeof(*pkt), 0);
  25. if (n >= 0)
  26. return n;
  27. if (timo-- < 0)
  28. return -1;
  29. }
  30. }
  31. int
  32. ethertxpkt(int ctlrno, Etherpkt *pkt, int len, int timo)
  33. {
  34. USED(timo);
  35. return devwrite(ctlrno, (uchar*)pkt, len, 0);
  36. }
  37. static void
  38. hnputs(uchar *ptr, ushort val)
  39. {
  40. ptr[0] = val>>8;
  41. ptr[1] = val;
  42. }
  43. static void
  44. hnputl(uchar *ptr, ulong val)
  45. {
  46. ptr[0] = val>>24;
  47. ptr[1] = val>>16;
  48. ptr[2] = val>>8;
  49. ptr[3] = val;
  50. }
  51. static ulong
  52. nhgetl(uchar *ptr)
  53. {
  54. return ((ptr[0]<<24) | (ptr[1]<<16) | (ptr[2]<<8) | ptr[3]);
  55. }
  56. static ushort
  57. nhgets(uchar *ptr)
  58. {
  59. return ((ptr[0]<<8) | ptr[1]);
  60. }
  61. static short endian = 1;
  62. static char* aendian = (char*)&endian;
  63. #define LITTLE *aendian
  64. static ushort
  65. ptcl_csum(void *a, int len)
  66. {
  67. uchar *addr;
  68. ulong t1, t2;
  69. ulong losum, hisum, mdsum, x;
  70. addr = a;
  71. losum = 0;
  72. hisum = 0;
  73. mdsum = 0;
  74. x = 0;
  75. if((ulong)addr & 1) {
  76. if(len) {
  77. hisum += addr[0];
  78. len--;
  79. addr++;
  80. }
  81. x = 1;
  82. }
  83. while(len >= 16) {
  84. t1 = *(ushort*)(addr+0);
  85. t2 = *(ushort*)(addr+2); mdsum += t1;
  86. t1 = *(ushort*)(addr+4); mdsum += t2;
  87. t2 = *(ushort*)(addr+6); mdsum += t1;
  88. t1 = *(ushort*)(addr+8); mdsum += t2;
  89. t2 = *(ushort*)(addr+10); mdsum += t1;
  90. t1 = *(ushort*)(addr+12); mdsum += t2;
  91. t2 = *(ushort*)(addr+14); mdsum += t1;
  92. mdsum += t2;
  93. len -= 16;
  94. addr += 16;
  95. }
  96. while(len >= 2) {
  97. mdsum += *(ushort*)addr;
  98. len -= 2;
  99. addr += 2;
  100. }
  101. if(x) {
  102. if(len)
  103. losum += addr[0];
  104. if(LITTLE)
  105. losum += mdsum;
  106. else
  107. hisum += mdsum;
  108. } else {
  109. if(len)
  110. hisum += addr[0];
  111. if(LITTLE)
  112. hisum += mdsum;
  113. else
  114. losum += mdsum;
  115. }
  116. losum += hisum >> 8;
  117. losum += (hisum & 0xff) << 8;
  118. while(hisum = losum>>16)
  119. losum = hisum + (losum & 0xffff);
  120. return ~losum;
  121. }
  122. static ushort
  123. ip_csum(uchar *addr)
  124. {
  125. int len;
  126. ulong sum = 0;
  127. len = (addr[0]&0xf)<<2;
  128. while(len > 0) {
  129. sum += addr[0]<<8 | addr[1] ;
  130. len -= 2;
  131. addr += 2;
  132. }
  133. sum = (sum & 0xffff) + (sum >> 16);
  134. sum = (sum & 0xffff) + (sum >> 16);
  135. return (sum^0xffff);
  136. }
  137. static void
  138. udpsend(int ctlrno, Netaddr *a, void *data, int dlen)
  139. {
  140. Udphdr *uh;
  141. Etherhdr *ip;
  142. Etherpkt pkt;
  143. int len, ptcllen;
  144. uh = (Udphdr*)&pkt;
  145. memset(uh, 0, sizeof(Etherpkt));
  146. memmove(uh->udpcksum+sizeof(uh->udpcksum), data, dlen);
  147. /*
  148. * UDP portion
  149. */
  150. ptcllen = dlen + (UDP_HDRSIZE-UDP_PHDRSIZE);
  151. uh->ttl = 0;
  152. uh->udpproto = IP_UDPPROTO;
  153. uh->frag[0] = 0;
  154. uh->frag[1] = 0;
  155. hnputs(uh->udpplen, ptcllen);
  156. hnputl(uh->udpsrc, myaddr.ip);
  157. hnputs(uh->udpsport, myaddr.port);
  158. hnputl(uh->udpdst, a->ip);
  159. hnputs(uh->udpdport, a->port);
  160. hnputs(uh->udplen, ptcllen);
  161. uh->udpcksum[0] = 0;
  162. uh->udpcksum[1] = 0;
  163. dlen = (dlen+1)&~1;
  164. hnputs(uh->udpcksum, ptcl_csum(&uh->ttl, dlen+UDP_HDRSIZE));
  165. /*
  166. * IP portion
  167. */
  168. ip = (Etherhdr*)&pkt;
  169. len = UDP_EHSIZE+UDP_HDRSIZE+dlen; /* non-descriptive names */
  170. ip->vihl = IP_VER|IP_HLEN;
  171. ip->tos = 0;
  172. ip->ttl = 255;
  173. hnputs(ip->length, len-ETHER_HDR);
  174. hnputs(ip->id, Id++);
  175. ip->frag[0] = 0;
  176. ip->frag[1] = 0;
  177. ip->cksum[0] = 0;
  178. ip->cksum[1] = 0;
  179. hnputs(ip->cksum, ip_csum(&ip->vihl));
  180. /*
  181. * Ethernet MAC portion
  182. */
  183. hnputs(ip->type, ET_IP);
  184. memmove(ip->d, a->ea, sizeof(ip->d));
  185. ethertxpkt(ctlrno, &pkt, len, Timeout);
  186. }
  187. static void
  188. nak(int ctlrno, Netaddr *a, int code, char *msg, int report)
  189. {
  190. int n;
  191. char buf[128];
  192. buf[0] = 0;
  193. buf[1] = Tftp_ERROR;
  194. buf[2] = 0;
  195. buf[3] = code;
  196. strcpy(buf+4, msg);
  197. n = strlen(msg) + 4 + 1;
  198. udpsend(ctlrno, a, buf, n);
  199. if(report)
  200. print("\ntftp: error(%d): %s\n", code, msg);
  201. }
  202. static int
  203. udprecv(int ctlrno, Netaddr *a, void *data, int dlen)
  204. {
  205. int n, len;
  206. ushort csm;
  207. Udphdr *h;
  208. ulong addr, timo;
  209. Etherpkt pkt;
  210. static int rxactive;
  211. if(rxactive == 0)
  212. timo = 1000;
  213. else
  214. timo = Timeout;
  215. timo += msec();
  216. while(timo > msec()){
  217. n = etherrxpkt(ctlrno, &pkt, timo-msec());
  218. if(n <= 0)
  219. continue;
  220. h = (Udphdr*)&pkt;
  221. if(nhgets(h->type) != ET_IP)
  222. continue;
  223. if(ip_csum(&h->vihl)) {
  224. print("ip chksum error\n");
  225. continue;
  226. }
  227. if(h->vihl != (IP_VER|IP_HLEN)) {
  228. print("ip bad vers/hlen\n");
  229. continue;
  230. }
  231. if(h->udpproto != IP_UDPPROTO)
  232. continue;
  233. h->ttl = 0;
  234. len = nhgets(h->udplen);
  235. hnputs(h->udpplen, len);
  236. if(nhgets(h->udpcksum)) {
  237. csm = ptcl_csum(&h->ttl, len+UDP_PHDRSIZE);
  238. if(csm != 0) {
  239. print("udp chksum error csum #%4lux len %d\n", csm, n);
  240. break;
  241. }
  242. }
  243. if(a->port != 0 && nhgets(h->udpsport) != a->port)
  244. continue;
  245. if(myaddr.port != 0 && nhgets(h->udpdport) != myaddr.port)
  246. continue;
  247. addr = nhgetl(h->udpsrc);
  248. if(a->ip != Bcastip && addr != a->ip)
  249. continue;
  250. len -= UDP_HDRSIZE-UDP_PHDRSIZE;
  251. if(len > dlen) {
  252. print("udp: packet too big\n");
  253. continue;
  254. }
  255. memmove(data, h->udpcksum+sizeof(h->udpcksum), len);
  256. a->ip = addr;
  257. a->port = nhgets(h->udpsport);
  258. memmove(a->ea, pkt.s, sizeof(a->ea));
  259. rxactive = 1;
  260. return len;
  261. }
  262. return 0;
  263. }
  264. static int tftpblockno;
  265. static int
  266. tftpopen(int ctlrno, Netaddr *a, char *name, Tftp *tftp)
  267. {
  268. int i, len, rlen;
  269. char buf[Segsize+2];
  270. buf[0] = 0;
  271. buf[1] = Tftp_READ;
  272. len = sprint(buf+2, "%s", name) + 2;
  273. len += sprint(buf+len+1, "octet") + 2;
  274. for(i = 0; i < 5; i++){
  275. udpsend(ctlrno, a, buf, len);
  276. a->port = 0;
  277. if((rlen = udprecv(ctlrno, a, tftp, sizeof(Tftp))) < sizeof(tftp->header))
  278. continue;
  279. switch((tftp->header[0]<<8)|tftp->header[1]){
  280. case Tftp_ERROR:
  281. print("tftpopen: error (%d): %s\n",
  282. (tftp->header[2]<<8)|tftp->header[3], tftp->data);
  283. return -1;
  284. case Tftp_DATA:
  285. tftpblockno = 1;
  286. len = (tftp->header[2]<<8)|tftp->header[3];
  287. if(len != tftpblockno){
  288. print("tftpopen: block error: %d\n", len);
  289. nak(ctlrno, a, 1, "block error", 0);
  290. return -1;
  291. }
  292. return rlen-sizeof(tftp->header);
  293. }
  294. }
  295. print("tftpopen: failed to connect to server\n");
  296. return -1;
  297. }
  298. static int
  299. tftpread(int ctlrno, Netaddr *a, Tftp *tftp, int dlen)
  300. {
  301. int blockno, len;
  302. uchar buf[4];
  303. buf[0] = 0;
  304. buf[1] = Tftp_ACK;
  305. buf[2] = tftpblockno>>8;
  306. buf[3] = tftpblockno;
  307. tftpblockno++;
  308. dlen += sizeof(tftp->header);
  309. buggery:
  310. udpsend(ctlrno, a, buf, sizeof(buf));
  311. if((len = udprecv(ctlrno, a, tftp, dlen)) != dlen){
  312. print("tftpread: %d != %d\n", len, dlen);
  313. nak(ctlrno, a, 2, "short read", 0);
  314. }
  315. blockno = (tftp->header[2]<<8)|tftp->header[3];
  316. if(blockno != tftpblockno){
  317. print("tftpread: block error: %d, expected %d\n", blockno, tftpblockno);
  318. if(blockno == tftpblockno-1)
  319. goto buggery;
  320. nak(ctlrno, a, 1, "block error", 0);
  321. return -1;
  322. }
  323. return len-sizeof(tftp->header);
  324. }
  325. // #define BOOT_MAGIC L_MAGIC
  326. #define BOOT_MAGIC 0x0700e0c3
  327. void
  328. getether(char *dev, uchar *ea)
  329. {
  330. int i;
  331. char *p;
  332. p = dev;
  333. for (i = 0; i < 8; i++) {
  334. p = strchr(p, ' ');
  335. if (p == 0)
  336. panic("no ether addr");
  337. p++;
  338. }
  339. for (i = 0; i < 6; i++) {
  340. ea[i] = strtoul(p, &p, 16);
  341. if (*p != (i == 5 ? ' ' : '-'))
  342. panic("bad ether addr");
  343. p++;
  344. }
  345. }
  346. static char inibuf[BOOTARGSLEN];
  347. int
  348. bootp(char *dev)
  349. {
  350. Bootp req, rep;
  351. int i, fd, dlen, segsize, text, data, bss, total;
  352. uchar *addr, *p, ea[6];
  353. char *cp;
  354. ulong entry;
  355. Exec *exec;
  356. char *filename, confname[32];
  357. getether(dev, ea);
  358. fd = devopen(dev);
  359. if (fd < 0)
  360. panic("bootp devopen");
  361. memset(&req, 0, sizeof(req));
  362. req.op = Bootrequest;
  363. req.htype = 1; /* ethernet */
  364. req.hlen = Eaddrlen; /* ethernet */
  365. memmove(req.chaddr, ea, Eaddrlen);
  366. myaddr.ip = 0;
  367. myaddr.port = BPportsrc;
  368. memmove(myaddr.ea, ea, Eaddrlen);
  369. for(i = 0; i < 10; i++) {
  370. server.ip = Bcastip;
  371. server.port = BPportdst;
  372. memmove(server.ea, broadcast, sizeof(server.ea));
  373. udpsend(fd, &server, &req, sizeof(req));
  374. if(udprecv(fd, &server, &rep, sizeof(rep)) <= 0)
  375. continue;
  376. if(memcmp(req.chaddr, rep.chaddr, Eaddrlen))
  377. continue;
  378. if(rep.htype != 1 || rep.hlen != Eaddrlen)
  379. continue;
  380. break;
  381. }
  382. if(i >= 10) {
  383. print("bootp timed out\n");
  384. return -1;
  385. }
  386. sprint(confname, "/alpha/conf/%d.%d.%d.%d",
  387. rep.yiaddr[0],
  388. rep.yiaddr[1],
  389. rep.yiaddr[2],
  390. rep.yiaddr[3]);
  391. if(rep.sname[0] != '\0')
  392. print("%s ", rep.sname);
  393. print("(%d.%d.%d.%d!%d): %s...",
  394. rep.siaddr[0],
  395. rep.siaddr[1],
  396. rep.siaddr[2],
  397. rep.siaddr[3],
  398. server.port,
  399. confname);
  400. myaddr.ip = nhgetl(rep.yiaddr);
  401. myaddr.port = tftpport++;
  402. server.ip = nhgetl(rep.siaddr);
  403. server.port = TFTPport;
  404. if((dlen = tftpopen(fd, &server, confname, &tftpb)) < 0)
  405. return -1;
  406. cp = inibuf;
  407. while(dlen > 0) {
  408. if(cp-inibuf+dlen > BOOTARGSLEN)
  409. panic("conf too large");
  410. memmove(cp, tftpb.data, dlen);
  411. cp += dlen;
  412. if(dlen != Segsize)
  413. break;
  414. if((dlen = tftpread(fd, &server, &tftpb, sizeof(tftpb.data))) < 0)
  415. return -1;
  416. }
  417. *cp = 0;
  418. setconf(inibuf);
  419. filename = "/alpha/9apc";
  420. cp = getconf("bootfile");
  421. if(cp != nil)
  422. filename = cp;
  423. print("%s\n", filename);
  424. myaddr.port = tftpport++;
  425. server.port = TFTPport;
  426. if((dlen = tftpopen(fd, &server, filename, &tftpb)) < 0)
  427. return -1;
  428. exec = (Exec*)(tftpb.data);
  429. if(dlen < sizeof(Exec) || GLLONG(exec->magic) != BOOT_MAGIC){
  430. nak(fd, &server, 0, "bad magic number", 1);
  431. return -1;
  432. }
  433. text = GLLONG(exec->text);
  434. data = GLLONG(exec->data);
  435. bss = GLLONG(exec->bss);
  436. total = text+data+bss;
  437. entry = GLLONG(exec->entry);
  438. if (!validrgn(entry, entry+total))
  439. panic("memory range not available: %lux-%lux\n", entry, entry+total);
  440. print("%d", text);
  441. addr = (uchar*)entry;
  442. p = tftpb.data+sizeof(Exec);
  443. dlen -= sizeof(Exec);
  444. segsize = text;
  445. for(;;){
  446. if(dlen == 0){
  447. if((dlen = tftpread(fd, &server, &tftpb, sizeof(tftpb.data))) < 0)
  448. return -1;
  449. p = tftpb.data;
  450. }
  451. if(segsize <= dlen)
  452. i = segsize;
  453. else
  454. i = dlen;
  455. memmove(addr, p, i);
  456. addr += i;
  457. p += i;
  458. segsize -= i;
  459. dlen -= i;
  460. if(segsize <= 0){
  461. if(data == 0)
  462. break;
  463. print("+%d", data);
  464. segsize = data;
  465. data = 0;
  466. // addr = (uchar*)pground((uvlong)addr);
  467. }
  468. }
  469. nak(fd, &server, 3, "ok", 0); /* tftpclose */
  470. print("+%d=%d\n", bss, total);
  471. print("entry: 0x%lux\n", entry);
  472. kexec(entry);
  473. return 0;
  474. }