Skip to content

Commit 370ea9b

Browse files
authored
Merge pull request #23 from alan-turing-institute/dev
Add models and tests
2 parents 984db26 + 5296aa4 commit 370ea9b

13 files changed

+1635
-237
lines changed

Project.toml

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
1-
name = "MLJNearestNeighborsInterface"
1+
name = "NearestNeighborModels"
22
uuid = "636a865e-7cf4-491e-846c-de09b730eb36"
3-
authors = ["Sebastian Vollmer <s.vollmer.4@warwick.ac.uk> and contributors"]
4-
version = "0.1.0"
3+
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Sebastian Vollmer <s.vollmer.4@warwick.ac.uk>", "Thibaut Lienart <thibaut.lienart@gmail.com>", "Okon Samuel <okonsamuel50@gmail.com>"]
4+
version = "0.1.1"
55

66
[deps]
7-
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
87
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
9-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10-
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11-
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
12-
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
13-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
9+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1411
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1512
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
16-
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
17-
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
18-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
19-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
21-
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
2213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2314
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2415
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
25-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2616

2717
[compat]
18+
Distances = "^0.9,^0.10"
19+
FillArrays = "^0.9"
20+
MLJModelInterface = "^0.3.5, ^0.4"
21+
NearestNeighbors = "^0.4"
22+
StatsBase = "^0.33"
23+
Tables = "^1.2"
2824
julia = "1"
2925

3026
[extras]
3127
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
28+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
29+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3230
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3331

