Skip to content

Commit 3d0a6cc

Browse files
committed
Change hashtype() so that the output is the dtype of h[idx] (where h is a list of hashes computed by an LSHFunction).
1 parent ed70483 commit 3d0a6cc

File tree

11 files changed

+28
-23
lines changed

11 files changed

+28
-23
lines changed

src/hashes/lphash.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ LSHFunction and SymmetricLSHFunction API compliance
180180
========================#
181181

182182
n_hashes(h::LpHash) = length(h.shift)
183-
hashtype(::LpHash) = Vector{Int32}
183+
hashtype(::LpHash) = Int32
184184

185185
# See Section 3.2 of the reference
186186
function single_hash_collision_probability(hashfn::LpHash, sim::Real)

src/hashes/minhash.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ end
137137
#========================
138138
LSHFunction and SymmetricLSHFunction API compliance
139139
========================#
140-
n_hashes(hashfn :: MinHash) = length(hashfn.mappings)
141-
hashtype(:: MinHash{T, I}) where {T, I} = I
140+
n_hashes(hashfn::MinHash) = length(hashfn.mappings)
141+
hashtype(::MinHash{T, I}) where {T, I} = I
142142
similarity(::MinHash) = jaccard
143143

144144
single_hash_collision_probability(::MinHash, sim) = sim

src/hashes/mips_hash.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,4 @@ query_hash(hashfn::MIPSHash, x) = MIPSHash_Q(hashfn, x)
233233
similarity(::MIPSHash) = inner_prod
234234

235235
n_hashes(hashfn::MIPSHash) = length(hashfn.shift)
236-
hashtype(::MIPSHash) = Vector{Int32}
236+
hashtype(::MIPSHash) = Int32

src/hashes/sign_alsh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,4 @@ query_hash(h::SignALSH, x) = SignALSH_Q(h, x)
205205

206206
n_hashes(h::SignALSH) = size(h.coeff_A, 1)
207207
similarity(::SignALSH) = inner_prod
208-
hashtype(::SignALSH) = BitArray{1}
208+
hashtype(::SignALSH) = Bool

src/hashes/simhash.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ current_max_input_size(hashfn::SimHash) = size(hashfn.coeff, 1)
120120
LSHFunction and SymmetricLSHFunction API compliance
121121
========================#
122122

123-
hashtype(::SimHash) = BitArray{1}
123+
hashtype(::SimHash) = Bool
124124
n_hashes(hashfn::SimHash) = size(hashfn.coeff, 2)
125125
similarity(::SimHash) = cossim
126126
single_hash_collision_probability(::SimHash, sim::Real) = (1 - acos(sim) / π)

src/tables/table.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ end
2020
# Outer constructors
2121
function LSHTable(hashfn; valtype=Any, unique_values=false, entrytype=Vector)
2222
htype = hashtype(hashfn)
23+
htype = (htype == Bool) ? BitArray{1} : Vector{htype}
24+
2325
vtype, etype, ltype = begin
2426
if entrytype <: Vector
2527
vtype = (valtype <: eltype(entrytype)) ? valtype : eltype(entrytype)

test/hashes/test_lphash.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Tests
1616
@test L1_hash.r == 2
1717
@test L1_hash.power == 1
1818
@test similarity(L1_hash) == ℓ1
19+
@test hashtype(L1_hash) == Int32
1920

2021
# Construct a hash for L^2 distance
2122
L2_hash = L2Hash(12; r = 3.4)
@@ -57,7 +58,7 @@ Tests
5758

5859
# Test 1: Vector{Float64} -> Vector{Int32}
5960
hashes = hashfn(randn(5))
60-
@test Vector{eltype(hashes)} == hashtype(hashfn)
61+
@test eltype(hashes) == hashtype(hashfn)
6162
@test isa(hashes, Vector{Int32})
6263

6364
# Test 2: Matrix{Float64} -> Matrix{Int32}

test/hashes/test_minhash.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ Tests
5656

5757
dataset = shuffle(symbols)[1:10]
5858
hashes = hashfn(dataset)
59-
hashes_match = true
6059

61-
for (hash, mapping) in zip(hashes, hashfn.mappings)
62-
if !hashes_match
63-
break
64-
end
60+
@test eltype(hashes) == hashtype(hashfn)
6561

66-
# Compute MinHash manually
67-
expected_hash = minimum(mapping[x] for x in dataset)
68-
hashes_match &= (hash == expected_hash)
69-
end
62+
@test let hashes_match = true
63+
for (hash, mapping) in zip(hashes, hashfn.mappings)
64+
if !hashes_match
65+
break
66+
end
7067

71-
@test hashes_match
68+
# Compute MinHash manually
69+
expected_hash = minimum(mapping[x] for x in dataset)
70+
hashes_match &= (hash == expected_hash)
71+
end
72+
hashes_match
73+
end
7274
end
7375

7476
@testset "Collision probabilities correlated with Jaccard similarity" begin

test/hashes/test_mips_hash.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Tests
1515
hashfn = MIPSHash(; maxnorm=1)
1616

1717
@test n_hashes(hashfn) == 1
18-
@test hashtype(hashfn) == Vector{Int32}
18+
@test hashtype(hashfn) == Int32
1919
@test similarity(hashfn) == inner_prod
2020
@test isa(hashfn, MIPSHash{Float32}) # Default dtype should be Float32
2121
@test isa(hashfn, LSH.AsymmetricLSHFunction)
@@ -63,8 +63,8 @@ Tests
6363

6464
@test isa(p_hashes, Matrix{Int32})
6565
@test isa(q_hashes, Matrix{Int32})
66-
@test Vector{eltype(p_hashes)} == hashtype(hashfn)
67-
@test Vector{eltype(q_hashes)} == hashtype(hashfn)
66+
@test hashtype(hashfn) == eltype(p_hashes)
67+
@test hashtype(hashfn) == eltype(q_hashes)
6868

6969
# Vector{Float64} -> Vector{Int32}
7070
x = randn(4)

test/hashes/test_sign_alsh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Tests
1313
@test n_hashes(hashfn) == 1
1414
@test isa(hashfn, SignALSH{Float32})
1515
@test isa(hashfn, LSH.AsymmetricLSHFunction)
16-
@test hashtype(hashfn) == BitArray{1}
16+
@test hashtype(hashfn) == Bool
1717
@test similarity(hashfn) == inner_prod
1818

1919
hashfn = SignALSH(32; maxnorm=1)
@@ -50,7 +50,7 @@ Tests
5050
qhashes = query_hash(hashfn, X)
5151

5252
@test eltype(ihashes) == eltype(qhashes)
53-
@test BitArray{1} == hashtype(hashfn)
53+
@test hashtype(hashfn) == Bool
5454

5555
# 1. Compute the indexing hashes manually
5656
norms = map(norm, eachcol(X))

0 commit comments

Comments
 (0)