-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaldr.c
83 lines (73 loc) · 2.26 KB
/
aldr.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/*
Name: aldr.c
Purpose: Fast sampling of random integers.
Author: CMU Probabilistic Computing Systems Lab
Copyright (C) 2025 CMU Probabilistic Computing Systems Lab, All Rights Reserved.
Released under Apache 2.0; refer to LICENSE.txt
*/
#include <stdlib.h>
#include "flip.h"
#include "aldr.h"
void aldr_free (struct aldr_s x) {
free(x.breadths);
free(x.leaves_flat);
}
struct aldr_s aldr_preprocess(uint32_t* a, uint32_t n) {
// this algorithm requires that
// 0 < n < (1 << 32) - 1
// 0 < sum(a) < (1 << 31) - 1
uint32_t m = 0;
for (uint32_t i = 0; i < n; ++i) {
m += a[i];
}
uint8_t k = 32 - __builtin_clz(m-1);
uint8_t K = k << 1; // depth
uint64_t c = (1ll << K) / m; // amplification factor
uint64_t r = (1ll << K) % m; // reject weight
uint32_t num_leaves = __builtin_popcountll(r);
for (uint32_t i = 0; i < n; ++i) {
num_leaves += __builtin_popcountll(c * a[i]);
}
uint32_t *breadths = calloc(K + 1, sizeof(*breadths));
uint32_t *leaves_flat = calloc(num_leaves, sizeof(*leaves_flat));
uint32_t location = 0;
for(uint8_t j = 0; j <= K; j++) {
uint64_t bitmask = (1ll << (K - j));
if (r & bitmask) {
leaves_flat[location] = 0;
++breadths[j];
++location;
}
for (uint32_t i = 0; i < n; ++i) {
uint64_t Qi = c*a[i];
if (Qi & bitmask) {
leaves_flat[location] = i + 1;
++breadths[j];
++location;
}
}
}
return (struct aldr_s){
.length_breadths = K+1,
.length_leaves_flat = num_leaves,
.breadths = breadths,
.leaves_flat = leaves_flat
};
}
uint32_t aldr_sample(struct aldr_s* f) {
for (;;) {
uint32_t depth = 0;
uint32_t location = 0;
uint32_t val = 0;
for (;;) {
if (val < f->breadths[depth]) {
uint32_t ans = f->leaves_flat[location + val];
if (ans) return ans - 1;
else break;
}
location += f->breadths[depth];
val = ((val - f->breadths[depth]) << 1) | flip();
++depth;
}
}
}