msg.c 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. /*
  2. * This file is part of the UCB release of Plan 9. It is subject to the license
  3. * terms in the LICENSE file found in the top-level directory of this
  4. * distribution and at http://akaros.cs.berkeley.edu/files/Plan9License. No
  5. * part of the UCB release of Plan 9, including this file, may be copied,
  6. * modified, propagated, or distributed except according to the terms contained
  7. * in the LICENSE file.
  8. */
  9. #include "ssh.h"
  10. static uint32_t sum32(uint32_t, void*, int);
  11. char *msgnames[] =
  12. {
  13. /* 0 */
  14. "SSH_MSG_NONE",
  15. "SSH_MSG_DISCONNECT",
  16. "SSH_SMSG_PUBLIC_KEY",
  17. "SSH_CMSG_SESSION_KEY",
  18. "SSH_CMSG_USER",
  19. "SSH_CMSG_AUTH_RHOSTS",
  20. "SSH_CMSG_AUTH_RSA",
  21. "SSH_SMSG_AUTH_RSA_CHALLENGE",
  22. "SSH_CMSG_AUTH_RSA_RESPONSE",
  23. "SSH_CMSG_AUTH_PASSWORD",
  24. /* 10 */
  25. "SSH_CMSG_REQUEST_PTY",
  26. "SSH_CMSG_WINDOW_SIZE",
  27. "SSH_CMSG_EXEC_SHELL",
  28. "SSH_CMSG_EXEC_CMD",
  29. "SSH_SMSG_SUCCESS",
  30. "SSH_SMSG_FAILURE",
  31. "SSH_CMSG_STDIN_DATA",
  32. "SSH_SMSG_STDOUT_DATA",
  33. "SSH_SMSG_STDERR_DATA",
  34. "SSH_CMSG_EOF",
  35. /* 20 */
  36. "SSH_SMSG_EXITSTATUS",
  37. "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
  38. "SSH_MSG_CHANNEL_OPEN_FAILURE",
  39. "SSH_MSG_CHANNEL_DATA",
  40. "SSH_MSG_CHANNEL_INPUT_EOF",
  41. "SSH_MSG_CHANNEL_OUTPUT_CLOSED",
  42. "SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
  43. "SSH_SMSG_X11_OPEN",
  44. "SSH_CMSG_PORT_FORWARD_REQUEST",
  45. "SSH_MSG_PORT_OPEN",
  46. /* 30 */
  47. "SSH_CMSG_AGENT_REQUEST_FORWARDING",
  48. "SSH_SMSG_AGENT_OPEN",
  49. "SSH_MSG_IGNORE",
  50. "SSH_CMSG_EXIT_CONFIRMATION",
  51. "SSH_CMSG_X11_REQUEST_FORWARDING",
  52. "SSH_CMSG_AUTH_RHOSTS_RSA",
  53. "SSH_MSG_DEBUG",
  54. "SSH_CMSG_REQUEST_COMPRESSION",
  55. "SSH_CMSG_MAX_PACKET_SIZE",
  56. "SSH_CMSG_AUTH_TIS",
  57. /* 40 */
  58. "SSH_SMSG_AUTH_TIS_CHALLENGE",
  59. "SSH_CMSG_AUTH_TIS_RESPONSE",
  60. "SSH_CMSG_AUTH_KERBEROS",
  61. "SSH_SMSG_AUTH_KERBEROS_RESPONSE",
  62. "SSH_CMSG_HAVE_KERBEROS_TGT"
  63. };
  64. void
  65. badmsg(Msg *m, int want)
  66. {
  67. char *s, buf[20+ERRMAX];
  68. if(m==nil){
  69. snprint(buf, sizeof buf, "<early eof: %r>");
  70. s = buf;
  71. }else{
  72. snprint(buf, sizeof buf, "<unknown type %d>", m->type);
  73. s = buf;
  74. if(0 <= m->type && m->type < nelem(msgnames))
  75. s = msgnames[m->type];
  76. }
  77. if(want)
  78. error("got %s message expecting %s", s, msgnames[want]);
  79. error("got unexpected %s message", s);
  80. }
  81. Msg*
  82. allocmsg(Conn *c, int type, int len)
  83. {
  84. uint8_t *p;
  85. Msg *m;
  86. if(len > 256*1024)
  87. abort();
  88. m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
  89. setmalloctag(m, getcallerpc(&c));
  90. p = (uint8_t*)&m[1];
  91. m->c = c;
  92. m->bp = p;
  93. m->ep = p+len;
  94. m->wp = p;
  95. m->type = type;
  96. return m;
  97. }
  98. void
  99. unrecvmsg(Conn *c, Msg *m)
  100. {
  101. debug(DBG_PROTO, "unreceived %s len %ld\n", msgnames[m->type], m->ep - m->rp);
  102. free(c->unget);
  103. c->unget = m;
  104. }
  105. static Msg*
  106. recvmsg0(Conn *c)
  107. {
  108. int pad;
  109. uint8_t *p, buf[4];
  110. uint32_t crc, crc0, len;
  111. Msg *m;
  112. if(c->unget){
  113. m = c->unget;
  114. c->unget = nil;
  115. return m;
  116. }
  117. if(readn(c->fd[0], buf, 4) != 4){
  118. werrstr("short net read: %r");
  119. return nil;
  120. }
  121. len = LONG(buf);
  122. if(len > 256*1024){
  123. werrstr("packet size far too big: %.8lux", len);
  124. return nil;
  125. }
  126. pad = 8 - len%8;
  127. m = (Msg*)emalloc(sizeof(Msg)+pad+len);
  128. setmalloctag(m, getcallerpc(&c));
  129. m->c = c;
  130. m->bp = (uint8_t*)&m[1];
  131. m->ep = m->bp + pad+len-4; /* -4: don't include crc */
  132. m->rp = m->bp;
  133. if(readn(c->fd[0], m->bp, pad+len) != pad+len){
  134. werrstr("short net read: %r");
  135. free(m);
  136. return nil;
  137. }
  138. if(c->cipher)
  139. c->cipher->decrypt(c->cstate, m->bp, len+pad);
  140. crc = sum32(0, m->bp, pad+len-4);
  141. p = m->bp + pad+len-4;
  142. crc0 = LONG(p);
  143. if(crc != crc0){
  144. werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
  145. free(m);
  146. return nil;
  147. }
  148. m->rp += pad;
  149. m->type = *m->rp++;
  150. return m;
  151. }
  152. Msg*
  153. recvmsg(Conn *c, int type)
  154. {
  155. Msg *m;
  156. while((m = recvmsg0(c)) != nil){
  157. debug(DBG_PROTO, "received %s len %ld\n", msgnames[m->type], m->ep - m->rp);
  158. if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
  159. break;
  160. if(m->type == SSH_MSG_DEBUG)
  161. debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
  162. free(m);
  163. }
  164. if(type == 0){
  165. /* no checking */
  166. }else if(type == -1){
  167. /* must not be nil */
  168. if(m == nil)
  169. error(Ehangup);
  170. }else{
  171. /* must be given type */
  172. if(m==nil || m->type!=type)
  173. badmsg(m, type);
  174. }
  175. setmalloctag(m, getcallerpc(&c));
  176. return m;
  177. }
  178. int
  179. sendmsg(Msg *m)
  180. {
  181. int i, pad;
  182. uint8_t *p;
  183. uint32_t datalen, len, crc;
  184. Conn *c;
  185. datalen = m->wp - m->bp;
  186. len = datalen + 5;
  187. pad = 8 - len%8;
  188. debug(DBG_PROTO, "sending %s len %lud\n", msgnames[m->type], datalen);
  189. p = m->bp;
  190. memmove(m->bp+4+pad+1, m->bp, datalen); /* slide data to correct position */
  191. PLONG(p, len);
  192. p += 4;
  193. if(m->c->cstate){
  194. for(i=0; i<pad; i++)
  195. *p++ = fastrand();
  196. }else{
  197. memset(p, 0, pad);
  198. p += pad;
  199. }
  200. *p++ = m->type;
  201. /* data already in position */
  202. p += datalen;
  203. crc = sum32(0, m->bp+4, pad+1+datalen);
  204. PLONG(p, crc);
  205. p += 4;
  206. c = m->c;
  207. qlock(c);
  208. if(c->cstate)
  209. c->cipher->encrypt(c->cstate, m->bp+4, len+pad);
  210. if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
  211. qunlock(c);
  212. free(m);
  213. return -1;
  214. }
  215. qunlock(c);
  216. free(m);
  217. return 0;
  218. }
  219. uint8_t
  220. getbyte(Msg *m)
  221. {
  222. if(m->rp >= m->ep)
  223. error(Edecode);
  224. return *m->rp++;
  225. }
  226. uint16_t
  227. getshort(Msg *m)
  228. {
  229. uint16_t x;
  230. if(m->rp+2 > m->ep)
  231. error(Edecode);
  232. x = SHORT(m->rp);
  233. m->rp += 2;
  234. return x;
  235. }
  236. uint32_t
  237. getlong(Msg *m)
  238. {
  239. uint32_t x;
  240. if(m->rp+4 > m->ep)
  241. error(Edecode);
  242. x = LONG(m->rp);
  243. m->rp += 4;
  244. return x;
  245. }
  246. char*
  247. getstring(Msg *m)
  248. {
  249. char *p;
  250. uint32_t len;
  251. /* overwrites length to make room for NUL */
  252. len = getlong(m);
  253. if(m->rp+len > m->ep)
  254. error(Edecode);
  255. p = (char*)m->rp-1;
  256. memmove(p, m->rp, len);
  257. p[len] = '\0';
  258. return p;
  259. }
  260. void*
  261. getbytes(Msg *m, int n)
  262. {
  263. uint8_t *p;
  264. if(m->rp+n > m->ep)
  265. error(Edecode);
  266. p = m->rp;
  267. m->rp += n;
  268. return p;
  269. }
  270. mpint*
  271. getmpint(Msg *m)
  272. {
  273. int n;
  274. n = (getshort(m)+7)/8; /* getshort returns # bits */
  275. return betomp(getbytes(m, n), n, nil);
  276. }
  277. RSApub*
  278. getRSApub(Msg *m)
  279. {
  280. RSApub *key;
  281. getlong(m);
  282. key = rsapuballoc();
  283. if(key == nil)
  284. error(Ememory);
  285. key->ek = getmpint(m);
  286. key->n = getmpint(m);
  287. setmalloctag(key, getcallerpc(&m));
  288. return key;
  289. }
  290. void
  291. putbyte(Msg *m, uint8_t x)
  292. {
  293. if(m->wp >= m->ep)
  294. error(Eencode);
  295. *m->wp++ = x;
  296. }
  297. void
  298. putshort(Msg *m, uint16_t x)
  299. {
  300. if(m->wp+2 > m->ep)
  301. error(Eencode);
  302. PSHORT(m->wp, x);
  303. m->wp += 2;
  304. }
  305. void
  306. putlong(Msg *m, uint32_t x)
  307. {
  308. if(m->wp+4 > m->ep)
  309. error(Eencode);
  310. PLONG(m->wp, x);
  311. m->wp += 4;
  312. }
  313. void
  314. putstring(Msg *m, char *s)
  315. {
  316. int len;
  317. len = strlen(s);
  318. putlong(m, len);
  319. putbytes(m, s, len);
  320. }
  321. void
  322. putbytes(Msg *m, void *a, int32_t n)
  323. {
  324. if(m->wp+n > m->ep)
  325. error(Eencode);
  326. memmove(m->wp, a, n);
  327. m->wp += n;
  328. }
  329. void
  330. putmpint(Msg *m, mpint *b)
  331. {
  332. int bits, n;
  333. bits = mpsignif(b);
  334. putshort(m, bits);
  335. n = (bits+7)/8;
  336. if(m->wp+n > m->ep)
  337. error(Eencode);
  338. mptobe(b, m->wp, n, nil);
  339. m->wp += n;
  340. }
  341. void
  342. putRSApub(Msg *m, RSApub *key)
  343. {
  344. putlong(m, mpsignif(key->n));
  345. putmpint(m, key->ek);
  346. putmpint(m, key->n);
  347. }
  348. static uint32_t crctab[256];
  349. static void
  350. initsum32(void)
  351. {
  352. uint32_t crc, poly;
  353. int i, j;
  354. poly = 0xEDB88320;
  355. for(i = 0; i < 256; i++){
  356. crc = i;
  357. for(j = 0; j < 8; j++){
  358. if(crc & 1)
  359. crc = (crc >> 1) ^ poly;
  360. else
  361. crc >>= 1;
  362. }
  363. crctab[i] = crc;
  364. }
  365. }
  366. static uint32_t
  367. sum32(uint32_t lcrc, void *buf, int n)
  368. {
  369. static int first=1;
  370. uint8_t *s = buf;
  371. uint32_t crc = lcrc;
  372. if(first){
  373. first=0;
  374. initsum32();
  375. }
  376. while(n-- > 0)
  377. crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
  378. return crc;
  379. }
  380. mpint*
  381. rsapad(mpint *b, int n)
  382. {
  383. int i, pad, nbuf;
  384. uint8_t buf[2560];
  385. mpint *c;
  386. if(n > sizeof buf)
  387. error("buffer too small in rsapad");
  388. nbuf = (mpsignif(b)+7)/8;
  389. pad = n - nbuf;
  390. assert(pad >= 3);
  391. mptobe(b, buf, nbuf, nil);
  392. memmove(buf+pad, buf, nbuf);
  393. buf[0] = 0;
  394. buf[1] = 2;
  395. for(i=2; i<pad-1; i++)
  396. buf[i]=1+fastrand()%255;
  397. buf[pad-1] = 0;
  398. c = betomp(buf, n, nil);
  399. memset(buf, 0, sizeof buf);
  400. return c;
  401. }
  402. mpint*
  403. rsaunpad(mpint *b)
  404. {
  405. int i, n;
  406. uint8_t buf[2560];
  407. n = (mpsignif(b)+7)/8;
  408. if(n > sizeof buf)
  409. error("buffer too small in rsaunpad");
  410. mptobe(b, buf, n, nil);
  411. /* the initial zero has been eaten by the betomp -> mptobe sequence */
  412. if(buf[0] != 2)
  413. error("bad data in rsaunpad");
  414. for(i=1; i<n; i++)
  415. if(buf[i]==0)
  416. break;
  417. return betomp(buf+i, n-i, nil);
  418. }
  419. void
  420. mptoberjust(mpint *b, uint8_t *buf, int len)
  421. {
  422. int n;
  423. n = mptobe(b, buf, len, nil);
  424. assert(n >= 0);
  425. if(n < len){
  426. len -= n;
  427. memmove(buf+len, buf, n);
  428. memset(buf, 0, len);
  429. }
  430. }
  431. mpint*
  432. rsaencryptbuf(RSApub *key, uint8_t *buf, int nbuf)
  433. {
  434. int n;
  435. mpint *a, *b, *c;
  436. n = (mpsignif(key->n)+7)/8;
  437. a = betomp(buf, nbuf, nil);
  438. b = rsapad(a, n);
  439. mpfree(a);
  440. c = rsaencrypt(key, b, nil);
  441. mpfree(b);
  442. return c;
  443. }