Skip to content

Commit a9ceadd

Browse files
committed
rewrite some supertypes to fix fit methods with eachrow, eachslice, etc.
1 parent c4e1a5c commit a9ceadd

File tree

9 files changed

+19
-296
lines changed

9 files changed

+19
-296
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OnlineStats"
22
uuid = "a15396b6-48d5-5d58-9928-6d29437db91e"
3-
version = "1.6.1"
3+
version = "1.6.2"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/OnlineStats.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ export
5858
OnlineStat, BiasVec
5959

6060
#-----------------------------------------------------------------------------# utils
61-
const Tup = Union{Tuple, NamedTuple}
62-
const VectorOb = Union{AbstractVector, Tup}
63-
const XY{T,S} = Union{Tuple{T,S}, Pair{T,S}, NamedTuple{names,Tuple{T,S}}} where {names,T<:AbstractVector{<:Number},S<:Number}
61+
const Tup{T} = Union{NTuple{N,T} where {N}, NamedTuple{names, Tuple{N,<:T} where {N}} where {names}}
62+
const VectorOb{T} = Union{AbstractVector{<:T}, Tup{T}}
63+
const XY{T,S} = Union{Tuple{T,S}, Pair{T,S}, NamedTuple{names,Tuple{T,S}}} where {names,T<:VectorOb{Number},S<:Number}
6464

6565
const ϵ = 1e-7 # avoid dividing by 0 in some cases
6666

src/stats/distributions.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ distribution is returned as 1.
166166
x = [1 2 3; 4 8 12]
167167
fit!(FitMultinomial(3), x)
168168
"""
169-
mutable struct FitMultinomial{T} <: OnlineStat{VectorOb}
169+
mutable struct FitMultinomial{T} <: OnlineStat{VectorOb{Number}}
170170
grp::Group{T}
171171
end
172172

@@ -176,7 +176,7 @@ nobs(o::FitMultinomial) = nobs(o.grp)
176176
function value(o::FitMultinomial)
177177
m = value.(o.grp.stats)
178178
p = length(o.grp)
179-
outvec = all(x-> x==0.0, m) ? ones(p) ./ p : collect(m) ./ sum(m)
179+
outvec = all(iszero, m) ? ones(p) ./ p : collect(m) ./ sum(m)
180180
return 1, outvec
181181
end
182182
_merge!(o::FitMultinomial, o2::FitMultinomial) = _merge!(o.grp, o2.grp)
@@ -192,7 +192,7 @@ Online parameter estimate of a `d`-dimensional MvNormal distribution (MLE).
192192
y = randn(100, 2)
193193
o = fit!(FitMvNormal(2), eachrow(y))
194194
"""
195-
struct FitMvNormal{C <: CovMatrix} <: OnlineStat{VectorOb}
195+
struct FitMvNormal{C <: CovMatrix} <: OnlineStat{VectorOb{Number}}
196196
cov::C
197197
end
198198
FitMvNormal(p::Integer) = FitMvNormal(CovMatrix(p))
@@ -214,4 +214,4 @@ _merge!(o::FitMvNormal, o2::FitMvNormal) = _merge!(o.cov, o2.cov)
214214

215215
Statistics.mean(o::FitMvNormal) = mean(o.cov)
216216
Statistics.var(o::FitMvNormal) = var(o.cov)
217-
Statistics.cov(o::FitMvNormal) = cov(o.cov)
217+
Statistics.cov(o::FitMvNormal) = cov(o.cov)

