-
Notifications
You must be signed in to change notification settings - Fork 6
Base implementation of SVGP #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 66 commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
3cec6a0
Add files
rossviljoen 6c814a6
Fixed KL and posterior covariance.
rossviljoen 798f77a
Update example to use Flux
rossviljoen 0641423
Remove Flux as a dep & factor out expected_loglik
rossviljoen 1e4fc90
Update example to use basic Flux layer
rossviljoen bb42044
Add minibatching.
rossviljoen 102d812
Improved variance calculation.
rossviljoen 3089d93
Initial quadrature implementation
rossviljoen 59474c5
Moved quadrature to new file.
rossviljoen 25e6627
Fixed AD for quadrature.
rossviljoen 54b5470
Fixed AD for KL divergence.
rossviljoen 5e1c882
Added classification example.
rossviljoen ce20eba
Updated examples.
rossviljoen 359b3d5
Renamed SVGPLayer to SVGPModel.
rossviljoen 3bdbedb
Added basic test structure.
rossviljoen cb3a341
Started equivalence tests
rossviljoen 3a2c8a9
First pass (doesn't work yet)
rossviljoen 005f8f0
Working tests
rossviljoen 443a2d4
Fixed KL divergence
rossviljoen 92da73c
Refactored elbo stuff
rossviljoen 7d05d1b
Fixed elbo mistakes
rossviljoen c0dd737
Remove type restiction in ELBO
rossviljoen 92dcdf5
Infer batch size
rossviljoen f8086c8
Merge branch 'master' into ross/tests
rossviljoen 787c57d
Merge branch 'dev' into base_implementation
rossviljoen ec5fa05
Added docstrings to elbo.jl
rossviljoen 6d4e87b
Added cross-covariance
rossviljoen 22c999a
Removed unnecessary dependencies
rossviljoen 2763972
Updated regression example
rossviljoen 23e5c2e
Added exact posterior tests
rossviljoen 60d5072
Merge pull request #6 from rossviljoen/ross/tests
rossviljoen a8e5cbe
Address review comments
rossviljoen 1bbeae0
Fix docstrings
rossviljoen 1a0782f
Rename kldivergence
rossviljoen eddc7ab
Factor out exact posterior
rossviljoen 7ea3c2f
Use AbstractGPs TestUtils
rossviljoen 9b6557f
Added support for prior mean function
rossviljoen 0e59e49
Added MC expectation and refactored elbo
rossviljoen 38ed15f
Updated docstrings
rossviljoen c8a974f
Dispatch on types instead of symbols
rossviljoen 56507a8
Update doctrings
rossviljoen 857ecc3
Enforce type for MonteCarlo and GaussHermite
rossviljoen 1bbf385
Added error for Analytic
rossviljoen bbd8502
Rename GaussHermite to Quadrature
rossviljoen 0563d01
Assume homoscedastic Gaussian noise
rossviljoen fb9a563
Add tests for `expected_loglik`
rossviljoen e62fbf7
Require ExpLink for Poisson closed form
rossviljoen 36c62b9
Better error message
rossviljoen 0ee1004
Added close form for Gamma and Exponential
rossviljoen f648a7c
Fix docstring
rossviljoen a9b9a57
Update docstring
rossviljoen b8e7d6b
Fix docstring
rossviljoen 9353e44
Restrict types for continuous distributions
rossviljoen ea3d3c6
Use `AbstractGPs.approx_posterior` and `elbo`
rossviljoen c1a4546
Minor formatting
rossviljoen 835da22
Dispatch on filled diagonal matrix obs noise
rossviljoen fa1cdc3
Add elbo tests
rossviljoen af41ca3
Small test changes
rossviljoen de2c4cd
Fix elbo error
rossviljoen f07c6f1
Remove qualifier from kldivergence
rossviljoen 9f4d295
Check for ZeroMean
rossviljoen ca5f148
Fix classification example jitter
rossviljoen 66ec256
Remove unnecessary imports from AbstractGPs
rossviljoen 6841074
Better cholesky of covariance methods
rossviljoen 1594ee8
Use KLDivergences
rossviljoen 878b214
Use vector of marginals `q_f` vs. `f_mean, f_var`
rossviljoen be96722
Ran JuliaFormatter
rossviljoen 39f243a
Revert "Ran JuliaFormatter"
rossviljoen ef3292c
Reformat with JuliaFormatter - BlueStyle
rossviljoen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name = "SparseGPs" | ||
uuid = "298c2ebc-0411-48ad-af38-99e88101b606" | ||
authors = ["Ross Viljoen <ross@viljoen.co.uk>"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" | ||
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Recreation of https://gpflow.readthedocs.io/en/master/notebooks/basics/classification.html | ||
|
||
# %% | ||
using SparseGPs | ||
using AbstractGPs | ||
using GPLikelihoods | ||
using StatsFuns | ||
using FastGaussQuadrature | ||
using Distributions | ||
using LinearAlgebra | ||
using DelimitedFiles | ||
using IterTools | ||
|
||
using Plots | ||
default(; legend=:outertopright, size=(700, 400)) | ||
|
||
using Random | ||
Random.seed!(1234) | ||
|
||
# %% | ||
# Read in the classification data | ||
data_file = pkgdir(SparseGPs) * "/examples/data/classif_1D.csv" | ||
x, y = eachcol(readdlm(data_file)) | ||
scatter(x, y) | ||
|
||
|
||
# %% | ||
# First, create the GP kernel from given parameters k | ||
function make_kernel(k) | ||
return softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2]))) | ||
end | ||
|
||
k = [10, 0.1] | ||
|
||
kernel = make_kernel(k) | ||
f = LatentGP(GP(kernel), BernoulliLikelihood(), 0.1) | ||
fx = f(x) | ||
|
||
|
||
# %% | ||
# Then, plot some samples from the prior underlying GP | ||
x_plot = 0:0.02:6 | ||
prior_f_samples = rand(f.f(x_plot, 1e-6),20) | ||
|
||
plt = plot( | ||
x_plot, | ||
prior_f_samples; | ||
seriescolor="red", | ||
linealpha=0.2, | ||
label="" | ||
) | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
|
||
|
||
# %% | ||
# Plot the same samples, but pushed through a logistic sigmoid to constrain | ||
# them in (0, 1). | ||
prior_y_samples = mean.(f.lik.(prior_f_samples)) | ||
rossviljoen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
plt = plot( | ||
x_plot, | ||
prior_y_samples; | ||
seriescolor="red", | ||
linealpha=0.2, | ||
label="" | ||
) | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
|
||
|
||
# %% | ||
# A simple Flux model | ||
using Flux | ||
|
||
struct SVGPModel | ||
k # kernel parameters | ||
m # variational mean | ||
A # variational covariance | ||
z # inducing points | ||
end | ||
|
||
@Flux.functor SVGPModel (k, m, A,) # Don't train the inducing inputs | ||
|
||
lik = BernoulliLikelihood() | ||
jitter = 1e-4 | ||
|
||
function (m::SVGPModel)(x) | ||
kernel = make_kernel(m.k) | ||
f = LatentGP(GP(kernel), lik, jitter) | ||
q = MvNormal(m.m, m.A'm.A) | ||
fx = f(x) | ||
fu = f(m.z).fx | ||
return fx, fu, q | ||
end | ||
|
||
function flux_loss(x, y; n_data=length(y)) | ||
fx, fu, q = model(x) | ||
return -SparseGPs.elbo(fx, y, fu, q; n_data, method=MonteCarlo()) | ||
end | ||
|
||
# %% | ||
M = 15 # number of inducing points | ||
|
||
# Initialise the parameters | ||
k = [10, 0.1] | ||
m = zeros(M) | ||
A = Matrix{Float64}(I, M, M) | ||
z = x[1:M] | ||
|
||
model = SVGPModel(k, m, A, z) | ||
|
||
opt = ADAM(0.1) | ||
parameters = Flux.params(model) | ||
|
||
# %% | ||
# Negative ELBO before training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# Train the model | ||
Flux.train!( | ||
(x, y) -> flux_loss(x, y), | ||
parameters, | ||
ncycle([(x, y)], 2000), # Train for 1000 epochs | ||
opt | ||
) | ||
|
||
# %% | ||
# Negative ELBO after training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# After optimisation, plot samples from the underlying posterior GP. | ||
fu = f(z).fx # want the underlying FiniteGP | ||
post = SparseGPs.approx_posterior(SVGP(), fu, MvNormal(m, A'A)) | ||
l_post = LatentGP(post, BernoulliLikelihood(), jitter) | ||
|
||
post_f_samples = rand(l_post.f(x_plot, 1e-6), 20) | ||
|
||
rossviljoen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
plt = plot( | ||
x_plot, | ||
post_f_samples; | ||
seriescolor="red", | ||
linealpha=0.2, | ||
legend=false | ||
) | ||
|
||
# %% | ||
# As above, push these samples through a logistic sigmoid to get posterior predictions. | ||
post_y_samples = mean.(l_post.lik.(post_f_samples)) | ||
|
||
plt = plot( | ||
x_plot, | ||
post_y_samples; | ||
seriescolor="red", | ||
linealpha=0.2, | ||
# legend=false, | ||
label="" | ||
) | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
vline!(z; label="Pseudo-points") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
5.668341708542713242e+00 0.000000000000000000e+00 | ||
5.758793969849246075e+00 0.000000000000000000e+00 | ||
5.517587939698492150e+00 0.000000000000000000e+00 | ||
2.954773869346733584e+00 1.000000000000000000e+00 | ||
3.648241206030150785e+00 1.000000000000000000e+00 | ||
2.110552763819095290e+00 1.000000000000000000e+00 | ||
4.613065326633165597e+00 0.000000000000000000e+00 | ||
4.793969849246231263e+00 0.000000000000000000e+00 | ||
4.703517587939698430e+00 0.000000000000000000e+00 | ||
6.030150753768843686e-01 1.000000000000000000e+00 | ||
3.015075376884421843e-01 0.000000000000000000e+00 | ||
3.979899497487437099e+00 0.000000000000000000e+00 | ||
3.226130653266331638e+00 1.000000000000000000e+00 | ||
1.899497487437185939e+00 1.000000000000000000e+00 | ||
1.145728643216080256e+00 1.000000000000000000e+00 | ||
3.316582914572864249e-01 0.000000000000000000e+00 | ||
6.030150753768843686e-01 1.000000000000000000e+00 | ||
2.231155778894472252e+00 1.000000000000000000e+00 | ||
3.256281407035175768e+00 1.000000000000000000e+00 | ||
1.085427135678391997e+00 1.000000000000000000e+00 | ||
1.809045226130653106e+00 1.000000000000000000e+00 | ||
4.492462311557789079e+00 0.000000000000000000e+00 | ||
1.959798994974874198e+00 1.000000000000000000e+00 | ||
0.000000000000000000e+00 0.000000000000000000e+00 | ||
3.346733668341708601e+00 1.000000000000000000e+00 | ||
1.507537688442210921e-01 0.000000000000000000e+00 | ||
1.809045226130653328e-01 1.000000000000000000e+00 | ||
5.517587939698492150e+00 0.000000000000000000e+00 | ||
2.201005025125628123e+00 1.000000000000000000e+00 | ||
5.577889447236180409e+00 0.000000000000000000e+00 | ||
1.809045226130653328e-01 0.000000000000000000e+00 | ||
1.688442211055276365e+00 1.000000000000000000e+00 | ||
4.160804020100502321e+00 0.000000000000000000e+00 | ||
2.170854271356783993e+00 1.000000000000000000e+00 | ||
4.311557788944723413e+00 0.000000000000000000e+00 | ||
3.075376884422110546e+00 1.000000000000000000e+00 | ||
5.125628140703517133e+00 0.000000000000000000e+00 | ||
1.989949748743718549e+00 1.000000000000000000e+00 | ||
5.366834170854271058e+00 0.000000000000000000e+00 | ||
4.100502512562814061e+00 0.000000000000000000e+00 | ||
7.236180904522613311e-01 1.000000000000000000e+00 | ||
2.261306532663316382e+00 1.000000000000000000e+00 | ||
3.467336683417085119e+00 1.000000000000000000e+00 | ||
1.085427135678391997e+00 1.000000000000000000e+00 | ||
5.095477386934673447e+00 0.000000000000000000e+00 | ||
5.185929648241205392e+00 0.000000000000000000e+00 | ||
2.743718592964823788e+00 1.000000000000000000e+00 | ||
2.773869346733668362e+00 1.000000000000000000e+00 | ||
1.417085427135678311e+00 1.000000000000000000e+00 | ||
1.989949748743718549e+00 1.000000000000000000e+00 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.