Skip to content

Commit 372ef31

Browse files
authored
added simple classify method for kmeans (#255)
1 parent ae46903 commit 372ef31

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OnlineStats"
22
uuid = "a15396b6-48d5-5d58-9928-6d29437db91e"
3-
version = "1.5.14"
3+
version = "1.5.15"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/stats/stats.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ Approximate K-Means clustering of `k` clusters.
265265
o = fit!(KMeans(2, 2), eachrow(x))
266266
267267
sort!(o; rev=true) # Order clusters by number of observations
268+
269+
classify(o, x[1]) # returns index of cluster closest to x[1]
268270
"""
269271
mutable struct KMeans{T, C <: NTuple{N, Cluster{T}} where N, W} <: OnlineStat{VectorOb}
270272
value::C
@@ -305,6 +307,8 @@ function _fit!(o::KMeans{T}, x) where {T}
305307
end
306308
end
307309

310+
classify(o::KMeans, x) = findmin(c -> norm(x .- c.value), o.value)[2]
311+
308312
#-----------------------------------------------------------------------# MovingTimeWindow
309313
"""
310314
MovingTimeWindow{T<:TimeType, S}(window::Dates.Period)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ end
299299
o = fit!(KMeans(2), eachrow(ymat))
300300
sort!(o, rev=true)
301301
@test o.value[1].n o.value[2].n
302+
303+
x = [repeat([[1.0, 1.0]], 3); repeat([[-1.0, -1.0]], 3)]
304+
o = fit!(KMeans(2), (ξ for ξ x))
305+
@test classify(o, x[1]) classify(o, x[4])
302306
end
303307
#-----------------------------------------------------------------------# LinReg
304308
@testset "LinReg" begin

0 commit comments

Comments
 (0)