ssort6.c 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #include <u.h>
  2. #include <libc.h>
  3. #include "ssort.h"
  4. #define pred(i, h) ((t=(i)-(h))<0? t+n: t)
  5. #define succ(i, h) ((t=(i)+(h))>=n? t-n: t)
  6. enum
  7. {
  8. BUCK = ~(~0u>>1), /* high bit */
  9. MAXI = ~0u>>1, /* biggest int */
  10. };
  11. typedef int Elem;
  12. static void qsort2(Elem*, Elem*, int n);
  13. static int ssortit(int a[], int p[], int s[], int q[], int n, int h, int *pe, int nbuck);
  14. static void lift(int si, int q[], int i);
  15. int sharedlen(int i, int j, int s[], int q[]);
  16. int
  17. ssort(int a[], int s[], int n)
  18. {
  19. int i, l;
  20. int c, cc, ncc, lab, cum, nbuck;
  21. int k;
  22. int *p = 0;
  23. int result = 0;
  24. int *q = 0;
  25. int *al;
  26. int *pl;
  27. # define finish(r) if(1){result=r; goto out;}else
  28. for(k=0,i=0; i<n; i++)
  29. if(a[i] > k)
  30. k = a[i]; /* max element */
  31. k++;
  32. if(k>n)
  33. finish(2);
  34. nbuck = 0;
  35. p = malloc(n*sizeof(int));
  36. if(p == 0)
  37. finish(1);
  38. if(s) { /* shared lengths */
  39. q = malloc(((n+1)>>1)*sizeof(int));
  40. if(q == 0)
  41. finish(1);
  42. for(i=0; i<n; i++)
  43. s[i] = q[i>>1] = MAXI;
  44. q[i>>1] = MAXI;
  45. }
  46. pl = p + n - k;
  47. al = a;
  48. memset(pl, -1, k*sizeof(int));
  49. for(i=0; i<n; i++) { /* (1) link */
  50. l = a[i];
  51. al[i] = pl[l];
  52. pl[l] = i;
  53. }
  54. for(i=0; i<k; i++) /* check input - no holes */
  55. if(pl[i]<0)
  56. finish(2);
  57. lab = 0; /* (2) create p and label a */
  58. cum = 0;
  59. i = 0;
  60. for(c = 0; c < k; c++){
  61. for(cc = pl[c]; cc != -1; cc = ncc){
  62. ncc = al[cc];
  63. al[cc] = lab;
  64. cum++;
  65. p[i++] = cc;
  66. }
  67. if(lab + 1 == cum) {
  68. i--;
  69. } else {
  70. p[i-1] |= BUCK;
  71. nbuck++;
  72. }
  73. if(s) {
  74. s[lab] = 0;
  75. lift(0, q, lab);
  76. }
  77. lab = cum;
  78. }
  79. ssortit(a, p, s, q, n, 1, p+i, nbuck);
  80. memcpy(a, p, n*sizeof(int));
  81. out:
  82. free(p);
  83. free(q);
  84. return result;
  85. }
  86. /*
  87. * calculate the suffix array for the bytes in buf,
  88. * terminated by a unique end marker less than any character in buf
  89. * returns the index of the identity permutation,
  90. * and -1 if there was an error.
  91. */
  92. int
  93. ssortbyte(uchar buf[], int p[], int s[], int n)
  94. {
  95. int *a, *q, buckets[256*256];
  96. int i, last, lastc, cum, c, cc, ncc, lab, id, nbuck;
  97. a = malloc((n+1)*sizeof(int));
  98. if(a == nil)
  99. return -1;
  100. q = nil;
  101. if(s) { /* shared lengths */
  102. q = malloc(((n+2)>>1)*sizeof(int));
  103. if(q == nil){
  104. free(a);
  105. return -1;
  106. }
  107. for(i=0; i<n+1; i++)
  108. s[i] = q[i>>1] = MAXI;
  109. q[i>>1] = MAXI;
  110. }
  111. memset(buckets, -1, sizeof(buckets));
  112. c = buf[n-1] << 8;
  113. last = c;
  114. for(i = n - 2; i >= 0; i--){
  115. c = (buf[i] << 8) | (c >> 8);
  116. a[i] = buckets[c];
  117. buckets[c] = i;
  118. }
  119. /*
  120. * end of string comes before anything else
  121. */
  122. a[n] = 0;
  123. if(s) {
  124. s[0] = 0;
  125. lift(0, q, 0);
  126. }
  127. lab = 1;
  128. cum = 1;
  129. i = 0;
  130. lastc = 1; /* won't match c & 0xff00 for any c */
  131. nbuck = 0;
  132. for(c = 0; c < 256*256; c++) {
  133. /*
  134. * last character is followed by unique end of string
  135. */
  136. if(c == last) {
  137. a[n-1] = lab;
  138. if(s) {
  139. s[lab] = 0;
  140. lift(0, q, lab);
  141. }
  142. cum++;
  143. lab++;
  144. lastc = c & 0xff00;
  145. }
  146. for(cc = buckets[c]; cc != -1; cc = ncc) {
  147. ncc = a[cc];
  148. a[cc] = lab;
  149. cum++;
  150. p[i++] = cc;
  151. }
  152. if(lab == cum)
  153. continue;
  154. if(lab + 1 == cum)
  155. i--;
  156. else {
  157. p[i - 1] |= BUCK;
  158. nbuck++;
  159. }
  160. if(s) {
  161. cc = (c & 0xff00) == lastc;
  162. s[lab] = cc;
  163. lift(cc, q, lab);
  164. }
  165. lastc = c & 0xff00;
  166. lab = cum;
  167. }
  168. id = ssortit(a, p, s, q, n+1, 2, p+i, nbuck);
  169. free(a);
  170. free(q);
  171. return id;
  172. }
  173. /*
  174. * can get an interval for the shared lengths from [h,3h) by recording h
  175. * rather than h + sharedlen(..) when relabelling. if so, no calls to lift are needed.
  176. */
  177. static int
  178. ssortit(int a[], int p[], int shared[], int q[], int n, int h, int *pe, int nbuck)
  179. {
  180. int *s, *ss, *packing, *sorting;
  181. int v, sv, vv, packed, lab, t, i;
  182. for(; h < n && p < pe; h=2*h) {
  183. packing = p;
  184. nbuck = 0;
  185. for(sorting = p; sorting < pe; sorting = s){
  186. /*
  187. * find length of stuff to sort
  188. */
  189. lab = a[*sorting];
  190. for(s = sorting; ; s++) {
  191. sv = *s;
  192. v = a[succ(sv & ~BUCK, h)];
  193. if(v & BUCK)
  194. v = lab;
  195. a[sv & ~BUCK] = v | BUCK;
  196. if(sv & BUCK)
  197. break;
  198. }
  199. *s++ &= ~BUCK;
  200. nbuck++;
  201. qsort2(sorting, a, s - sorting);
  202. v = a[*sorting];
  203. a[*sorting] = lab;
  204. packed = 0;
  205. for(ss = sorting + 1; ss < s; ss++) {
  206. sv = *ss;
  207. vv = a[sv];
  208. if(vv == v) {
  209. *packing++ = ss[-1];
  210. packed++;
  211. } else {
  212. if(packed) {
  213. *packing++ = ss[-1] | BUCK;
  214. }
  215. lab += packed + 1;
  216. if(shared) {
  217. v = h + sharedlen(v&~BUCK, vv&~BUCK, shared, q);
  218. shared[lab] = v;
  219. lift(v, q, lab);
  220. }
  221. packed = 0;
  222. v = vv;
  223. }
  224. a[sv] = lab;
  225. }
  226. if(packed) {
  227. *packing++ = ss[-1] | BUCK;
  228. }
  229. }
  230. pe = packing;
  231. }
  232. /*
  233. * reconstuct the permutation matrix
  234. * return index of the entire string
  235. */
  236. v = a[0];
  237. for(i = 0; i < n; i++)
  238. p[a[i]] = i;
  239. return v;
  240. }
  241. /* Propagate a new entry s[i], with value si, into q[]. */
  242. static void
  243. lift(int si, int q[], int i)
  244. {
  245. for(i >>= 1; q[i] > si; i &= ~-i) /* squash the low 1-bit */
  246. q[i] = si;
  247. }
  248. /*
  249. * Find in s[] the minimum value over the interval i<=k<=j, using
  250. * tree q[] to do logarithmic, rather than linear search
  251. */
  252. int
  253. sharedlen(int i, int j, int s[], int q[])
  254. {
  255. int k, v;
  256. int min = MAXI;
  257. int max = 0;
  258. if(i > j) { /* swap i & j */
  259. i ^= j;
  260. j ^= i;
  261. i ^= j;
  262. }
  263. i++;
  264. while(i <= j && min > max) {
  265. k = i & -i;
  266. if(i & 1)
  267. v = s[i];
  268. else
  269. v = q[i>>1];
  270. if(i+k > j+1) {
  271. if(v > max)
  272. max = v;
  273. if(s[i] < min)
  274. min = s[i];
  275. i++;
  276. } else {
  277. if(v < min)
  278. min = v;
  279. i += k;
  280. }
  281. }
  282. return min;
  283. }
  284. /*
  285. * qsort specialized for sorting permutations based on successors
  286. */
  287. static void
  288. vecswap2(Elem *a, Elem *b, int n)
  289. {
  290. while (n-- > 0) {
  291. Elem t = *a;
  292. *a++ = *b;
  293. *b++ = t;
  294. }
  295. }
  296. #define swap2(a, b) { t = *(a); *(a) = *(b); *(b) = t; }
  297. #define ptr2char(i, asucc) (asucc[*(i)])
  298. static Elem*
  299. med3(Elem *a, Elem *b, Elem *c, Elem *asucc)
  300. {
  301. Elem va, vb, vc;
  302. if ((va=ptr2char(a, asucc)) == (vb=ptr2char(b, asucc)))
  303. return a;
  304. if ((vc=ptr2char(c, asucc)) == va || vc == vb)
  305. return c;
  306. return va < vb ?
  307. (vb < vc ? b : (va < vc ? c : a))
  308. : (vb > vc ? b : (va < vc ? a : c));
  309. }
  310. static void
  311. inssort(Elem *a, Elem *asucc, int n)
  312. {
  313. Elem *pi, *pj, t;
  314. for (pi = a + 1; --n > 0; pi++)
  315. for (pj = pi; pj > a; pj--) {
  316. if(ptr2char(pj-1, asucc) <= ptr2char(pj, asucc))
  317. break;
  318. swap2(pj, pj-1);
  319. }
  320. }
  321. static void
  322. qsort2(Elem *a, Elem *asucc, int n)
  323. {
  324. Elem d, r, partval;
  325. Elem *pa, *pb, *pc, *pd, *pl, *pm, *pn, t;
  326. if (n < 15) {
  327. inssort(a, asucc, n);
  328. return;
  329. }
  330. pl = a;
  331. pm = a + (n >> 1);
  332. pn = a + (n-1);
  333. if (n > 30) { /* On big arrays, pseudomedian of 9 */
  334. d = (n >> 3);
  335. pl = med3(pl, pl+d, pl+2*d, asucc);
  336. pm = med3(pm-d, pm, pm+d, asucc);
  337. pn = med3(pn-2*d, pn-d, pn, asucc);
  338. }
  339. pm = med3(pl, pm, pn, asucc);
  340. swap2(a, pm);
  341. partval = ptr2char(a, asucc);
  342. pa = pb = a + 1;
  343. pc = pd = a + n-1;
  344. for (;;) {
  345. while (pb <= pc && (r = ptr2char(pb, asucc)-partval) <= 0) {
  346. if (r == 0) {
  347. swap2(pa, pb);
  348. pa++;
  349. }
  350. pb++;
  351. }
  352. while (pb <= pc && (r = ptr2char(pc, asucc)-partval) >= 0) {
  353. if (r == 0) {
  354. swap2(pc, pd);
  355. pd--;
  356. }
  357. pc--;
  358. }
  359. if (pb > pc)
  360. break;
  361. swap2(pb, pc);
  362. pb++;
  363. pc--;
  364. }
  365. pn = a + n;
  366. r = pa-a;
  367. if(pb-pa < r)
  368. r = pb-pa;
  369. vecswap2(a, pb-r, r);
  370. r = pn-pd-1;
  371. if(pd-pc < r)
  372. r = pd-pc;
  373. vecswap2(pb, pn-r, r);
  374. if ((r = pb-pa) > 1)
  375. qsort2(a, asucc, r);
  376. if ((r = pd-pc) > 1)
  377. qsort2(a + n-r, asucc, r);
  378. }