msg.c 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. #include "ssh.h"
  2. static ulong sum32(ulong, void*, int);
  3. char *msgnames[] =
  4. {
  5. /* 0 */
  6. "SSH_MSG_NONE",
  7. "SSH_MSG_DISCONNECT",
  8. "SSH_SMSG_PUBLIC_KEY",
  9. "SSH_CMSG_SESSION_KEY",
  10. "SSH_CMSG_USER",
  11. "SSH_CMSG_AUTH_RHOSTS",
  12. "SSH_CMSG_AUTH_RSA",
  13. "SSH_SMSG_AUTH_RSA_CHALLENGE",
  14. "SSH_CMSG_AUTH_RSA_RESPONSE",
  15. "SSH_CMSG_AUTH_PASSWORD",
  16. /* 10 */
  17. "SSH_CMSG_REQUEST_PTY",
  18. "SSH_CMSG_WINDOW_SIZE",
  19. "SSH_CMSG_EXEC_SHELL",
  20. "SSH_CMSG_EXEC_CMD",
  21. "SSH_SMSG_SUCCESS",
  22. "SSH_SMSG_FAILURE",
  23. "SSH_CMSG_STDIN_DATA",
  24. "SSH_SMSG_STDOUT_DATA",
  25. "SSH_SMSG_STDERR_DATA",
  26. "SSH_CMSG_EOF",
  27. /* 20 */
  28. "SSH_SMSG_EXITSTATUS",
  29. "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
  30. "SSH_MSG_CHANNEL_OPEN_FAILURE",
  31. "SSH_MSG_CHANNEL_DATA",
  32. "SSH_MSG_CHANNEL_INPUT_EOF",
  33. "SSH_MSG_CHANNEL_OUTPUT_CLOSED",
  34. "SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
  35. "SSH_SMSG_X11_OPEN",
  36. "SSH_CMSG_PORT_FORWARD_REQUEST",
  37. "SSH_MSG_PORT_OPEN",
  38. /* 30 */
  39. "SSH_CMSG_AGENT_REQUEST_FORWARDING",
  40. "SSH_SMSG_AGENT_OPEN",
  41. "SSH_MSG_IGNORE",
  42. "SSH_CMSG_EXIT_CONFIRMATION",
  43. "SSH_CMSG_X11_REQUEST_FORWARDING",
  44. "SSH_CMSG_AUTH_RHOSTS_RSA",
  45. "SSH_MSG_DEBUG",
  46. "SSH_CMSG_REQUEST_COMPRESSION",
  47. "SSH_CMSG_MAX_PACKET_SIZE",
  48. "SSH_CMSG_AUTH_TIS",
  49. /* 40 */
  50. "SSH_SMSG_AUTH_TIS_CHALLENGE",
  51. "SSH_CMSG_AUTH_TIS_RESPONSE",
  52. "SSH_CMSG_AUTH_KERBEROS",
  53. "SSH_SMSG_AUTH_KERBEROS_RESPONSE",
  54. "SSH_CMSG_HAVE_KERBEROS_TGT"
  55. };
  56. void
  57. badmsg(Msg *m, int want)
  58. {
  59. char *s, buf[20+ERRMAX];
  60. if(m==nil){
  61. snprint(buf, sizeof buf, "<early eof: %r>");
  62. s = buf;
  63. }else{
  64. snprint(buf, sizeof buf, "<unknown type %d>", m->type);
  65. s = buf;
  66. if(0 <= m->type && m->type < nelem(msgnames))
  67. s = msgnames[m->type];
  68. }
  69. if(want)
  70. error("got %s message expecting %s", s, msgnames[want]);
  71. error("got unexpected %s message", s);
  72. }
  73. Msg*
  74. allocmsg(Conn *c, int type, int len)
  75. {
  76. uchar *p;
  77. Msg *m;
  78. if(len > 256*1024)
  79. abort();
  80. m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
  81. setmalloctag(m, getcallerpc(&c));
  82. p = (uchar*)&m[1];
  83. m->c = c;
  84. m->bp = p;
  85. m->ep = p+len;
  86. m->wp = p;
  87. m->type = type;
  88. return m;
  89. }
  90. void
  91. unrecvmsg(Conn *c, Msg *m)
  92. {
  93. debug(DBG_PROTO, "unreceived %s len %d\n", msgnames[m->type], m->ep - m->rp);
  94. free(c->unget);
  95. c->unget = m;
  96. }
  97. static Msg*
  98. recvmsg0(Conn *c)
  99. {
  100. int pad;
  101. uchar *p, buf[4];
  102. ulong crc, crc0, len;
  103. Msg *m;
  104. if(c->unget){
  105. m = c->unget;
  106. c->unget = nil;
  107. return m;
  108. }
  109. if(readn(c->fd[0], buf, 4) != 4){
  110. werrstr("short net read: %r");
  111. return nil;
  112. }
  113. len = LONG(buf);
  114. if(len > 256*1024){
  115. werrstr("packet size far too big: %.8lux", len);
  116. return nil;
  117. }
  118. pad = 8 - len%8;
  119. m = (Msg*)emalloc(sizeof(Msg)+pad+len);
  120. setmalloctag(m, getcallerpc(&c));
  121. m->c = c;
  122. m->bp = (uchar*)&m[1];
  123. m->ep = m->bp + pad+len-4; /* -4: don't include crc */
  124. m->rp = m->bp;
  125. if(readn(c->fd[0], m->bp, pad+len) != pad+len){
  126. werrstr("short net read: %r");
  127. free(m);
  128. return nil;
  129. }
  130. if(c->cipher)
  131. c->cipher->decrypt(c->cstate, m->bp, len+pad);
  132. crc = sum32(0, m->bp, pad+len-4);
  133. p = m->bp + pad+len-4;
  134. crc0 = LONG(p);
  135. if(crc != crc0){
  136. werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
  137. free(m);
  138. return nil;
  139. }
  140. m->rp += pad;
  141. m->type = *m->rp++;
  142. return m;
  143. }
  144. Msg*
  145. recvmsg(Conn *c, int type)
  146. {
  147. Msg *m;
  148. while((m = recvmsg0(c)) != nil){
  149. debug(DBG_PROTO, "received %s len %d\n", msgnames[m->type], m->ep - m->rp);
  150. if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
  151. break;
  152. if(m->type == SSH_MSG_DEBUG)
  153. debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
  154. free(m);
  155. }
  156. if(type == 0){
  157. /* no checking */
  158. }else if(type == -1){
  159. /* must not be nil */
  160. if(m == nil)
  161. error(Ehangup);
  162. }else{
  163. /* must be given type */
  164. if(m==nil || m->type!=type)
  165. badmsg(m, type);
  166. }
  167. setmalloctag(m, getcallerpc(&c));
  168. return m;
  169. }
  170. int
  171. sendmsg(Msg *m)
  172. {
  173. int i, pad;
  174. uchar *p;
  175. ulong datalen, len, crc;
  176. Conn *c;
  177. datalen = m->wp - m->bp;
  178. len = datalen + 5;
  179. pad = 8 - len%8;
  180. debug(DBG_PROTO, "sending %s len %d\n", msgnames[m->type], datalen);
  181. p = m->bp;
  182. memmove(m->bp+4+pad+1, m->bp, datalen); /* slide data to correct position */
  183. PLONG(p, len);
  184. p += 4;
  185. if(m->c->cstate){
  186. for(i=0; i<pad; i++)
  187. *p++ = fastrand();
  188. }else{
  189. memset(p, 0, pad);
  190. p += pad;
  191. }
  192. *p++ = m->type;
  193. /* data already in position */
  194. p += datalen;
  195. crc = sum32(0, m->bp+4, pad+1+datalen);
  196. PLONG(p, crc);
  197. p += 4;
  198. c = m->c;
  199. qlock(c);
  200. if(c->cstate)
  201. c->cipher->encrypt(c->cstate, m->bp+4, len+pad);
  202. if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
  203. qunlock(c);
  204. free(m);
  205. return -1;
  206. }
  207. qunlock(c);
  208. free(m);
  209. return 0;
  210. }
  211. uchar
  212. getbyte(Msg *m)
  213. {
  214. if(m->rp >= m->ep)
  215. error(Edecode);
  216. return *m->rp++;
  217. }
  218. ushort
  219. getshort(Msg *m)
  220. {
  221. ushort x;
  222. if(m->rp+2 > m->ep)
  223. error(Edecode);
  224. x = SHORT(m->rp);
  225. m->rp += 2;
  226. return x;
  227. }
  228. ulong
  229. getlong(Msg *m)
  230. {
  231. ulong x;
  232. if(m->rp+4 > m->ep)
  233. error(Edecode);
  234. x = LONG(m->rp);
  235. m->rp += 4;
  236. return x;
  237. }
  238. char*
  239. getstring(Msg *m)
  240. {
  241. char *p;
  242. ulong len;
  243. /* overwrites length to make room for NUL */
  244. len = getlong(m);
  245. if(m->rp+len > m->ep)
  246. error(Edecode);
  247. p = (char*)m->rp-1;
  248. memmove(p, m->rp, len);
  249. p[len] = '\0';
  250. return p;
  251. }
  252. void*
  253. getbytes(Msg *m, int n)
  254. {
  255. uchar *p;
  256. if(m->rp+n > m->ep)
  257. error(Edecode);
  258. p = m->rp;
  259. m->rp += n;
  260. return p;
  261. }
  262. mpint*
  263. getmpint(Msg *m)
  264. {
  265. int n;
  266. n = (getshort(m)+7)/8; /* getshort returns # bits */
  267. return betomp(getbytes(m, n), n, nil);
  268. }
  269. RSApub*
  270. getRSApub(Msg *m)
  271. {
  272. RSApub *key;
  273. getlong(m);
  274. key = rsapuballoc();
  275. if(key == nil)
  276. error(Ememory);
  277. key->ek = getmpint(m);
  278. key->n = getmpint(m);
  279. setmalloctag(key, getcallerpc(&m));
  280. return key;
  281. }
  282. void
  283. putbyte(Msg *m, uchar x)
  284. {
  285. if(m->wp >= m->ep)
  286. error(Eencode);
  287. *m->wp++ = x;
  288. }
  289. void
  290. putshort(Msg *m, ushort x)
  291. {
  292. if(m->wp+2 > m->ep)
  293. error(Eencode);
  294. PSHORT(m->wp, x);
  295. m->wp += 2;
  296. }
  297. void
  298. putlong(Msg *m, ulong x)
  299. {
  300. if(m->wp+4 > m->ep)
  301. error(Eencode);
  302. PLONG(m->wp, x);
  303. m->wp += 4;
  304. }
  305. void
  306. putstring(Msg *m, char *s)
  307. {
  308. int len;
  309. len = strlen(s);
  310. putlong(m, len);
  311. putbytes(m, s, len);
  312. }
  313. void
  314. putbytes(Msg *m, void *a, long n)
  315. {
  316. if(m->wp+n > m->ep)
  317. error(Eencode);
  318. memmove(m->wp, a, n);
  319. m->wp += n;
  320. }
  321. void
  322. putmpint(Msg *m, mpint *b)
  323. {
  324. int bits, n;
  325. bits = mpsignif(b);
  326. putshort(m, bits);
  327. n = (bits+7)/8;
  328. if(m->wp+n > m->ep)
  329. error(Eencode);
  330. mptobe(b, m->wp, n, nil);
  331. m->wp += n;
  332. }
  333. void
  334. putRSApub(Msg *m, RSApub *key)
  335. {
  336. putlong(m, mpsignif(key->n));
  337. putmpint(m, key->ek);
  338. putmpint(m, key->n);
  339. }
  340. static ulong crctab[256];
  341. static void
  342. initsum32(void)
  343. {
  344. ulong crc, poly;
  345. int i, j;
  346. poly = 0xEDB88320;
  347. for(i = 0; i < 256; i++){
  348. crc = i;
  349. for(j = 0; j < 8; j++){
  350. if(crc & 1)
  351. crc = (crc >> 1) ^ poly;
  352. else
  353. crc >>= 1;
  354. }
  355. crctab[i] = crc;
  356. }
  357. }
  358. static ulong
  359. sum32(ulong lcrc, void *buf, int n)
  360. {
  361. static int first=1;
  362. uchar *s = buf;
  363. ulong crc = lcrc;
  364. if(first){
  365. first=0;
  366. initsum32();
  367. }
  368. while(n-- > 0)
  369. crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
  370. return crc;
  371. }
  372. mpint*
  373. rsapad(mpint *b, int n)
  374. {
  375. int i, pad, nbuf;
  376. uchar buf[2560];
  377. mpint *c;
  378. if(n > sizeof buf)
  379. error("buffer too small in rsapad");
  380. nbuf = (mpsignif(b)+7)/8;
  381. pad = n - nbuf;
  382. assert(pad >= 3);
  383. mptobe(b, buf, nbuf, nil);
  384. memmove(buf+pad, buf, nbuf);
  385. buf[0] = 0;
  386. buf[1] = 2;
  387. for(i=2; i<pad-1; i++)
  388. buf[i]=1+fastrand()%255;
  389. buf[pad-1] = 0;
  390. c = betomp(buf, n, nil);
  391. memset(buf, 0, sizeof buf);
  392. return c;
  393. }
  394. mpint*
  395. rsaunpad(mpint *b)
  396. {
  397. int i, n;
  398. uchar buf[2560];
  399. n = (mpsignif(b)+7)/8;
  400. if(n > sizeof buf)
  401. error("buffer too small in rsaunpad");
  402. mptobe(b, buf, n, nil);
  403. /* the initial zero has been eaten by the betomp -> mptobe sequence */
  404. if(buf[0] != 2)
  405. error("bad data in rsaunpad");
  406. for(i=1; i<n; i++)
  407. if(buf[i]==0)
  408. break;
  409. return betomp(buf+i, n-i, nil);
  410. }
  411. void
  412. mptoberjust(mpint *b, uchar *buf, int len)
  413. {
  414. int n;
  415. n = mptobe(b, buf, len, nil);
  416. assert(n >= 0);
  417. if(n < len){
  418. len -= n;
  419. memmove(buf+len, buf, n);
  420. memset(buf, 0, len);
  421. }
  422. }
  423. mpint*
  424. rsaencryptbuf(RSApub *key, uchar *buf, int nbuf)
  425. {
  426. int n;
  427. mpint *a, *b, *c;
  428. n = (mpsignif(key->n)+7)/8;
  429. a = betomp(buf, nbuf, nil);
  430. b = rsapad(a, n);
  431. mpfree(a);
  432. c = rsaencrypt(key, b, nil);
  433. mpfree(b);
  434. return c;
  435. }