#include #include #include #include "dns.h" typedef struct Scan Scan; struct Scan { uchar *base; uchar *p; uchar *ep; char *err; }; #define NAME(x) gname(x, sp) #define SYMBOL(x) (x = gsym(sp)) #define STRING(x) (x = gstr(sp)) #define USHORT(x) (x = gshort(sp)) #define ULONG(x) (x = glong(sp)) #define UCHAR(x) (x = gchar(sp)) #define V4ADDR(x) (x = gv4addr(sp)) #define V6ADDR(x) (x = gv6addr(sp)) #define BYTES(x, y) (y = gbytes(sp, &x, len - (sp->p - data))) static char *toolong = "too long"; /* * get a ushort/ulong */ static ushort gchar(Scan *sp) { ushort x; if(sp->err) return 0; if(sp->ep - sp->p < 1){ sp->err = toolong; return 0; } x = sp->p[0]; sp->p += 1; return x; } static ushort gshort(Scan *sp) { ushort x; if(sp->err) return 0; if(sp->ep - sp->p < 2){ sp->err = toolong; return 0; } x = (sp->p[0]<<8) | sp->p[1]; sp->p += 2; return x; } static ulong glong(Scan *sp) { ulong x; if(sp->err) return 0; if(sp->ep - sp->p < 4){ sp->err = toolong; return 0; } x = (sp->p[0]<<24) | (sp->p[1]<<16) | (sp->p[2]<<8) | sp->p[3]; sp->p += 4; return x; } /* * get an ip address */ static DN* gv4addr(Scan *sp) { char addr[32]; if(sp->err) return 0; if(sp->ep - sp->p < 4){ sp->err = toolong; return 0; } snprint(addr, sizeof(addr), "%V", sp->p); sp->p += 4; return dnlookup(addr, Cin, 1); } static DN* gv6addr(Scan *sp) { char addr[64]; if(sp->err) return 0; if(sp->ep - sp->p < IPaddrlen){ sp->err = toolong; return 0; } snprint(addr, sizeof(addr), "%I", sp->p); sp->p += IPaddrlen; return dnlookup(addr, Cin, 1); } /* * get a string. make it an internal symbol. */ static DN* gsym(Scan *sp) { int n; char sym[Strlen+1]; if(sp->err) return 0; n = *(sp->p++); if(sp->p+n > sp->ep){ sp->err = toolong; return 0; } if(n > Strlen){ sp->err = "illegal string"; return 0; } strncpy(sym, (char*)sp->p, n); sym[n] = 0; sp->p += n; return dnlookup(sym, Csym, 1); } /* * get a string. don't make it an internal symbol. */ static Txt* gstr(Scan *sp) { int n; char sym[Strlen+1]; Txt *t; if(sp->err) return 0; n = *(sp->p++); if(sp->p+n > sp->ep){ sp->err = toolong; return 0; } if(n > Strlen){ sp->err = "illegal string"; return 0; } strncpy(sym, (char*)sp->p, n); sym[n] = 0; sp->p += n; t = emalloc(sizeof(*t)); t->next = nil; t->p = estrdup(sym); return t; } /* * get a sequence of bytes */ static int gbytes(Scan *sp, uchar **p, int n) { if(sp->err) return 0; if(sp->p+n > sp->ep || n < 0){ sp->err = toolong; return 0; } *p = emalloc(n); memmove(*p, sp->p, n); sp->p += n; return n; } /* * get a domain name. 'to' must point to a buffer at least Domlen+1 long. */ static char* gname(char *to, Scan *sp) { int len, off; int pointer; int n; char *tostart; char *toend; uchar *p; tostart = to; if(sp->err) goto err; pointer = 0; p = sp->p; toend = to + Domlen; for(len = 0; *p; len += pointer ? 0 : (n+1)){ if((*p & 0xc0) == 0xc0){ /* pointer to other spot in message */ if(pointer++ > 10){ sp->err = "pointer loop"; goto err; } off = ((p[0]<<8) + p[1]) & 0x3ff; p = sp->base + off; if(p >= sp->ep){ sp->err = "bad pointer"; goto err; } n = 0; continue; } n = *p++; if(len + n < Domlen - 1){ if(to + n > toend){ sp->err = toolong; goto err; } memmove(to, p, n); to += n; } p += n; if(*p){ if(to >= toend){ sp->err = toolong; goto err; } *to++ = '.'; } } *to = 0; if(pointer) sp->p += len + 2; /* + 2 for pointer */ else sp->p += len + 1; /* + 1 for the null domain */ return tostart; err: *tostart = 0; return tostart; } /* * convert the next RR from a message */ static RR* convM2RR(Scan *sp) { RR *rp; int type; int class; uchar *data; int len; char dname[Domlen+1]; Txt *t, **l; retry: NAME(dname); USHORT(type); USHORT(class); rp = rralloc(type); rp->owner = dnlookup(dname, class, 1); rp->type = type; ULONG(rp->ttl); rp->ttl += now; USHORT(len); data = sp->p; if(sp->p + len > sp->ep) sp->err = toolong; if(sp->err){ rrfree(rp); return 0; } switch(type){ default: /* unknown type, just ignore it */ sp->p = data + len; rrfree(rp); goto retry; case Thinfo: SYMBOL(rp->cpu); SYMBOL(rp->os); break; case Tcname: case Tmb: case Tmd: case Tmf: case Tns: rp->host = dnlookup(NAME(dname), Cin, 1); break; case Tmg: case Tmr: rp->mb = dnlookup(NAME(dname), Cin, 1); break; case Tminfo: rp->rmb = dnlookup(NAME(dname), Cin, 1); rp->mb = dnlookup(NAME(dname), Cin, 1); break; case Tmx: USHORT(rp->pref); rp->host = dnlookup(NAME(dname), Cin, 1); break; case Ta: V4ADDR(rp->ip); break; case Taaaa: V6ADDR(rp->ip); break; case Tptr: rp->ptr = dnlookup(NAME(dname), Cin, 1); break; case Tsoa: rp->host = dnlookup(NAME(dname), Cin, 1); rp->rmb = dnlookup(NAME(dname), Cin, 1); ULONG(rp->soa->serial); ULONG(rp->soa->refresh); ULONG(rp->soa->retry); ULONG(rp->soa->expire); ULONG(rp->soa->minttl); break; case Ttxt: l = &rp->txt; *l = nil; while(sp->p-data < len){ STRING(t); *l = t; l = &t->next; } break; case Tnull: BYTES(rp->null->data, rp->null->dlen); break; case Trp: rp->rmb = dnlookup(NAME(dname), Cin, 1); rp->rp = dnlookup(NAME(dname), Cin, 1); break; case Tkey: USHORT(rp->key->flags); UCHAR(rp->key->proto); UCHAR(rp->key->alg); BYTES(rp->key->data, rp->key->dlen); break; case Tsig: USHORT(rp->sig->type); UCHAR(rp->sig->alg); UCHAR(rp->sig->labels); ULONG(rp->sig->ttl); ULONG(rp->sig->exp); ULONG(rp->sig->incep); USHORT(rp->sig->tag); rp->sig->signer = dnlookup(NAME(dname), Cin, 1); BYTES(rp->sig->data, rp->sig->dlen); break; case Tcert: USHORT(rp->cert->type); USHORT(rp->cert->tag); UCHAR(rp->cert->alg); BYTES(rp->cert->data, rp->cert->dlen); break; } if(sp->p - data != len) sp->err = "bad RR len"; return rp; } /* * convert the next question from a message */ static RR* convM2Q(Scan *sp) { char dname[Domlen+1]; int type; int class; RR *rp; NAME(dname); USHORT(type); USHORT(class); if(sp->err) return 0; rp = rralloc(type); rp->owner = dnlookup(dname, class, 1); return rp; } static RR* rrloop(Scan *sp, int count, int quest) { int i; static char errbuf[64]; RR *first, *rp, **l; if(sp->err) return 0; l = &first; first = 0; for(i = 0; i < count; i++){ rp = quest ? convM2Q(sp) : convM2RR(sp); if(rp == 0) break; if(sp->err){ rrfree(rp); break; } *l = rp; l = &rp->next; } return first; } /* * convert the next DNS from a message stream */ char* convM2DNS(uchar *buf, int len, DNSmsg *m) { Scan scan; Scan *sp; char *err; scan.base = buf; scan.p = buf; scan.ep = buf + len; scan.err = 0; sp = &scan; memset(m, 0, sizeof(DNSmsg)); USHORT(m->id); USHORT(m->flags); USHORT(m->qdcount); USHORT(m->ancount); USHORT(m->nscount); USHORT(m->arcount); m->qd = rrloop(sp, m->qdcount, 1); m->an = rrloop(sp, m->ancount, 0); m->ns = rrloop(sp, m->nscount, 0); err = scan.err; /* live with bad ar's */ m->ar = rrloop(sp, m->arcount, 0); return err; }