bayes.c 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #include <u.h>
  2. #include <libc.h>
  3. #include <bio.h>
  4. #include <regexp.h>
  5. #include "hash.h"
  6. enum
  7. {
  8. MAXTAB = 256,
  9. MAXBEST = 32,
  10. };
  11. typedef struct Table Table;
  12. struct Table
  13. {
  14. char *file;
  15. Hash *hash;
  16. int nmsg;
  17. };
  18. typedef struct Word Word;
  19. struct Word
  20. {
  21. Stringtab *s; /* from hmsg */
  22. int count[MAXTAB]; /* counts from each table */
  23. double p[MAXTAB]; /* probabilities from each table */
  24. double mp; /* max probability */
  25. int mi; /* w.p[w.mi] = w.mp */
  26. };
  27. Table tab[MAXTAB];
  28. int ntab;
  29. Word best[MAXBEST];
  30. int mbest;
  31. int nbest;
  32. int debug;
  33. void
  34. usage(void)
  35. {
  36. fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
  37. exits("usage");
  38. }
  39. void*
  40. emalloc(int n)
  41. {
  42. void *v;
  43. v = mallocz(n, 1);
  44. if(v == nil)
  45. sysfatal("out of memory");
  46. return v;
  47. }
  48. void
  49. noteword(Word *w)
  50. {
  51. int i;
  52. for(i=nbest-1; i>=0; i--)
  53. if(w->mp < best[i].mp)
  54. break;
  55. i++;
  56. if(i >= mbest)
  57. return;
  58. if(nbest == mbest)
  59. nbest--;
  60. if(i < nbest)
  61. memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
  62. best[i] = *w;
  63. nbest++;
  64. }
  65. Hash*
  66. hread(char *s)
  67. {
  68. Hash *h;
  69. Biobuf *b;
  70. if((b = Bopenlock(s, OREAD)) == nil)
  71. sysfatal("open %s: %r", s);
  72. h = emalloc(sizeof(Hash));
  73. Breadhash(b, h, 1);
  74. Bterm(b);
  75. return h;
  76. }
  77. void
  78. main(int argc, char **argv)
  79. {
  80. int i, j, a, mi, oi, tot, keywords;
  81. double totp, p, xp[MAXTAB];
  82. Hash *hmsg;
  83. Word w;
  84. Stringtab *s, *t;
  85. Biobuf bout;
  86. mbest = 15;
  87. keywords = 0;
  88. ARGBEGIN{
  89. case 'D':
  90. debug = 1;
  91. break;
  92. case 'k':
  93. keywords = 1;
  94. break;
  95. case 'm':
  96. mbest = atoi(EARGF(usage()));
  97. if(mbest > MAXBEST)
  98. sysfatal("cannot keep more than %d words", MAXBEST);
  99. break;
  100. default:
  101. usage();
  102. }ARGEND
  103. for(i=0; i<argc; i++)
  104. if(strcmp(argv[i], "~") == 0)
  105. break;
  106. if(i > MAXTAB)
  107. sysfatal("cannot handle more than %d tables", MAXTAB);
  108. if(i+1 >= argc)
  109. usage();
  110. for(i=0; i<argc; i++){
  111. if(strcmp(argv[i], "~") == 0)
  112. break;
  113. tab[ntab].file = argv[i];
  114. tab[ntab].hash = hread(argv[i]);
  115. s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
  116. if(s == nil || s->count == 0)
  117. tab[ntab].nmsg = 1;
  118. else
  119. tab[ntab].nmsg = s->count;
  120. ntab++;
  121. }
  122. Binit(&bout, 1, OWRITE);
  123. oi = ++i;
  124. for(a=i; a<argc; a++){
  125. hmsg = hread(argv[a]);
  126. nbest = 0;
  127. for(s=hmsg->all; s; s=s->link){
  128. w.s = s;
  129. tot = 0;
  130. totp = 0.0;
  131. for(i=0; i<ntab; i++){
  132. t = findstab(tab[i].hash, s->str, s->n, 0);
  133. if(t == nil)
  134. w.count[i] = 0;
  135. else
  136. w.count[i] = t->count;
  137. tot += w.count[i];
  138. p = w.count[i]/(double)tab[i].nmsg;
  139. if(p >= 1.0)
  140. p = 1.0;
  141. w.p[i] = p;
  142. totp += p;
  143. }
  144. if(tot < 5){ /* word does not appear enough; give to box 0 */
  145. w.p[0] = 0.5;
  146. for(i=1; i<ntab; i++)
  147. w.p[i] = 0.1;
  148. w.mp = 0.5;
  149. w.mi = 0;
  150. noteword(&w);
  151. continue;
  152. }
  153. w.mp = 0.0;
  154. for(i=0; i<ntab; i++){
  155. p = w.p[i];
  156. p /= totp;
  157. if(p < 0.01)
  158. p = 0.01;
  159. else if(p > 0.99)
  160. p = 0.99;
  161. if(p > w.mp){
  162. w.mp = p;
  163. w.mi = i;
  164. }
  165. w.p[i] = p;
  166. }
  167. noteword(&w);
  168. }
  169. totp = 0.0;
  170. for(i=0; i<ntab; i++){
  171. p = 1.0;
  172. for(j=0; j<nbest; j++)
  173. p *= best[j].p[i];
  174. xp[i] = p;
  175. totp += p;
  176. }
  177. for(i=0; i<ntab; i++)
  178. xp[i] /= totp;
  179. mi = 0;
  180. for(i=1; i<ntab; i++)
  181. if(xp[i] > xp[mi])
  182. mi = i;
  183. if(oi != argc-1)
  184. Bprint(&bout, "%s: ", argv[a]);
  185. Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
  186. if(keywords){
  187. for(i=0; i<nbest; i++){
  188. Bprint(&bout, " ");
  189. Bwrite(&bout, best[i].s->str, best[i].s->n);
  190. Bprint(&bout, " %f", best[i].p[mi]);
  191. }
  192. }
  193. freehash(hmsg);
  194. Bprint(&bout, "\n");
  195. if(debug){
  196. for(i=0; i<nbest; i++){
  197. Bwrite(&bout, best[i].s->str, best[i].s->n);
  198. Bprint(&bout, " %f", best[i].p[mi]);
  199. if(best[i].p[mi] < best[i].mp)
  200. Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
  201. Bprint(&bout, "\n");
  202. }
  203. }
  204. }
  205. Bterm(&bout);
  206. }