123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- #include <u.h>
- #include <libc.h>
- #include <bio.h>
- #include <regexp.h>
- #include "hash.h"
- enum
- {
- MAXTAB = 256,
- MAXBEST = 32,
- };
- typedef struct Table Table;
- struct Table
- {
- char *file;
- Hash *hash;
- int nmsg;
- };
- typedef struct Word Word;
- struct Word
- {
- Stringtab *s; /* from hmsg */
- int count[MAXTAB]; /* counts from each table */
- double p[MAXTAB]; /* probabilities from each table */
- double mp; /* max probability */
- int mi; /* w.p[w.mi] = w.mp */
- };
- Table tab[MAXTAB];
- int ntab;
- Word best[MAXBEST];
- int mbest;
- int nbest;
- int debug;
- void
- usage(void)
- {
- fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
- exits("usage");
- }
- void*
- emalloc(int n)
- {
- void *v;
- v = mallocz(n, 1);
- if(v == nil)
- sysfatal("out of memory");
- return v;
- }
- void
- noteword(Word *w)
- {
- int i;
- for(i=nbest-1; i>=0; i--)
- if(w->mp < best[i].mp)
- break;
- i++;
- if(i >= mbest)
- return;
- if(nbest == mbest)
- nbest--;
- if(i < nbest)
- memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
- best[i] = *w;
- nbest++;
- }
- Hash*
- hread(char *s)
- {
- Hash *h;
- Biobuf *b;
- if((b = Bopenlock(s, OREAD)) == nil)
- sysfatal("open %s: %r", s);
- h = emalloc(sizeof(Hash));
- Breadhash(b, h, 1);
- Bterm(b);
- return h;
- }
- void
- main(int argc, char **argv)
- {
- int i, j, a, mi, oi, tot, keywords;
- double totp, p, xp[MAXTAB];
- Hash *hmsg;
- Word w;
- Stringtab *s, *t;
- Biobuf bout;
- mbest = 15;
- keywords = 0;
- ARGBEGIN{
- case 'D':
- debug = 1;
- break;
- case 'k':
- keywords = 1;
- break;
- case 'm':
- mbest = atoi(EARGF(usage()));
- if(mbest > MAXBEST)
- sysfatal("cannot keep more than %d words", MAXBEST);
- break;
- default:
- usage();
- }ARGEND
- for(i=0; i<argc; i++)
- if(strcmp(argv[i], "~") == 0)
- break;
- if(i > MAXTAB)
- sysfatal("cannot handle more than %d tables", MAXTAB);
- if(i+1 >= argc)
- usage();
- for(i=0; i<argc; i++){
- if(strcmp(argv[i], "~") == 0)
- break;
- tab[ntab].file = argv[i];
- tab[ntab].hash = hread(argv[i]);
- s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
- if(s == nil || s->count == 0)
- tab[ntab].nmsg = 1;
- else
- tab[ntab].nmsg = s->count;
- ntab++;
- }
- Binit(&bout, 1, OWRITE);
- oi = ++i;
- for(a=i; a<argc; a++){
- hmsg = hread(argv[a]);
- nbest = 0;
- for(s=hmsg->all; s; s=s->link){
- w.s = s;
- tot = 0;
- totp = 0.0;
- for(i=0; i<ntab; i++){
- t = findstab(tab[i].hash, s->str, s->n, 0);
- if(t == nil)
- w.count[i] = 0;
- else
- w.count[i] = t->count;
- tot += w.count[i];
- p = w.count[i]/(double)tab[i].nmsg;
- if(p >= 1.0)
- p = 1.0;
- w.p[i] = p;
- totp += p;
- }
- if(tot < 5){ /* word does not appear enough; give to box 0 */
- w.p[0] = 0.5;
- for(i=1; i<ntab; i++)
- w.p[i] = 0.1;
- w.mp = 0.5;
- w.mi = 0;
- noteword(&w);
- continue;
- }
- w.mp = 0.0;
- for(i=0; i<ntab; i++){
- p = w.p[i];
- p /= totp;
- if(p < 0.01)
- p = 0.01;
- else if(p > 0.99)
- p = 0.99;
- if(p > w.mp){
- w.mp = p;
- w.mi = i;
- }
- w.p[i] = p;
- }
- noteword(&w);
- }
- totp = 0.0;
- for(i=0; i<ntab; i++){
- p = 1.0;
- for(j=0; j<nbest; j++)
- p *= best[j].p[i];
- xp[i] = p;
- totp += p;
- }
- for(i=0; i<ntab; i++)
- xp[i] /= totp;
- mi = 0;
- for(i=1; i<ntab; i++)
- if(xp[i] > xp[mi])
- mi = i;
- if(oi != argc-1)
- Bprint(&bout, "%s: ", argv[a]);
- Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
- if(keywords){
- for(i=0; i<nbest; i++){
- Bprint(&bout, " ");
- Bwrite(&bout, best[i].s->str, best[i].s->n);
- Bprint(&bout, " %f", best[i].p[mi]);
- }
- }
- freehash(hmsg);
- Bprint(&bout, "\n");
- if(debug){
- for(i=0; i<nbest; i++){
- Bwrite(&bout, best[i].s->str, best[i].s->n);
- Bprint(&bout, " %f", best[i].p[mi]);
- if(best[i].p[mi] < best[i].mp)
- Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
- Bprint(&bout, "\n");
- }
- }
- }
- Bterm(&bout);
- }
|