Skip to content

Commit 078c007

Browse files
authored
Move combinatorial code previously in mmap-bitvec dependency
* Transplant combinatorial code from mmap-bitvec, use newly published mmap-bitvec * Remove injection of SSH private key (no longer needed with open mmap-bitvec crate) * Fix `cargo fmt`, clippy warnings * Update CI Rust toolchain version to 1.60.0
1 parent 62427d9 commit 078c007

File tree

6 files changed

+215
-30
lines changed

6 files changed

+215
-30
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,9 @@ jobs:
1515
- uses: actions-rs/toolchain@v1
1616
with:
1717
profile: minimal
18-
toolchain: 1.40.0
18+
toolchain: 1.60.0
1919
override: true
2020

21-
- name: create SSH key
22-
uses: webfactory/ssh-agent@v0.2.0
23-
with:
24-
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
25-
2621
- name: version info
2722
run: rustc --version; cargo --version;
2823

@@ -41,11 +36,6 @@ jobs:
4136
toolchain: nightly
4237
override: true
4338

44-
- name: create SSH key
45-
uses: webfactory/ssh-agent@v0.2.0
46-
with:
47-
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
48-
4939
- name: version info
5040
run: rustc --version; cargo --version;
5141

@@ -64,11 +54,6 @@ jobs:
6454
toolchain: stable
6555
override: true
6656

67-
- name: create SSH key
68-
uses: webfactory/ssh-agent@v0.2.0
69-
with:
70-
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
71-
7257
- uses: actions-rs/cargo@v1
7358
with:
7459
command: fmt
@@ -86,11 +71,6 @@ jobs:
8671
toolchain: stable
8772
override: true
8873

89-
- name: create SSH key
90-
uses: webfactory/ssh-agent@v0.2.0
91-
with:
92-
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
93-
9474
- uses: actions-rs/cargo@v1
9575
with:
9676
command: clippy

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ edition = "2018"
66

77
[dependencies]
88
bincode = "1"
9-
mmap-bitvec = { git="ssh://git@github.com/onecodex/mmap-bitvec.git" }
9+
mmap-bitvec = "0.4.0"
1010
murmurhash3 = "0.0.5"
1111
serde = { version = "1.0", features = ["derive"] }
12+
once_cell = "1.3.1"
1213

1314
[dev-dependencies]
1415
criterion = "0.3"

src/bfield.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::io;
22
use std::path::{Path, PathBuf};
33

4-
use mmap_bitvec::combinatorial::rank;
4+
use crate::combinatorial::rank;
55
use serde::de::DeserializeOwned;
66
use serde::Serialize;
77

src/bfield_member.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use std::intrinsics;
55
use std::io;
66
use std::path::{Path, PathBuf};
77

8+
use crate::combinatorial::{rank, unrank};
89
use bincode::{deserialize, serialize};
9-
use mmap_bitvec::combinatorial::{rank, unrank};
1010
use mmap_bitvec::{BitVector, MmapBitVec};
1111
use murmurhash3::murmurhash3_x64_128;
1212
use serde::de::DeserializeOwned;
@@ -43,8 +43,6 @@ fn prefetch_read(pointer: *const u8) {
4343
unsafe {
4444
arch_impl::_mm_prefetch::<{ arch_impl::_MM_HINT_NTA }>(pointer as *const i8);
4545
}
46-
47-
return;
4846
}
4947
}
5048