3432
[targets]
35-
test = ["MLJBase", "Test"]
33+
test = ["MLJBase", "OffsetArrays", "StableRNGs", "Test"]

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
# NearestNeighborsModels
1+
# NearestNeighborModels
2+
3+
Repository implementing MLJ interface and weighting kernels for
4+
[NearestNeighbors](https://github.com/KristofferC/NearestNeighbors.jl) models.
5+
6+
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://vollmersj.github.io/NearestNeighborModels.jl/stable)
7+
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://vollmersj.github.io/NearestNeighborModels.jl/dev)
8+
[![Build Status](https://github.com/alan-turing-institute/NearestNeighborModels.jl/workflows/CI/badge.svg)](https://github.com/alan-turing-institute/NearestNeighborModels.jl/actions)
9+
[![Coverage](https://codecov.io/gh/alan-turing-institute/NearestNeighborModels.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/vollmersj/NearestNeighborModels.jl)
210

3-
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://vollmersj.github.io/NearestNeighborsModels.jl/stable)
4-
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://vollmersj.github.io/NearestNeighborsModels.jl/dev)
5-
[![Build Status](https://github.com/alan-turing-institute/NearestNeighborsModels.jl/workflows/CI/badge.svg)](https://github.com/alan-turing-institute/NearestNeighborsModels.jl/actions)
6-
[![Coverage](https://codecov.io/gh/alan-turing-institute/NearestNeighborsModels.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/vollmersj/NearestNeighborsModels.jl)

src/MLJNearestNeighborsInterface.jl

Lines changed: 0 additions & 36 deletions
This file was deleted.

src/NearestNeighborModels.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
module NearestNeighborModels
2+
3+
# ===================================================================
4+
# IMPORTS
5+
import MLJModelInterface
6+
import MLJModelInterface: @mlj_model, metadata_model, metadata_pkg,
7+
Table, Continuous, Count, Finite, OrderedFactor, Multiclass
8+
import NearestNeighbors
9+
import StatsBase
10+
import Tables
11+
12+
using Distances
13+
using FillArrays
14+
using LinearAlgebra
15+
using Statistics
16+
17+
# ===================================================================
18+
## EXPORTS
19+
export list_kernels, ColumnTable, DictTable
20+
21+
# Export KNN models
22+
# KNN models are exported automatically by `@mjl_model`
23+
24+
# Re-Export Distance Metrics from `Distances.jl`
25+
export Euclidean, Cityblock, Minkowski, Chebyshev, Hamming, WeightedEuclidean,
26+
WeightedCityblock, WeightedMinkowski
27+
28+
# Export KNN Kernels
29+
export DualU, DualD, Dudani, Fibonacci, Inverse, ISquared, KNNKernel, Macleod, Rank,
30+
ReciprocalRank, UDK, Uniform, UserDefinedKernel, Zavreal
31+
32+
# ===================================================================
33+
## CONSTANTS
34+
const Vec{T} = AbstractVector{T}
35+
const Mat{T} = AbstractMatrix{T}
36+
const Arr{T, N} = AbstractArray{T, N}
37+
const ColumnTable = Tables.ColumnTable
38+
const DictTable = Dict{Symbol, <:AbstractVector}
39+
const MultiUnivariateFinite = Union{DictTable, ColumnTable}
40+
41+
# Define constants for easy referencing of packages
42+
const MMI = MLJModelInterface
43+
const NN = NearestNeighbors
44+
const PKG = "NearestNeighborsModels"
45+
46+
# Definitions of model descriptions for use in model doc-strings.
47+
const KNNRegressorDescription = """
48+
K-Nearest Neighbors regressor: predicts the response associated with a new point
49+
by taking an weighted average of the response of the K-nearest points.
50+
"""
51+
52+
const KNNClassifierDescription = """
53+
K-Nearest Neighbors classifier: predicts the class associated with a new point
54+
by taking a vote over the classes of the K-nearest points.
55+
"""
56+
57+
const KNNCoreFields = """
58+
* `K::Int=5` : number of neighbors
59+
* `algorithm::Symbol = :kdtree` : one of `(:kdtree, :brutetree, :balltree)`
60+
* `metric::Metric = Euclidean()` : a `Metric` object for the distance between points
61+
* `leafsize::Int = 10` : at what number of points to stop splitting the tree
62+
* `reorder::Bool = true` : if true puts points close in distance close in memory
63+
* `weights::KNNKernel=Uniform()` : kernel used in assigning weights to the
64+
k-nearest neighbors for each observation. An instance of one of the types in
65+
`list_kernels()`. User-defined weighting functions can be passed by wrapping the
66+
function in a `UDF` kernel. If sample weights `w` are passed during machine
67+
construction e.g `machine(model, X, y, w)` then the weight assigned to each
68+
neighbor is the product of the `KNNKernel` generated weight and the corresponding
69+
neighbor sample weight.
70+
71+
"""
72+
73+
const SeeAlso = """
74+
See also the
75+
[package documentation](https://github.com/KristofferC/NearestNeighbors.jl).
76+
For more information about the kernels see the paper by Geler et.al
77+
[Comparison of different weighting schemes for the kNN classifier
78+
on time-series data]
79+
(https://perun.pmf.uns.ac.rs/radovanovic/publications/2016-kais-knn-weighting.pdf).
80+
"""
81+
82+
const MultitargetKNNClassifierFields = """
83+
## Keywords Parameters
84+
85+
$KNNCoreFields
86+
* `output_type::Type{<:MultiUnivariateFinite}=DictTable` : One of
87+
(`ColumnTable`, `DictTable`). The type of table type to use for predictions.
88+
Setting to `ColumnTable` might improve performance for narrow tables while setting to
89+
`DictTable` improves performance for wide tables.
90+
91+
$SeeAlso
92+
93+
"""
94+
95+
const KNNFields = """
96+
## Keywords Parameters
97+
98+
$KNNCoreFields
99+
100+
$SeeAlso
101+
102+
"""
103+
104+
# ===================================================================
105+
# Includes
106+
include("utils.jl")
107+
include("kernels.jl")
108+
include("models.jl")
109+
110+
# ===================================================================
111+
# List of all models interfaced
112+
const MODELS = (
113+
KNNClassifier, KNNRegressor, MultitargetKNNRegressor, MultitargetKNNClassifier
114+
)
115+
116+
# ====================================================================
117+
# PKG_METADATA
118+
metadata_pkg.(
119+
MODELS,
120+
name = "NearestNeighborModels",
121+
uuid = "6f286f6a-111f-5878-ab1e-185364afe411",
122+
url = "https://github.com/alan-turing-institute/NearestNeighborModels.jl",
123+
license = "MIT",
124+
julia = true,
125+
is_wrapper = false
126+
)
127+
128+
end # module

0 commit comments

Comments
 (0)