src/stats/linreg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ parameter `λ`. An intercept (`bias`) term is added by default.
7979
8080
coef(o; y=7, x=[2,5,4])
8181
"""
82-
mutable struct LinRegBuilder{W} <: OnlineStat{VectorOb}
82+
mutable struct LinRegBuilder{W} <: OnlineStat{VectorOb{Number}}
8383
A::Matrix{Float64} # x'x, pretend that x = [x, 1]
8484
weight::W
8585
n::Int

src/stats/nbclassifier.jl

Lines changed: 4 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ nvars(o::NBClassifier) = length(o.init)
106106
nobs(o::NBClassifier) = isempty(o.d) ? 0 : sum(nobs, values(o))
107107
probs(o::NBClassifier) = isempty(o.d) ? zeros(0) : map(nobs, values(o)) ./ nobs(o)
108108

109-
function _predict(o::NBClassifier, x::VectorOb, p = zeros(nkeys(o)), n = nobs(o))
109+
function _predict(o::NBClassifier, x::VectorOb{Number}, p = zeros(nkeys(o)), n = nobs(o))
110110
for (k, gk) in enumerate(values(o))
111111
# P(Ck)
112112
p[k] = log(nobs(gk) / n + ϵ)
@@ -119,7 +119,7 @@ function _predict(o::NBClassifier, x::VectorOb, p = zeros(nkeys(o)), n = nobs(o)
119119
sp = sum(p)
120120
sp == 0.0 ? p : rmul!(p, inv(sp))
121121
end
122-
function _classify(o::NBClassifier, x::VectorOb, p = zeros(nkeys(o)), n = nobs(o))
122+
function _classify(o::NBClassifier, x::VectorOb{Number}, p = zeros(nkeys(o)), n = nobs(o))
123123
_, k = findmax(_predict(o, x, p, n))
124124
index_to_key(o, k)
125125
end
@@ -129,11 +129,11 @@ function index_to_key(d, i)
129129
end
130130
end
131131

132-
predict(o::NBClassifier, x::VectorOb) = _predict(o, x)
132+
predict(o::NBClassifier, x::VectorOb{Number}) = _predict(o, x)
133133
predict(o::NBClassifier, x) = [predict(o, xi) for xi in x]
134134
predict(o::NBClassifier, x::AbstractMatrix) = predict(o, OnlineStatsBase.eachrow(x))
135135

136-
classify(o::NBClassifier, x::VectorOb) = _classify(o, x)
136+
classify(o::NBClassifier, x::VectorOb{Number}) = _classify(o, x)
137137
classify(o::NBClassifier, x) = [classify(o, xi) for xi in x]
138138
classify(o::NBClassifier, x::AbstractMatrix) = classify(o, OnlineStatsBase.eachrow(x))
139139

@@ -151,280 +151,3 @@ function split!(o::NBClassifier)
151151
end
152152

153153
entropy(o::NBClassifier) = entropy(probs(o), 2)
154-
155-
# function split(o::NBClassifier)
156-
# nroot = [nobs(g) for g in values(o)]
157-
# nleft = copy(nroot)
158-
# nright = copy(nroot)
159-
# split = NBSplit(length(nroot))
160-
# entropy_root = entropy(o)
161-
# for j in 1:nvars(o)
162-
# ss = o[j]
163-
# stat = merge(ss)
164-
# for loc in split_candidates(stat)
165-
# for k in 1:nkeys(o)
166-
# nleft[k] = round(Int, n_sent_left(ss[k], loc))
167-
# end
168-
# entropy_left = entropy(nleft ./ sum(nleft))
169-
# @. nright = nroot - nleft
170-
# entropy_right = entropy(nright ./ sum(nright))
171-
# entropy_after = smooth(entropy_right, entropy_left, sum(nleft) / sum(nroot))
172-
# ig = entropy_root - entropy_after
173-
# if ig > split.ig
174-
# split.j = j
175-
# split.at = loc
176-
# split.ig = ig
177-
# split.nleft .= nleft
178-
# end
179-
# end
180-
# end
181-
# left = NBClassifier(collect(keys(o)), o.init)
182-
# right = NBClassifier(collect(keys(o)), o.init)
183-
# for (i, g) in enumerate(values(left.d))
184-
# g.nobs = split.nleft[i]
185-
# end
186-
# for (i, g) in enumerate(values(right.d))
187-
# g.nobs = nroot[i] - split.nleft[i]
188-
# end
189-
# o, split, left, right
190-
# end
191-
192-
# #-----------------------------------------------------------------------# NBClassifier
193-
# """
194-
# NBClassifier(group::Group, labeltype::Type)
195-
196-
# Create a naive bayes classifier, using the stats in `group` to approximate the
197-
# distributions of each predictor variable conditioned on label.
198-
199-
# - For continuous variables, use [`Hist(nbin)`](@ref).
200-
# - For categorical variables, use [`CountMap(T)`](@ref).
201-
202-
# # Example
203-
204-
# x = randn(10^5, 10)
205-
# y = rand(1:5, 10^5)
206-
# o = NBClassifier(10Hist(20), Float64)
207-
# series((x, y), o)
208-
# predict(o, x)
209-
# classify(o, x)
210-
# """
211-
# struct NBClassifier{T, G <: Group} <: ExactStat{(1, 0)}
212-
# d::OrderedDict{T, G} # class => group
213-
# init::G # empty group
214-
# end
215-
# NBClassifier(T::Type, g::G) where {G<:Group} = NBClassifier(OrderedDict{T,G}(), g)
216-
# NBClassifier(g::Group, T::Type) = NBClassifier(T, g)
217-
# function NBClassifier(labels::Vector{T}, g::G) where {T, G<:Group}
218-
# NBClassifier(OrderedDict{T, G}(lab=>copy(g) for lab in labels), g)
219-
# end
220-
# NBClassifier(p::Int, T::Type, b=20) = NBClassifier(T, p * Hist(b))
221-
222-
223-
# function Base.show(io::IO, o::NBClassifier)
224-
# print(io, name(o))
225-
# sd = sort(o.d)
226-
# for di in sd
227-
# print(io, "\n > ", first(di), " (", round(nobs(last(di)) / nobs(o), 4), ")")
228-
# end
229-
# end
230-
231-
# Base.keys(o::NBClassifier) = keys(o.d)
232-
# Base.values(o::NBClassifier) = values(o.d)
233-
# Base.haskey(o::NBClassifier, y) = haskey(o.d, y)
234-
# nvars(o::NBClassifier) = length(o.init)
235-
# nkeys(o::NBClassifier) = length(o.d)
236-
# nobs(o::NBClassifier) = sum(nobs, values(o))
237-
# probs(o::NBClassifier) = [nobs(g) for g in values(o)] ./ nobs(o)
238-
# Base.getindex(o::NBClassifier, j) = [stat[j] for stat in values(o)]
239-
240-
# # d is an object that iterates keys in known order
241-
# function index_to_key(d, i)
242-
# for (k, ky) in enumerate(keys(d))
243-
# k == i && return ky
244-
# end
245-
# end
246-
247-
# function fit!(o::NBClassifier, xy, γ)
248-
# x, y = xy
249-
# if haskey(o, y)
250-
# g = o.d[y]
251-
# fit!(g, x, 1 / (nobs(g) + 1))
252-
# else
253-
# o.d[y] = fit!(copy(o.init), x, 1.0)
254-
# end
255-
# end
256-
# entropy(o::NBClassifier) = entropy(probs(o), 2)
257-
258-
# function _predict(o::NBClassifier, x::VectorOb, p = zeros(nkeys(o)), n = nobs(o))
259-
# for (k, gk) in enumerate(values(o))
260-
# # P(Ck)
261-
# p[k] = log(nobs(gk) / n + ϵ)
262-
# # P(xj | Ck)
263-
# for j in 1:length(x)
264-
# p[k] += log(pdf(gk[j], x[j]) + ϵ)
265-
# end
266-
# p[k] = exp(p[k])
267-
# end
268-
# sp = sum(p)
269-
# sp == 0.0 ? p : p ./= sum(p)
270-
# end
271-
# function _classify(o::NBClassifier, x::VectorOb, p = zeros(nkeys(o)), n = nobs(o))
272-
# _, k = findmax(_predict(o, x, p, n))
273-
# index_to_key(o, k)
274-
# end
275-
# predict(o::NBClassifier, x::VectorOb) = _predict(o, x)
276-
# classify(o::NBClassifier, x::VectorOb) = _classify(o, x)
277-
# function classify_node(o::NBClassifier)
278-
# _, k = findmax([nobs(g) for g in values(o)])
279-
# index_to_key(o, k)
280-
# end
281-
# for f in [:(_predict), :(_classify)]
282-
# @eval begin
283-
# function $f(o::NBClassifier, x::AbstractMatrix, ::Rows = Rows())
284-
# n = nobs(o)
285-
# p = zeros(nkeys(o))
286-
# mapslices(xi -> $f(o, xi, p, n), x, 2)
287-
# end
288-
# function $f(o::NBClassifier, x::AbstractMatrix, ::Cols)
289-
# n = nobs(o)
290-
# p = zeros(nkeys(o))
291-
# mapslices(xi -> $f(o, xi, p, n), x, 1)
292-
# end
293-
# end
294-
# end
295-
296-
# function split(o::NBClassifier)
297-
# nroot = [nobs(g) for g in values(o)]
298-
# nleft = copy(nroot)
299-
# nright = copy(nroot)
300-
# split = NBSplit(length(nroot))
301-
# entropy_root = entropy(o)
302-
# for j in 1:nvars(o)
303-
# ss = o[j]
304-
# stat = merge(ss)
305-
# for loc in split_candidates(stat)
306-
# for k in 1:nkeys(o)
307-
# nleft[k] = round(Int, n_sent_left(ss[k], loc))
308-
# end
309-
# entropy_left = entropy(nleft ./ sum(nleft))
310-
# @. nright = nroot - nleft
311-
# entropy_right = entropy(nright ./ sum(nright))
312-
# entropy_after = smooth(entropy_right, entropy_left, sum(nleft) / sum(nroot))
313-
# ig = entropy_root - entropy_after
314-
# if ig > split.ig
315-
# split.j = j
316-
# split.at = loc
317-
# split.ig = ig
318-
# split.nleft .= nleft
319-
# end
320-
# end
321-
# end
322-
# left = NBClassifier(collect(keys(o)), o.init)
323-
# right = NBClassifier(collect(keys(o)), o.init)
324-
# for (i, g) in enumerate(values(left.d))
325-
# g.nobs = split.nleft[i]
326-
# end
327-
# for (i, g) in enumerate(values(right.d))
328-
# g.nobs = nroot[i] - split.nleft[i]
329-
# end
330-
# o, split, left, right
331-
# end
332-
333-
# n_sent_left(o::Union{OrderStats, Hist}, loc) = sum(o, loc)
334-
# n_sent_left(o::CountMap, label) = o[label]
335-
336-
# #-----------------------------------------------------------------------# NBSplit
337-
# # Continuous: x[j] < at
338-
# # Categorical: x[j] == at
339-
# mutable struct NBSplit{}
340-
# j::Int
341-
# at::Any
342-
# ig::Float64
343-
# nleft::Vector{Int}
344-
# end
345-
# NBSplit(n=0) = NBSplit(0, -Inf, -Inf, zeros(Int, n))
346-
347-
# whichchild(o::NBSplit, x) = x[o.j] < o.at ? 1 : 2
348-
349-
# #-----------------------------------------------------------------------# NBNode
350-
# mutable struct NBNode{T <: NBClassifier} <: ExactStat{(1, 0)}
351-
# nbc::T
352-
# id::Int
353-
# parent::Int
354-
# children::Vector{Int}
355-
# split::NBSplit
356-
# end
357-
# function NBNode(o::NBClassifier; id = 1, parent = 0, children = Int[], split = NBSplit())
358-
# NBNode(o, id, parent, children, split)
359-
# end
360-
# function Base.show(io::IO, o::NBNode)
361-
# print(io, "NBNode ", o.id)
362-
# if o.split.j > 0
363-
# print(io, " (split on $(o.split.j)")
364-
# end
365-
# end
366-
367-
# #-----------------------------------------------------------------------# NBTree
368-
# """
369-
# NBTree(o::NBClassifier; maxsize=5000, splitsize=1000)
370-
371-
# Create a decision tree where each node is a naive bayes classifier. A node will split
372-
# when it reaches `splitsize` observations and no more splits will occur once `maxsize`
373-
# nodes are in the tree.
374-
375-
# # Example
376-
377-
# x = randn(10^5, 10)
378-
# y = rand(Bool, 10^5)
379-
# o = NBTree(NBClassifier(10Hist(20), Bool))
380-
# series((x,y), o)
381-
# classify(o, x)
382-
# """
383-
# mutable struct NBTree{T<:NBNode} <: ExactStat{(1, 0)}
384-
# tree::Vector{T}
385-
# maxsize::Int
386-
# splitsize::Int
387-
# end
388-
# function NBTree(o::NBClassifier; maxsize = 5000, splitsize = 1000)
389-
# NBTree([NBNode(o)], maxsize, splitsize)
390-
# end
391-
# NBTree(args...; kw...) = NBTree(NBClassifier(args...); kw...)
392-
# function Base.show(io::IO, o::NBTree)
393-
# print(io, "NBTree(size = $(length(o.tree)), splitsize=$(o.splitsize))")
394-
# end
395-
396-
# function fit!(o::NBTree, xy, γ)
397-
# x, y = xy
398-
# i, node = whichleaf(o, x)
399-
# fit!(node.nbc, xy, γ)
400-
# if length(o.tree) < o.maxsize && nobs(node.nbc) >= o.splitsize
401-
# nbc, spl, left_nbc, right_nbc = split(node.nbc)
402-
# # if spl.ig > o.cp
403-
# node.split = spl
404-
# node.children = [length(o.tree) + 1, length(o.tree) + 2]
405-
# t = length(o.tree)
406-
# left = NBNode(left_nbc, id = t + 1, parent = i)
407-
# right = NBNode(right_nbc, id = t + 2, parent = i)
408-
# push!(o.tree, left)
409-
# push!(o.tree, right)
410-
# # end
411-
# end
412-
# end
413-
414-
# function whichleaf(o::NBTree, x::VectorOb)
415-
# i = 1
416-
# node = o.tree[i]
417-
# while length(node.children) > 0
418-
# i = node.children[whichchild(node.split, x)]
419-
# node = o.tree[i]
420-
# end
421-
# i, node
422-
# end
423-
424-
# function classify(o::NBTree, x::VectorOb)
425-
# i, node = whichleaf(o, x)
426-
# classify_node(node.nbc)
427-
# end
428-
# function classify(o::NBTree, x::AbstractMatrix)
429-
# mapslices(xi -> classify(o, xi), x, 2)
430-
# end

