msg.c 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. #include "ssh.h"
  2. static ulong sum32(ulong, void*, int);
  3. char *msgnames[] =
  4. {
  5. /* 0 */
  6. "SSH_MSG_NONE",
  7. "SSH_MSG_DISCONNECT",
  8. "SSH_SMSG_PUBLIC_KEY",
  9. "SSH_CMSG_SESSION_KEY",
  10. "SSH_CMSG_USER",
  11. "SSH_CMSG_AUTH_RHOSTS",
  12. "SSH_CMSG_AUTH_RSA",
  13. "SSH_SMSG_AUTH_RSA_CHALLENGE",
  14. "SSH_CMSG_AUTH_RSA_RESPONSE",
  15. "SSH_CMSG_AUTH_PASSWORD",
  16. /* 10 */
  17. "SSH_CMSG_REQUEST_PTY",
  18. "SSH_CMSG_WINDOW_SIZE",
  19. "SSH_CMSG_EXEC_SHELL",
  20. "SSH_CMSG_EXEC_CMD",
  21. "SSH_SMSG_SUCCESS",
  22. "SSH_SMSG_FAILURE",
  23. "SSH_CMSG_STDIN_DATA",
  24. "SSH_SMSG_STDOUT_DATA",
  25. "SSH_SMSG_STDERR_DATA",
  26. "SSH_CMSG_EOF",
  27. /* 20 */
  28. "SSH_SMSG_EXITSTATUS",
  29. "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
  30. "SSH_MSG_CHANNEL_OPEN_FAILURE",
  31. "SSH_MSG_CHANNEL_DATA",
  32. "SSH_MSG_CHANNEL_INPUT_EOF",
  33. "SSH_MSG_CHANNEL_OUTPUT_CLOSED",
  34. "SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
  35. "SSH_SMSG_X11_OPEN",
  36. "SSH_CMSG_PORT_FORWARD_REQUEST",
  37. "SSH_MSG_PORT_OPEN",
  38. /* 30 */
  39. "SSH_CMSG_AGENT_REQUEST_FORWARDING",
  40. "SSH_SMSG_AGENT_OPEN",
  41. "SSH_MSG_IGNORE",
  42. "SSH_CMSG_EXIT_CONFIRMATION",
  43. "SSH_CMSG_X11_REQUEST_FORWARDING",
  44. "SSH_CMSG_AUTH_RHOSTS_RSA",
  45. "SSH_MSG_DEBUG",
  46. "SSH_CMSG_REQUEST_COMPRESSION",
  47. "SSH_CMSG_MAX_PACKET_SIZE",
  48. "SSH_CMSG_AUTH_TIS",
  49. /* 40 */
  50. "SSH_SMSG_AUTH_TIS_CHALLENGE",
  51. "SSH_CMSG_AUTH_TIS_RESPONSE",
  52. "SSH_CMSG_AUTH_KERBEROS",
  53. "SSH_SMSG_AUTH_KERBEROS_RESPONSE",
  54. "SSH_CMSG_HAVE_KERBEROS_TGT"
  55. };
  56. void
  57. badmsg(Msg *m, int want)
  58. {
  59. char *s, buf[20+ERRMAX];
  60. if(m==nil){
  61. snprint(buf, sizeof buf, "<early eof: %r>");
  62. s = buf;
  63. }else{
  64. snprint(buf, sizeof buf, "<unknown type %d>", m->type);
  65. s = buf;
  66. if(0 <= m->type && m->type < nelem(msgnames))
  67. s = msgnames[m->type];
  68. }
  69. if(want)
  70. error("got %s message expecting %s", s, msgnames[want]);
  71. error("got unexpected %s message", s);
  72. }
  73. Msg*
  74. allocmsg(Conn *c, int type, int len)
  75. {
  76. uchar *p;
  77. Msg *m;
  78. if(len > 256*1024)
  79. abort();
  80. m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
  81. setmalloctag(m, getcallerpc(&c));
  82. p = (uchar*)&m[1];
  83. m->c = c;
  84. m->bp = p;
  85. m->ep = p+len;
  86. m->wp = p;
  87. m->type = type;
  88. return m;
  89. }
  90. void
  91. unrecvmsg(Conn *c, Msg *m)
  92. {
  93. free(c->unget);
  94. c->unget = m;
  95. }
  96. static Msg*
  97. recvmsg0(Conn *c)
  98. {
  99. int pad;
  100. uchar *p, buf[4];
  101. ulong crc, crc0, len;
  102. Msg *m;
  103. if(c->unget){
  104. m = c->unget;
  105. c->unget = nil;
  106. return m;
  107. }
  108. if(readn(c->fd[0], buf, 4) != 4){
  109. werrstr("short net read: %r");
  110. return nil;
  111. }
  112. len = LONG(buf);
  113. if(len > 256*1024){
  114. werrstr("packet size far too big: %.8lux", len);
  115. return nil;
  116. }
  117. pad = 8 - len%8;
  118. m = (Msg*)emalloc(sizeof(Msg)+pad+len);
  119. setmalloctag(m, getcallerpc(&c));
  120. m->c = c;
  121. m->bp = (uchar*)&m[1];
  122. m->ep = m->bp + pad+len-4; /* -4: don't include crc */
  123. m->rp = m->bp;
  124. if(readn(c->fd[0], m->bp, pad+len) != pad+len){
  125. werrstr("short net read: %r");
  126. free(m);
  127. return nil;
  128. }
  129. if(c->cipher)
  130. c->cipher->decrypt(c->cstate, m->bp, len+pad);
  131. crc = sum32(0, m->bp, pad+len-4);
  132. p = m->bp + pad+len-4;
  133. crc0 = LONG(p);
  134. if(crc != crc0){
  135. werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
  136. free(m);
  137. return nil;
  138. }
  139. m->rp += pad;
  140. m->type = *m->rp++;
  141. return m;
  142. }
  143. Msg*
  144. recvmsg(Conn *c, int type)
  145. {
  146. Msg *m;
  147. while((m = recvmsg0(c)) != nil){
  148. debug(DBG_PROTO, "received %s len %d\n", msgnames[m->type], m->ep - m->rp);
  149. if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
  150. break;
  151. if(m->type == SSH_MSG_DEBUG)
  152. debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
  153. free(m);
  154. }
  155. if(type == 0){
  156. /* no checking */
  157. }else if(type == -1){
  158. /* must not be nil */
  159. if(m == nil)
  160. error(Ehangup);
  161. }else{
  162. /* must be given type */
  163. if(m==nil || m->type!=type)
  164. badmsg(m, type);
  165. }
  166. setmalloctag(m, getcallerpc(&c));
  167. return m;
  168. }
  169. int
  170. sendmsg(Msg *m)
  171. {
  172. int i, pad;
  173. uchar *p;
  174. ulong datalen, len, crc;
  175. Conn *c;
  176. datalen = m->wp - m->bp;
  177. len = datalen + 5;
  178. pad = 8 - len%8;
  179. debug(DBG_PROTO, "sending %s len %d\n", msgnames[m->type], datalen);
  180. p = m->bp;
  181. memmove(m->bp+4+pad+1, m->bp, datalen); /* slide data to correct position */
  182. PLONG(p, len);
  183. p += 4;
  184. if(m->c->cstate){
  185. for(i=0; i<pad; i++)
  186. *p++ = fastrand();
  187. }else{
  188. memset(p, 0, pad);
  189. p += pad;
  190. }
  191. *p++ = m->type;
  192. /* data already in position */
  193. p += datalen;
  194. crc = sum32(0, m->bp+4, pad+1+datalen);
  195. PLONG(p, crc);
  196. p += 4;
  197. c = m->c;
  198. qlock(c);
  199. if(c->cstate)
  200. c->cipher->encrypt(c->cstate, m->bp+4, len+pad);
  201. if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
  202. qunlock(c);
  203. free(m);
  204. return -1;
  205. }
  206. qunlock(c);
  207. free(m);
  208. return 0;
  209. }
  210. uchar
  211. getbyte(Msg *m)
  212. {
  213. if(m->rp >= m->ep)
  214. error(Edecode);
  215. return *m->rp++;
  216. }
  217. ushort
  218. getshort(Msg *m)
  219. {
  220. ushort x;
  221. if(m->rp+2 > m->ep)
  222. error(Edecode);
  223. x = SHORT(m->rp);
  224. m->rp += 2;
  225. return x;
  226. }
  227. ulong
  228. getlong(Msg *m)
  229. {
  230. ulong x;
  231. if(m->rp+4 > m->ep)
  232. error(Edecode);
  233. x = LONG(m->rp);
  234. m->rp += 4;
  235. return x;
  236. }
  237. char*
  238. getstring(Msg *m)
  239. {
  240. char *p;
  241. ulong len;
  242. /* overwrites length to make room for NUL */
  243. len = getlong(m);
  244. if(m->rp+len > m->ep)
  245. error(Edecode);
  246. p = (char*)m->rp-1;
  247. memmove(p, m->rp, len);
  248. p[len] = '\0';
  249. return p;
  250. }
  251. void*
  252. getbytes(Msg *m, int n)
  253. {
  254. uchar *p;
  255. if(m->rp+n > m->ep)
  256. error(Edecode);
  257. p = m->rp;
  258. m->rp += n;
  259. return p;
  260. }
  261. mpint*
  262. getmpint(Msg *m)
  263. {
  264. int n;
  265. n = (getshort(m)+7)/8; /* getshort returns # bits */
  266. return betomp(getbytes(m, n), n, nil);
  267. }
  268. RSApub*
  269. getRSApub(Msg *m)
  270. {
  271. RSApub *key;
  272. getlong(m);
  273. key = rsapuballoc();
  274. if(key == nil)
  275. error(Ememory);
  276. key->ek = getmpint(m);
  277. key->n = getmpint(m);
  278. setmalloctag(key, getcallerpc(&m));
  279. return key;
  280. }
  281. void
  282. putbyte(Msg *m, uchar x)
  283. {
  284. if(m->wp >= m->ep)
  285. error(Eencode);
  286. *m->wp++ = x;
  287. }
  288. void
  289. putshort(Msg *m, ushort x)
  290. {
  291. if(m->wp+2 > m->ep)
  292. error(Eencode);
  293. PSHORT(m->wp, x);
  294. m->wp += 2;
  295. }
  296. void
  297. putlong(Msg *m, ulong x)
  298. {
  299. if(m->wp+4 > m->ep)
  300. error(Eencode);
  301. PLONG(m->wp, x);
  302. m->wp += 4;
  303. }
  304. void
  305. putstring(Msg *m, char *s)
  306. {
  307. int len;
  308. len = strlen(s);
  309. putlong(m, len);
  310. putbytes(m, s, len);
  311. }
  312. void
  313. putbytes(Msg *m, void *a, long n)
  314. {
  315. if(m->wp+n > m->ep)
  316. error(Eencode);
  317. memmove(m->wp, a, n);
  318. m->wp += n;
  319. }
  320. void
  321. putmpint(Msg *m, mpint *b)
  322. {
  323. int bits, n;
  324. bits = mpsignif(b);
  325. putshort(m, bits);
  326. n = (bits+7)/8;
  327. if(m->wp+n > m->ep)
  328. error(Eencode);
  329. mptobe(b, m->wp, n, nil);
  330. m->wp += n;
  331. }
  332. void
  333. putRSApub(Msg *m, RSApub *key)
  334. {
  335. putlong(m, mpsignif(key->n));
  336. putmpint(m, key->ek);
  337. putmpint(m, key->n);
  338. }
  339. static ulong crctab[256];
  340. static void
  341. initsum32(void)
  342. {
  343. ulong crc, poly;
  344. int i, j;
  345. poly = 0xEDB88320;
  346. for(i = 0; i < 256; i++){
  347. crc = i;
  348. for(j = 0; j < 8; j++){
  349. if(crc & 1)
  350. crc = (crc >> 1) ^ poly;
  351. else
  352. crc >>= 1;
  353. }
  354. crctab[i] = crc;
  355. }
  356. }
  357. static ulong
  358. sum32(ulong lcrc, void *buf, int n)
  359. {
  360. static int first=1;
  361. uchar *s = buf;
  362. ulong crc = lcrc;
  363. if(first){
  364. first=0;
  365. initsum32();
  366. }
  367. while(n-- > 0)
  368. crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
  369. return crc;
  370. }
  371. mpint*
  372. rsapad(mpint *b, int n)
  373. {
  374. int i, pad, nbuf;
  375. uchar buf[2560];
  376. mpint *c;
  377. if(n > sizeof buf)
  378. error("buffer too small in rsapad");
  379. nbuf = (mpsignif(b)+7)/8;
  380. pad = n - nbuf;
  381. assert(pad >= 3);
  382. mptobe(b, buf, nbuf, nil);
  383. memmove(buf+pad, buf, nbuf);
  384. buf[0] = 0;
  385. buf[1] = 2;
  386. for(i=2; i<pad-1; i++)
  387. buf[i]=1+fastrand()%255;
  388. buf[pad-1] = 0;
  389. c = betomp(buf, n, nil);
  390. memset(buf, 0, sizeof buf);
  391. return c;
  392. }
  393. mpint*
  394. rsaunpad(mpint *b)
  395. {
  396. int i, n;
  397. uchar buf[2560];
  398. n = (mpsignif(b)+7)/8;
  399. if(n > sizeof buf)
  400. error("buffer too small in rsaunpad");
  401. mptobe(b, buf, n, nil);
  402. /* the initial zero has been eaten by the betomp -> mptobe sequence */
  403. if(buf[0] != 2)
  404. error("bad data in rsaunpad");
  405. for(i=1; i<n; i++)
  406. if(buf[i]==0)
  407. break;
  408. return betomp(buf+i, n-i, nil);
  409. }
  410. void
  411. mptoberjust(mpint *b, uchar *buf, int len)
  412. {
  413. int n;
  414. n = mptobe(b, buf, len, nil);
  415. assert(n >= 0);
  416. if(n < len){
  417. len -= n;
  418. memmove(buf+len, buf, n);
  419. memset(buf, 0, len);
  420. }
  421. }
  422. mpint*
  423. rsaencryptbuf(RSApub *key, uchar *buf, int nbuf)
  424. {
  425. int n;
  426. mpint *a, *b, *c;
  427. n = (mpsignif(key->n)+7)/8;
  428. a = betomp(buf, nbuf, nil);
  429. b = rsapad(a, n);
  430. mpfree(a);
  431. c = rsaencrypt(key, b, nil);
  432. mpfree(b);
  433. return c;
  434. }