@@ -65,6 +63,7 @@ pub(crate) struct BFieldMember<T> {
6563

6664
/// A simple type alias to make the code more readable
6765
pub type BFieldVal = u32;
66+
/// Magic bytes used to indicate the `bfield` file type for `MmapBitvec`
6867
const BF_MAGIC: [u8; 2] = [0xBF, 0x1D];
6968

7069
#[derive(Debug, PartialEq)]
@@ -95,7 +94,7 @@ impl<T: Clone + DeserializeOwned + Serialize> BFieldMember<T> {
9594
MmapBitVec::from_memory(size)?
9695
} else {
9796
let header: Vec<u8> = serialize(&bf_params).unwrap();
98-
MmapBitVec::create(&filename, size, BF_MAGIC, &header)?
97+
MmapBitVec::create(&filename, size, Some(BF_MAGIC), &header)?
9998
};
10099

101100
Ok(BFieldMember {
@@ -123,7 +122,7 @@ impl<T: Clone + DeserializeOwned + Serialize> BFieldMember<T> {
123122
let header: Vec<u8> = serialize(&self.params).unwrap();
124123
self.bitvec
125124
.get()
126-
.save_to_disk(&self.filename, BF_MAGIC, &header)?;
125+
.save_to_disk(&self.filename, Some(BF_MAGIC), &header)?;
127126
let bitvec = BitVec::new(MmapBitVec::open(&self.filename, Some(&BF_MAGIC), false)?);
128127
Ok(Self {
129128
bitvec,
@@ -213,7 +212,7 @@ impl<T: Clone + DeserializeOwned + Serialize> BFieldMember<T> {
213212
let pos = marker_pos(hash, marker_ix, self.bitvec.get().size(), marker_width);
214213
positions[marker_ix] = pos;
215214
unsafe {
216-
let byte_idx_st = (pos >> 3) as usize;
215+
let byte_idx_st = pos >> 3;
217216
let ptr: *const u8 = self.bitvec.get().mmap.as_ptr().add(byte_idx_st);
218217
prefetch_read(ptr);
219218
}

src/combinatorial.rs

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
use once_cell::sync::Lazy;
2+
use std::collections::HashMap;
3+
use std::convert::TryFrom;
4+
5+
const MARKER_TABLE_SIZE: usize = 200_000;
6+
7+
// TODO: replace with const fn when it is possible
8+
// (for and if are not allowed in const fn on current stable)
9+
// https://github.com/rust-lang/rust/issues/87575
10+
static MARKER_TABLES: Lazy<HashMap<u8, Vec<u128>>> = Lazy::new(|| {
11+
let mut m = HashMap::new();
12+
for k in 1..10u8 {
13+
let mut table = vec![0u128; MARKER_TABLE_SIZE];
14+
let table_size = if k == 1 {
15+
128
16+
} else if k == 2 {
17+
8128
18+
} else {
19+
table.len()
20+
};
21+
22+
table[0] = ((1 << k) - 1) as u128;
23+
for i in 1..table_size {
24+
table[i] = next_rank(table[i - 1]);
25+
}
26+
m.insert(k, table);
27+
}
28+
m
29+
});
30+
31+
/// https://en.wikipedia.org/wiki/Combinatorial_number_system
32+
pub fn rank(value: usize, k: u8) -> u128 {
33+
assert!(k > 0 && k < 10, "kappa needs to be less than 10");
34+
// it's possible this may overflow if value > (128 choose k) or return
35+
// a bad value (0) if value > (128 choose k) and k == 1 or 2
36+
if value >= MARKER_TABLE_SIZE {
37+
let mut marker = MARKER_TABLES[&k][MARKER_TABLE_SIZE - 1];
38+
for _ in 0..(value - MARKER_TABLE_SIZE) {
39+
// next_rank would overflow if we pass 0, we return it instead
40+
if marker == 0 {
41+
return marker;
42+
}
43+
marker = next_rank(marker);
44+
}
45+
marker
46+
} else {
47+
MARKER_TABLES[&k][value]
48+
}
49+
}
50+
51+
/// https://en.wikipedia.org/wiki/Combinatorial_number_system
52+
pub fn unrank(marker: u128) -> usize {
53+
// val = choose(rank(0), 1) + choose(rank(1), 2) + choose(rank(2), 3) + ...
54+
let mut working_marker = marker;
55+
let mut value = 0u64;
56+
let mut idx = 0;
57+
while working_marker != 0 {
58+
let rank = u64::from(working_marker.trailing_zeros());
59+
working_marker -= 1 << rank;
60+
idx += 1;
61+
value += choose(rank, idx);
62+
}
63+
value as usize
64+
}
65+
66+
/// (Hopefully) fast implementation of a binomial
67+
///
68+
/// This uses a preset group of equations for k < 8 and then falls back to a
69+
/// multiplicative implementation that tries to prevent overflows while
70+
/// maintaining all results as exact integers.
71+
#[inline]
72+
pub fn choose(n: u64, k: u8) -> u64 {
73+
// (extra border condition for speed-up?)
74+
// if n == u64::from(k) {
75+
// return 1;
76+
// }
77+
match k {
78+
0 => 1,
79+
1 => n,
80+
2 => n * (n - 1) / 2,
81+
3 => n * (n - 1) * (n - 2) / 6,
82+
4 => n * (n - 1) * (n - 2) * (n - 3) / 24,
83+
5 => n * (n - 1) * (n - 2) * (n - 3) * (n - 4) / 120,
84+
6 => n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n - 5) / 720,
85+
7 => n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n - 5) * (n - 6) / 5040,
86+
_ => {
87+
let mut num: u128 = 1;
88+
let mut denom: u128 = 1;
89+
for i in 1..=u128::from(k) {
90+
num *= u128::from(n) + 1 - i;
91+
if num % i == 0 {
92+
num /= i;
93+
continue;
94+
}
95+
denom *= i;
96+
if num % denom == 0 {
97+
num /= denom;
98+
denom = 1;
99+
}
100+
}
101+
TryFrom::try_from(num / denom)
102+
.unwrap_or_else(|_| panic!("{} choose {} is greater than 2**64", n, k))
103+
// (or recursively) choose(n - 1, k - 1) + choose(n-1, k)
104+
// for floats, this should work since they handle fractions:
105+
// (1..u64::from(k)).map(|i| (n + 1 - i) / i).product(),
106+
}
107+
}
108+
}
109+
110+
#[inline]
111+
fn next_rank(marker: u128) -> u128 {
112+
if marker == 0 {
113+
unreachable!("Got next_rank called with marker == 0");
114+
}
115+
let t = marker | (marker - 1);
116+
(t + 1) | (((!t & (t + 1)) - 1) >> (marker.trailing_zeros() + 1))
117+
}
118+
119+
#[cfg(test)]
120+
mod tests {
121+
use super::*;
122+
123+
#[test]
124+
fn test_rank() {
125+
assert_eq!(rank(0, 3), 7);
126+
assert_eq!(rank(2, 3), 13);
127+
assert_eq!(rank(0, 3).count_ones(), 3);
128+
assert_eq!(rank(2, 3).count_ones(), 3);
129+
assert_eq!(rank(35001, 4).count_ones(), 4);
130+
131+
// Maximum value of 64 choose 3
132+
assert_eq!(rank(41663, 3).count_ones(), 3);
133+
}
134+
135+
#[test]
136+
fn test_unrank() {
137+
// 3 bit markers
138+
assert_eq!(unrank(7), 0);
139+
assert_eq!(unrank(13), 2);
140+
}
141+
142+
#[test]
143+
fn test_rank_and_unrank() {
144+
for k in 1..4u8 {
145+
for value in [1 as usize, 23, 45].iter() {
146+
assert_eq!(unrank(rank(*value, k)), *value);
147+
}
148+
}
149+
}
150+
151+
#[test]
152+
fn test_choose() {
153+
assert_eq!(choose(1, 1), 1);
154+
assert_eq!(choose(10, 1), 10);
155+
156+
assert_eq!(choose(5, 2), 10);
157+
158+
assert_eq!(choose(5, 3), 10);
159+
160+
assert_eq!(choose(5, 4), 5);
161+
162+
assert_eq!(choose(5, 5), 1);
163+
assert_eq!(choose(20, 5), 15504);
164+
165+
assert_eq!(choose(20, 6), 38760);
166+
167+
assert_eq!(choose(20, 7), 77520);
168+
assert_eq!(choose(23, 7), 245157);
169+
170+
// test the last branch
171+
assert_eq!(choose(8, 8), 1);
172+
assert_eq!(choose(9, 8), 9);
173+
174+
// every value of 64 choose n should work
175+
assert_eq!(choose(64, 0), 1);
176+
assert_eq!(choose(64, 1), 64);
177+
assert_eq!(choose(64, 16), 488526937079580);
178+
assert_eq!(choose(64, 32), 1832624140942590534);
179+
assert_eq!(choose(64, 48), 488526937079580);
180+
assert_eq!(choose(64, 63), 64);
181+
assert_eq!(choose(64, 64), 1);
182+
183+
// super high values can overflow; these are approaching the limit
184+
assert_eq!(choose(128, 11), 2433440563030400);
185+
assert_eq!(choose(128, 13), 211709328983644800);
186+
assert_eq!(choose(256, 9), 11288510714272000);
187+
}
188+
189+
#[test]
190+
#[should_panic(expected = "256 choose 20 is greater than 2**64")]
191+
fn test_choose_overflow() {
192+
assert_eq!(choose(256, 20), 11288510714272000);
193+
}
194+
195+
#[test]
196+
fn test_next_rank() {
197+
assert_eq!(next_rank(0b1), 0b10);
198+
assert_eq!(next_rank(0b100), 0b1000);
199+
200+
assert_eq!(next_rank(0b111), 0b1011);
201+
assert_eq!(next_rank(0b1000101), 0b1000110);
202+
}
203+
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
66
mod bfield;
77
mod bfield_member;
8+
/// Some combinatorial utilities
9+
mod combinatorial;
810

911
pub use crate::bfield::BField;
1012
pub use crate::bfield_member::BFieldVal;
11-
pub use mmap_bitvec::combinatorial::choose;
13+
pub use combinatorial::choose;

0 commit comments

Comments
 (0)