src/stats/statlearn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@ function _merge!(o::StatLearn, o2::StatLearn)
221221
smooth!(o.λ, o2.λ, γ)
222222
end
223223

224-
predict(o::StatLearn, x::VectorOb) = dot(x, o.β)
224+
predict(o::StatLearn, x::VectorOb{Number}) = dot(x, o.β)
225225
predict(o::StatLearn, x::AbstractMatrix) = x * o.β
226226
classify(o::StatLearn, x) = sign.(predict(o, x))
227227

228-
function objective(o::StatLearn, x::AbstractMatrix, y::VectorOb)
228+
function objective(o::StatLearn, x::AbstractMatrix, y::VectorOb{Number})
229229
mean(o.loss.(y, predict(o,x))) + sum(o.λ .* o.penalty.(o.β))
230230
end
231231

src/stats/stats.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ Approximate K-Means clustering of `k` clusters.
268268
269269
classify(o, x[1]) # returns index of cluster closest to x[1]
270270
"""
271-
mutable struct KMeans{T, C <: NTuple{N, Cluster{T}} where N, W} <: OnlineStat{VectorOb}
271+
mutable struct KMeans{T, C <: NTuple{N, Cluster{T}} where N, W} <: OnlineStat{VectorOb{Number}}
272272
value::C
273273
buffer::Vector{T}
274274
rate::W

0 commit comments

Comments
 (0)