diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 9c79359112..b345310472 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,4 @@ style = "sciml" -format_markdown = true \ No newline at end of file +annotate_untyped_fields_with_any = false +format_markdown = true +separate_kwargs_with_semicolon = true \ No newline at end of file diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5090c18581..27abbe5de6 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -27,6 +27,7 @@ steps: using Pkg Pkg.instantiate() Pkg.activate("docs") + Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate() push!(LOAD_PATH, @__DIR__) println("+++ :julia: Building documentation") diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6036d0f045..276605cf58 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,8 +21,10 @@ jobs: - BasicNeuralDE - AdvancedNeuralDE - Newton + - Aqua version: - - 1 + - '1' + - '~1.10.0-0' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 07964dae70..5e1829c1e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,53 +1,53 @@ name = "DiffEqFlux" uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" authors = ["Chris Rackauckas "] -version = "2.5.1" +version = "3.0.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" -DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ADTypes = "0.2" Adapt = "3" ChainRulesCore = "1" -ConsoleProgressMonitor = "0.1" -DataInterpolations = "3.3, 4" +ComponentArrays = "0.15.5" +ConcreteStructs = "0.2" DiffEqBase = "6.41" Distributions = "0.23, 0.24, 0.25" DistributionsAD = "0.6" -Flux = "0.14" ForwardDiff = "0.10" Functors = "0.4" -LoggingExtras = "0.4, 1" +LinearAlgebra = "<0.0.1, 1" +Lux = "0.5.5" LuxCore = "0.1" -ProgressLogging = "0.1" +PrecompileTools = "1" +Random = "<0.0.1, 1" RecursiveArrayTools = "2" Reexport = "0.2, 1" SciMLBase = "1, 2" SciMLSensitivity = "7" -TerminalLoggers = "0.1" +Tracker = "0.2.29" Zygote = "0.5, 0.6" ZygoteRules = "0.2" julia = "1.9" diff --git a/README.md b/README.md index fc52731ef7..100b73a978 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,16 @@ [![Build Status](https://github.com/SciML/DiffEqFlux.jl/workflows/CI/badge.svg)](https://github.com/SciML/DiffEqFlux.jl/actions?query=workflow%3ACI) [![Build status](https://badge.buildkite.com/a1fecf87b085b452fe0f3d3968ddacb5c1d5570806834e1d52.svg)](https://buildkite.com/julialang/diffeqflux-dot-jl) -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -DiffEqFlux.jl fuses the world of differential equations with machine learning +DiffEq(For)Lux.jl (aka DiffEqFlux.jl) fuses the world of differential equations with machine learning by helping users put diffeq solvers into neural networks. This package utilizes -[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/), -[Flux.jl](https://docs.sciml.ai/Flux/stable/) and [Lux.jl](https://lux.csail.mit.edu/v0.5.5/api/) as its building blocks to support research in -[Scientific Machine Learning](https://www.stochasticlifestyle.com/the-essential-tools-of-scientific-machine-learning-scientific-ml/), -specifically neural differential equations and universal differential equations, -to add physical information into traditional machine learning. +[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/), and [Lux.jl](https://lux.csail.mit.edu/) as its building blocks to support research in +[Scientific Machine Learning](https://www.stochasticlifestyle.com/the-essential-tools-of-scientific-machine-learning-scientific-ml/), specifically neural differential equations and universal differential equations, to add physical information into traditional machine learning. + +> [!NOTE] +> We maintain backwards compatibility with [Flux.jl](https://docs.sciml.ai/Flux/stable/) via [Lux.transform](https://lux.csail.mit.edu/stable/api/Lux/flux_to_lux#Lux.transform) ## Tutorials and Documentation @@ -40,12 +40,12 @@ Scientific Machine Learning](https://arxiv.org/abs/2001.04385). As such, it is the first package to support and demonstrate: -- Stiff and non-stiff universal ordinary differential equations (universal ODEs) -- Universal stochastic differential equations (universal SDEs) -- Universal delay differential equations (universal DDEs) -- Universal partial differential equations (universal PDEs) -- Universal jump stochastic differential equations (universal jump diffusions) -- Hybrid universal differential equations (universal DEs with event handling) + - Stiff and non-stiff universal ordinary differential equations (universal ODEs) + - Universal stochastic differential equations (universal SDEs) + - Universal delay differential equations (universal DDEs) + - Universal partial differential equations (universal PDEs) + - Universal jump stochastic differential equations (universal jump diffusions) + - Hybrid universal differential equations (universal DEs with event handling) with high order, adaptive, implicit, GPU-accelerated, Newton-Krylov, etc. methods. For examples, please refer to @@ -58,10 +58,18 @@ PDEs and neural jump SDEs, can be found Do not limit yourself to the current neuralization. With this package, you can explore various ways to integrate the two methodologies: -- Neural networks can be defined where the “activations” are nonlinear functions - described by differential equations -- Neural networks can be defined where some layers are ODE solves -- ODEs can be defined where some terms are neural networks -- Cost functions on ODEs can define neural networks + - Neural networks can be defined where the “activations” are nonlinear functions + described by differential equations + - Neural networks can be defined where some layers are ODE solves + - ODEs can be defined where some terms are neural networks + - Cost functions on ODEs can define neural networks ![Flux ODE Training Animation](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) + +## Breaking Changes in v3 + + - Flux dependency is dropped. If a non Lux `AbstractExplicitLayer` is passed we try to automatically convert it to a Lux model with `Lux.transform(model)`. + - `Flux` is no longer re-exported from `DiffEqFlux`. Instead we reexport `Lux`. + - `NeuralDAE` now allows an optional `du0` as input. + - `TensorLayer` is now a Lux Neural Network. + - APIs for quite a few layer constructions have changed. Please refer to the updated documentation for more details. diff --git a/docs/Project.toml b/docs/Project.toml index bca0893bd4..bd5b068ff3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,10 +1,10 @@ [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -13,8 +13,11 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" @@ -22,12 +25,15 @@ OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -35,17 +41,20 @@ CSV = "0.10" ComponentArrays = "0.13, 0.14, 0.15" DataDeps = "0.7" DataFrames = "1" -DiffEqFlux = "2" -DifferentialEquations = "7.6.0" +DiffEqFlux = "3" Distances = "0.10.7" Distributions = "0.25.78" Documenter = "1" Flux = "0.14" ForwardDiff = "0.10" IterTools = "1" -Lux = "0.4.34, 0.5" +LinearAlgebra = "<0.0.1, 1" +Lux = "0.5.5" +LuxCUDA = "0.3" MLDataUtils = "0.5" MLDatasets = "0.7" +MLUtils = "0.4" +NNlib = "0.9" Optimisers = "0.2, 0.3" Optimization = "3.9" OptimizationOptimJL = "0.1" @@ -53,9 +62,12 @@ OptimizationOptimisers = "0.1" OptimizationPolyalgorithms = "0.1" OrdinaryDiffEq = "6.31" Plots = "1.36" +Printf = "<0.0.1, 1" +Random = "<0.0.1, 1" ReverseDiff = "1.14" SciMLBase = "1.72, 2" SciMLSensitivity = "7.11" Statistics = "1" StochasticDiffEq = "6.56" +Test = "<0.0.1, 1" Zygote = "0.6.62" diff --git a/docs/make.jl b/docs/make.jl index d9c41acc84..c9d04fc1b4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,7 @@ using Documenter, DiffEqFlux -cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml", force = true) -cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true) +cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true) +cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true) ENV["GKSwstype"] = "100" using Plots @@ -9,18 +9,14 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true include("pages.jl") -makedocs( - sitename = "DiffEqFlux.jl", - authors="Chris Rackauckas et al.", +makedocs(; sitename = "DiffEqFlux.jl", + authors = "Chris Rackauckas et al.", clean = true, doctest = false, linkcheck = true, warnonly = [:docs_block, :missing_docs], modules = [DiffEqFlux], - format = Documenter.HTML(assets = ["assets/favicon.ico"], - canonical="https://docs.sciml.ai/DiffEqFlux/stable/"), - pages=pages -) + format = Documenter.HTML(; assets = ["assets/favicon.ico"], + canonical = "https://docs.sciml.ai/DiffEqFlux/stable/"), + pages = pages) -deploydocs( - repo = "github.com/SciML/DiffEqFlux.jl.git"; - push_preview = true -) +deploydocs(; repo = "github.com/SciML/DiffEqFlux.jl.git", + push_preview = true) diff --git a/docs/pages.jl b/docs/pages.jl index 1d21724e65..a215826cd7 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,29 +1,23 @@ pages = [ - "DiffEqFlux.jl: High Level Scientific Machine Learning (SciML) Pre-Built Architectures" => "index.md", - "Differential Equation Machine Learning Tutorials" => Any[ - "examples/neural_ode.md", - "examples/GPUs.md", - "examples/mnist_neural_ode.md", - "examples/mnist_conv_neural_ode.md", - "examples/augmented_neural_ode.md", - "examples/neural_sde.md", - "examples/collocation.md", - "examples/normalizing_flows.md", - "examples/hamiltonian_nn.md", - "examples/tensor_layer.md", - "examples/multiple_shooting.md", - "examples//neural_ode_weather_forecast.md" - ], - "Layer APIs" => Any[ - "Classical Basis Layers" => "layers/BasisLayers.md", - "Tensor Product Layer" => "layers/TensorLayer.md", - "Continuous Normalizing Flows Layer" => "layers/CNFLayer.md", - "Spline Layer" => "layers/SplineLayer.md", - "Neural Differential Equation Layers" => "layers/NeuralDELayers.md", - "Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md" - ], - "Utility Function APIs" => Any[ - "Smoothed Collocation" => "utilities/Collocation.md", - "Multiple Shooting Functionality" => "utilities/MultipleShooting.md", - ], - ] + "DiffEqFlux.jl: High Level Scientific Machine Learning (SciML) Pre-Built Architectures" => "index.md", + "Differential Equation Machine Learning Tutorials" => Any["examples/neural_ode.md", + "examples/GPUs.md", + "examples/mnist_neural_ode.md", + "examples/mnist_conv_neural_ode.md", + "examples/augmented_neural_ode.md", + "examples/neural_sde.md", + "examples/collocation.md", + "examples/normalizing_flows.md", + "examples/hamiltonian_nn.md", + "examples/tensor_layer.md", + "examples/multiple_shooting.md", + "examples//neural_ode_weather_forecast.md"], + "Layer APIs" => Any["Classical Basis Layers" => "layers/BasisLayers.md", + "Tensor Product Layer" => "layers/TensorLayer.md", + "Continuous Normalizing Flows Layer" => "layers/CNFLayer.md", + "Spline Layer" => "layers/SplineLayer.md", + "Neural Differential Equation Layers" => "layers/NeuralDELayers.md", + "Hamiltonian Neural Network Layer" => "layers/HamiltonianNN.md"], + "Utility Function APIs" => Any["Smoothed Collocation" => "utilities/Collocation.md", + "Multiple Shooting Functionality" => "utilities/MultipleShooting.md"], +] diff --git a/docs/src/examples/GPUs.md b/docs/src/examples/GPUs.md index d0ae477304..0088fccff7 100644 --- a/docs/src/examples/GPUs.md +++ b/docs/src/examples/GPUs.md @@ -2,35 +2,41 @@ Note that the differential equation solvers will run on the GPU if the initial condition is a GPU array. Thus, for example, we can define a neural ODE manually -that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU): +that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU). + +For a detailed discussion on how GPUs need to be setup refer to +[Lux Docs](https://lux.csail.mit.edu/stable/manual/gpu_management). ```julia -using DifferentialEquations, Lux, SciMLSensitivity, ComponentArrays +using OrdinaryDiffEq, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays using Random rng = Random.default_rng() +const cdev = cpu_device() +const gdev = gpu_device() + model = Chain(Dense(2, 50, tanh), Dense(50, 2)) -ps, st = Lux.setup(rng,model) -ps = ps |> ComponentArray |> gpu -st = st |> gpu +ps, st = Lux.setup(rng, model) +ps = ps |> ComponentArray |> gdev +st = st |> gdev dudt(u, p, t) = model(u, p, st)[1] # Simulation interval and intermediary points -tspan = (0f0, 10f0) -tsteps = 0f0:1f-1:10f0 +tspan = (0.0f0, 10.0f0) +tsteps = 0.0f0:1.0f-1:10.0f0 -u0 = Float32[2.0; 0.0] |> gpu +u0 = Float32[2.0; 0.0] |> gdev prob_gpu = ODEProblem(dudt, u0, tspan, ps) # Runs on a GPU -sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps) +sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps) ``` Or we could directly use the neural ODE layer function, like: ```julia using DiffEqFlux: NeuralODE -prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(), saveat = tsteps) +prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(); saveat = tsteps) ``` If one is using `Lux.Chain`, then the computation takes place on the GPU with @@ -39,41 +45,39 @@ If one is using `Lux.Chain`, then the computation takes place on the GPU with ```julia import Lux -dudt2 = Lux.Chain(x -> x.^3, - Lux.Dense(2,50,tanh), - Lux.Dense(50,2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) -u0 = Float32[2.; 0.] |> gpu -p, st = Lux.setup(rng, dudt2) |> gpu +u0 = Float32[2.0; 0.0] |> gdev +p, st = Lux.setup(rng, dudt2) |> gdev -dudt2_(u, p, t) = dudt2(u,p,st)[1] +dudt2_(u, p, t) = dudt2(u, p, st)[1] # Simulation interval and intermediary points -tspan = (0f0, 10f0) -tsteps = 0f0:1f-1:10f0 +tspan = (0.0f0, 10.0f0) +tsteps = 0.0f0:1.0f-1:10.0f0 prob_gpu = ODEProblem(dudt2_, u0, tspan, p) # Runs on a GPU -sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps) +sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps) ``` or via the NeuralODE struct: ```julia -prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) -prob_neuralode_gpu(u0,p,st) +prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) +prob_neuralode_gpu(u0, p, st) ``` ## Neural ODE Example -Here is the full neural ODE example. Note that we use the `gpu` function so that the -same code works on CPUs and GPUs, dependent on `using CUDA`. +Here is the full neural ODE example. Note that we use the `gpu_device` function so that the +same code works on CPUs and GPUs, dependent on `using LuxCUDA`. ```julia -using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq, - Plots, CUDA, SciMLSensitivity, Random, ComponentArrays -import DiffEqFlux: NeuralODE +using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq, + Plots, LuxCUDA, SciMLSensitivity, Random, ComponentArrays +import DiffEqFlux: NeuralODE CUDA.allowscalar(false) # Makes sure no slow operations are occuring @@ -83,26 +87,25 @@ rng = Random.default_rng() u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) # Make the data into a GPU-based array if the user has a GPU -ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps)) - +ode_data = gdev(solve(prob_trueode, Tsit5(); saveat = tsteps)) -dudt2 = Chain(x -> x.^3, Dense(2, 50, tanh), Dense(50, 2)) -u0 = Float32[2.0; 0.0] |> gpu -p, st = Lux.setup(rng, dudt2) -p = p |> ComponentArray |> gpu -st = st |> gpu +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) +u0 = Float32[2.0; 0.0] |> gdev +p, st = Lux.setup(rng, dudt2) +p = p |> ComponentArray |> gdev +st = st |> gdev -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) function predict_neuralode(p) - gpu(first(prob_neuralode(u0,p,st))) + gdev(first(prob_neuralode(u0, p, st))) end function loss_neuralode(p) pred = predict_neuralode(p) @@ -113,25 +116,24 @@ end list_plots = [] iter = 0 callback = function (p, l, pred; doplot = false) - global list_plots, iter - if iter == 0 - list_plots = [] - end - iter += 1 - display(l) - # plot current prediction against data - plt = scatter(tsteps, Array(ode_data[1,:]), label = "data") - scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction") - push!(list_plots, plt) - if doplot - display(plot(plt)) - end - return false + global list_plots, iter + if iter == 0 + list_plots = [] + end + iter += 1 + display(l) + # plot current prediction against data + plt = scatter(tsteps, Array(ode_data[1, :]); label = "data") + scatter!(plt, tsteps, Array(pred[1, :]); label = "prediction") + push!(list_plots, plt) + if doplot + display(plot(plt)) + end + return false end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), adtype) +optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, p) -result_neuralode = Optimization.solve(optprob,Adam(0.05),callback = callback,maxiters = 300) - -``` \ No newline at end of file +result_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 300) +``` diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index e1a7ce5398..f1d54f6bb1 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -3,19 +3,23 @@ ## Copy-Pasteable Code ```@example augneuralode_cp -using DiffEqFlux, DifferentialEquations -using Statistics, LinearAlgebra, Plots -using Flux.Data: DataLoader +using DiffEqFlux, OrdinaryDiffEq, Statistics, LinearAlgebra, Plots, LuxCUDA, Random +using MLUtils, ComponentArrays +using Optimization, OptimizationOptimisers, IterTools + +const cdev = cpu_device() +const gdev = gpu_device() function random_point_in_sphere(dim, min_radius, max_radius) - distance = (max_radius - min_radius) .* (rand(Float32,1) .^ (1f0 / dim)) .+ min_radius - direction = randn(Float32,dim) + distance = (max_radius - min_radius) .* (rand(Float32, 1) .^ (1.0f0 / dim)) .+ + min_radius + direction = randn(Float32, dim) unit_direction = direction ./ norm(direction) return distance .* unit_direction end function concentric_sphere(dim, inner_radius_range, outer_radius_range, - num_samples_inner, num_samples_outer; batch_size = 64) + num_samples_inner, num_samples_outer; batch_size = 64) data = [] labels = [] for _ in 1:num_samples_inner @@ -26,92 +30,99 @@ function concentric_sphere(dim, inner_radius_range, outer_radius_range, push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1)) push!(labels, -ones(1, 1)) end - data = cat(data..., dims=2) - labels = cat(labels..., dims=2) - DataLoader((data |> Flux.gpu, labels |> Flux.gpu); batchsize=batch_size, shuffle=true, - partial=false) + data = cat(data...; dims = 2) + labels = cat(labels...; dims = 2) + return DataLoader((data |> gdev, labels |> gdev); batchsize = batch_size, + shuffle = true, partial = false) end -diffeqarray_to_array(x) = reshape(Flux.gpu(x), size(x)[1:2]) +diffeqarray_to_array(x) = reshape(gdev(x), size(x)[1:2]) function construct_model(out_dim, input_dim, hidden_dim, augment_dim) input_dim = input_dim + augment_dim - node = NeuralODE(Flux.Chain(Flux.Dense(input_dim, hidden_dim, relu), - Flux.Dense(hidden_dim, hidden_dim, relu), - Flux.Dense(hidden_dim, input_dim)) |> Flux.gpu, - (0.f0, 1.f0), Tsit5(), save_everystep = false, - reltol = 1f-3, abstol = 1f-3, save_start = false) |> Flux.gpu + node = NeuralODE(Chain(Dense(input_dim, hidden_dim, relu), + Dense(hidden_dim, hidden_dim, relu), + Dense(hidden_dim, input_dim)), (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + reltol = 1.0f-3, abstol = 1.0f-3, save_start = false) node = augment_dim == 0 ? node : AugmentedNDELayer(node, augment_dim) - return Flux.Chain((x, p=node.p) -> node(x, p), - Array, - diffeqarray_to_array, - Flux.Dense(input_dim, out_dim) |> Flux.gpu), node.p |> Flux.gpu + model = Chain(node, diffeqarray_to_array, Dense(input_dim, out_dim)) + ps, st = Lux.setup(Random.default_rng(), model) + return model, ps |> gdev, st |> gdev end -function plot_contour(model, npoints = 300) - grid_points = zeros(Float32, 2, npoints ^ 2) +function plot_contour(model, ps, st, npoints = 300) + grid_points = zeros(Float32, 2, npoints^2) idx = 1 - x = range(-4f0, 4f0, length = npoints) - y = range(-4f0, 4f0, length = npoints) + x = range(-4.0f0, 4.0f0; length = npoints) + y = range(-4.0f0, 4.0f0; length = npoints) for x1 in x, x2 in y grid_points[:, idx] .= [x1, x2] idx += 1 end - sol = reshape(model(grid_points |> Flux.gpu), npoints, npoints) |> Flux.cpu + sol = reshape(model(grid_points |> gdev, ps, st)[1], npoints, npoints) |> cdev - return contour(x, y, sol, fill = true, linewidth=0.0) + return contour(x, y, sol; fill = true, linewidth = 0.0) end -loss_node(x, y) = mean((model(x) .- y) .^ 2) +loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) println("Generating Dataset") -dataloader = concentric_sphere(2, (0f0, 2f0), (3f0, 4f0), 2000, 2000; batch_size = 256) +dataloader = concentric_sphere(2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; + batch_size = 256) iter = 0 -cb = function() - global iter +cb = function (ps, l) + global iter iter += 1 if iter % 10 == 0 - println("Iteration $iter || Loss = $(loss_node(dataloader.data[1], dataloader.data[2]))") + @info "Augmented Neural ODE" iter=iter loss=l end + return false end -model, parameters = construct_model(1, 2, 64, 0) +model, ps, st = construct_model(1, 2, 64, 0) opt = Adam(0.005) +loss_node(model, dataloader.data[1], dataloader.data[2], ps, st) + println("Training Neural ODE") -for _ in 1:10 - Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt, cb = cb) -end +optfunc = OptimizationFunction((x, p, data, target) -> loss_node(model, data, target, x, st), + Optimization.AutoZygote()) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) +res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) -plt_node = plot_contour(model) +plt_node = plot_contour(model, res.u, st) -model, parameters = construct_model(1, 2, 64, 1) -opt = Adam(5f-3) +model, ps, st = construct_model(1, 2, 64, 1) +opt = Adam(0.005) println() println("Training Augmented Neural ODE") -for _ in 1:10 - Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt, cb = cb) -end +optfunc = OptimizationFunction((x, p, data, target) -> loss_node(model, data, target, x, st), + Optimization.AutoZygote()) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) +res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) -plt_anode = plot_contour(model) +plt_node = plot_contour(model, res.u, st) ``` -# Step-by-Step Explanation +## Step-by-Step Explanation -## Loading required packages +### Loading required packages ```@example augneuralode -using DiffEqFlux, DifferentialEquations -using Statistics, LinearAlgebra, Plots -using Flux.Data: DataLoader +using DiffEqFlux, OrdinaryDiffEq, Statistics, LinearAlgebra, Plots, LuxCUDA, Random +using MLUtils, ComponentArrays +using Optimization, OptimizationOptimisers, IterTools + +const cdev = cpu_device() +const gdev = gpu_device() ``` -## Generating a toy dataset +### Generating a toy dataset In this example, we will be using data sampled uniformly in two concentric circles and then train our Neural ODEs to do regression on that values. We assign `1` to any point which lies inside the inner @@ -121,7 +132,8 @@ circle, and `-1` to any point which lies between the inner and outer circle. Our ```@example augneuralode function random_point_in_sphere(dim, min_radius, max_radius) - distance = (max_radius - min_radius) .* (rand(Float32, 1) .^ (1f0 / dim)) .+ min_radius + distance = (max_radius - min_radius) .* (rand(Float32, 1) .^ (1.0f0 / dim)) .+ + min_radius direction = randn(Float32, dim) unit_direction = direction ./ norm(direction) return distance .* unit_direction @@ -133,7 +145,7 @@ and shuffle the data. ```@example augneuralode function concentric_sphere(dim, inner_radius_range, outer_radius_range, - num_samples_inner, num_samples_outer; batch_size = 64) + num_samples_inner, num_samples_outer; batch_size = 64) data = [] labels = [] for _ in 1:num_samples_inner @@ -144,14 +156,14 @@ function concentric_sphere(dim, inner_radius_range, outer_radius_range, push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1)) push!(labels, -ones(1, 1)) end - data = cat(data..., dims=2) - labels = cat(labels..., dims=2) - return DataLoader((data |> Flux.gpu, labels |> Flux.gpu); batchsize=batch_size, shuffle=true, - partial=false) + data = cat(data...; dims = 2) + labels = cat(labels...; dims = 2) + return DataLoader((data |> gdev, labels |> gdev); batchsize = batch_size, + shuffle = true, partial = false) end ``` -## Models +### Models We consider 2 models in this tutorial. The first is a simple Neural ODE which is described in detail in [this tutorial](https://docs.sciml.ai/SciMLSensitivity/stable/examples/neural_ode/neural_ode_flux/). The other one is an @@ -164,97 +176,99 @@ In order to run the models on Flux.gpu, we need to manually transfer the models predicting the derivatives inside the Neural ODE and the other one is the last layer in the Chain. ```@example augneuralode -diffeqarray_to_array(x) = reshape(Flux.gpu(x), size(x)[1:2]) +diffeqarray_to_array(x) = reshape(gdev(x), size(x)[1:2]) function construct_model(out_dim, input_dim, hidden_dim, augment_dim) input_dim = input_dim + augment_dim - node = NeuralODE(Flux.Chain(Flux.Dense(input_dim, hidden_dim, relu), - Flux.Dense(hidden_dim, hidden_dim, relu), - Flux.Dense(hidden_dim, input_dim)) |> Flux.gpu, - (0.f0, 1.f0), Tsit5(), save_everystep = false, - reltol = 1f-3, abstol = 1f-3, save_start = false) |> Flux.gpu - node = augment_dim == 0 ? node : (AugmentedNDELayer(node, augment_dim) |> Flux.gpu) - return Flux.Chain((x, p=node.p) -> node(x, p), - Array, - diffeqarray_to_array, - Flux.Dense(input_dim, out_dim) |> Flux.gpu), node.p |> Flux.gpu + node = NeuralODE(Chain(Dense(input_dim, hidden_dim, relu), + Dense(hidden_dim, hidden_dim, relu), + Dense(hidden_dim, input_dim)), (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + reltol = 1.0f-3, abstol = 1.0f-3, save_start = false) + node = augment_dim == 0 ? node : AugmentedNDELayer(node, augment_dim) + model = Chain(node, diffeqarray_to_array, Dense(input_dim, out_dim)) + ps, st = Lux.setup(Random.default_rng(), model) + return model, ps |> gdev, st |> gdev end ``` -## Plotting the Results +### Plotting the Results Here, we define a utility to plot our model regression results as a heatmap. ```@example augneuralode -function plot_contour(model, npoints = 300) - grid_points = zeros(2, npoints ^ 2) +function plot_contour(model, ps, st, npoints = 300) + grid_points = zeros(Float32, 2, npoints^2) idx = 1 - x = range(-4f0, 4f0, length = npoints) - y = range(-4f0, 4f0, length = npoints) + x = range(-4.0f0, 4.0f0; length = npoints) + y = range(-4.0f0, 4.0f0; length = npoints) for x1 in x, x2 in y grid_points[:, idx] .= [x1, x2] idx += 1 end - sol = reshape(model(grid_points |> Flux.gpu), npoints, npoints) |> Flux.cpu + sol = reshape(model(grid_points |> gdev, ps, st)[1], npoints, npoints) |> cdev - return contour(x, y, sol, fill = true, linewidth=0.0) + return contour(x, y, sol; fill = true, linewidth = 0.0) end ``` -## Training Parameters +### Training Parameters -### Loss Functions +#### Loss Functions We use the L2 distance between the model prediction `model(x)` and the actual prediction `y` as the optimization objective. ```@example augneuralode -loss_node(x, y) = mean((model(x) .- y) .^ 2) +loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) ``` -### Dataset +#### Dataset Next, we generate the dataset. We restrict ourselves to 2 dimensions as it is easy to visualize. We sample a total of `4000` data points. ```@example augneuralode -dataloader = concentric_sphere(2, (0f0, 2f0), (3f0, 4f0), 2000, 2000; batch_size = 256) +dataloader = concentric_sphere(2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; + batch_size = 256) ``` -### Callback Function +#### Callback Function Additionally, we define a callback function which displays the total loss at specific intervals. ```@example augneuralode iter = 0 -cb = function() - global iter += 1 - if iter % 10 == 1 - println("Iteration $iter || Loss = $(loss_node(dataloader.data[1], dataloader.data[2]))") +cb = function (ps, l) + global iter + iter += 1 + if iter % 10 == 0 + @info "Augmented Neural ODE" iter=iter loss=l end + return false end ``` -### Optimizer +#### Optimizer We use Adam as the optimizer with a learning rate of 0.005 ```@example augneuralode -opt = Adam(5f-3) +opt = Adam(5.0f-3) ``` -## Training the Neural ODE +### Training the Neural ODE To train our neural ode model, we need to pass the appropriate learnable parameters, `parameters` which are returned by the `construct_models` function. It is simply the `node.p` vector. We then train our model for `20` epochs. ```@example augneuralode -model, parameters = construct_model(1, 2, 64, 0) +model, ps, st = construct_model(1, 2, 64, 0) -for _ in 1:10 - Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt, cb = cb) -end +optfunc = OptimizationFunction((x, p, data, target) -> loss_node(model, data, target, x, st), + Optimization.AutoZygote()) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) +res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) ``` Here is what the contour plot should look for Neural ODE. Notice that the regression is not perfect due to @@ -269,57 +283,71 @@ input with a single zero. This makes the problem 3-dimensional, and as such it i a function which can be expressed by the neural ode. For more details and proofs, please refer to [1]. ```@example augneuralode -model, parameters = construct_model(1, 2, 64, 1) +model, ps, st = construct_model(1, 2, 64, 1) -for _ in 1:10 - Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt, cb = cb) -end +optfunc = OptimizationFunction((x, p, data, target) -> loss_node(model, data, target, x, st), + Optimization.AutoZygote()) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) +res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) ``` For the augmented Neural ODE we notice that the artifact is gone. ![anode](https://user-images.githubusercontent.com/30564094/85916607-02bcd880-b870-11ea-84fa-d15e24295ea6.png) -# Expected Output +## Expected Output ``` Generating Dataset -Training Neural ODE -Iteration 10 || Loss = 0.9802582 -Iteration 20 || Loss = 0.6727416 -Iteration 30 || Loss = 0.5862373 -Iteration 40 || Loss = 0.5278132 -Iteration 50 || Loss = 0.4867624 -Iteration 60 || Loss = 0.41630346 -Iteration 70 || Loss = 0.3325938 -Iteration 80 || Loss = 0.28235924 -Iteration 90 || Loss = 0.24069068 -Iteration 100 || Loss = 0.20503852 -Iteration 110 || Loss = 0.17608969 -Iteration 120 || Loss = 0.1491399 -Iteration 130 || Loss = 0.12711425 -Iteration 140 || Loss = 0.10686825 -Iteration 150 || Loss = 0.089558244 +┌ Info: Augmented Neural ODE +│ iter = 10 +└ loss = 1.3382126f0 +┌ Info: Augmented Neural ODE +│ iter = 20 +└ loss = 0.7405951f0 +┌ Info: Augmented Neural ODE +│ iter = 30 +└ loss = 0.65393615f0 +┌ Info: Augmented Neural ODE +│ iter = 40 +└ loss = 0.6115348f0 +┌ Info: Augmented Neural ODE +│ iter = 50 +└ loss = 0.5469544f0 +┌ Info: Augmented Neural ODE +│ iter = 60 +└ loss = 0.61832863f0 +┌ Info: Augmented Neural ODE +│ iter = 70 +└ loss = 0.45164242f0 Training Augmented Neural ODE -Iteration 10 || Loss = 1.3911372 -Iteration 20 || Loss = 0.7694144 -Iteration 30 || Loss = 0.5639633 -Iteration 40 || Loss = 0.33187616 -Iteration 50 || Loss = 0.14787851 -Iteration 60 || Loss = 0.094676435 -Iteration 70 || Loss = 0.07363529 -Iteration 80 || Loss = 0.060333826 -Iteration 90 || Loss = 0.04998395 -Iteration 100 || Loss = 0.044843454 -Iteration 110 || Loss = 0.042587914 -Iteration 120 || Loss = 0.042706195 -Iteration 130 || Loss = 0.040252227 -Iteration 140 || Loss = 0.037686247 -Iteration 150 || Loss = 0.036247417 +┌ Info: Augmented Neural ODE +│ iter = 80 +└ loss = 2.5972328f0 +┌ Info: Augmented Neural ODE +│ iter = 90 +└ loss = 0.79345906f0 +┌ Info: Augmented Neural ODE +│ iter = 100 +└ loss = 0.6131873f0 +┌ Info: Augmented Neural ODE +│ iter = 110 +└ loss = 0.36244678f0 +┌ Info: Augmented Neural ODE +│ iter = 120 +└ loss = 0.14108367f0 +┌ Info: Augmented Neural ODE +│ iter = 130 +└ loss = 0.09875094f0 +┌ Info: Augmented Neural ODE +│ iter = 140 +└ loss = 0.060682703f0 +┌ Info: Augmented Neural ODE +│ iter = 150 +└ loss = 0.050104875f0 ``` -# References +## References [1] Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented neural ODEs." In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pp. 3140-3150. 2019. - diff --git a/docs/src/examples/collocation.md b/docs/src/examples/collocation.md index f6fbf62077..d70ac93ee6 100644 --- a/docs/src/examples/collocation.md +++ b/docs/src/examples/collocation.md @@ -1,7 +1,7 @@ # Smoothed Collocation for Fast Two-Stage Training !!! note - + This is one of many methods for calculating the collocation coefficients for the training process. For a more comprehensive set of collocation methods, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/manual/collocation/). @@ -11,7 +11,8 @@ pretraining the neural network against a smoothed collocation of the data. First the example and then an explanation. ```@example collocation_cp -using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationOptimisers, Plots +using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, SciMLSensitivity, Optimization, + OptimizationOptimisers, Plots using Random rng = Random.default_rng() @@ -19,34 +20,32 @@ rng = Random.default_rng() u0 = Float32[2.0; 0.0] datasize = 300 tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) .+ 0.1randn(2,300) +data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) .+ 0.1randn(2, 300) -du,u = collocate_data(data,tsteps,EpanechnikovKernel()) +du, u = collocate_data(data, tsteps, EpanechnikovKernel()) -scatter(tsteps,data') -plot!(tsteps,u',lw=5) +scatter(tsteps, data') +plot!(tsteps, u'; lw = 5) savefig("colloc.png") -plot(tsteps,du') +plot(tsteps, du') savefig("colloc_du.png") -dudt2 = Lux.Chain(x -> x.^3, - Lux.Dense(2, 50, tanh), - Lux.Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) function loss(p) cost = zero(first(p)) - for i in 1:size(du,2) - _du, _ = dudt2(@view(u[:,i]),p, st) - dui = @view du[:,i] - cost += sum(abs2,dui .- _du) + for i in 1:size(du, 2) + _du, _ = dudt2(@view(u[:, i]), p, st) + dui = @view du[:, i] + cost += sum(abs2, dui .- _du) end sqrt(cost) end @@ -54,23 +53,23 @@ end pinit, st = Lux.setup(rng, dudt2) callback = function (p, l) - return false + return false end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype) +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) -result_neuralode = Optimization.solve(optprob, Adam(0.05), callback = callback, maxiters = 10000) +result_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 10000) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) nn_sol, st = prob_neuralode(u0, result_neuralode.u, st) -scatter(tsteps,data') +scatter(tsteps, data') plot!(nn_sol) savefig("colloc_trained.png") function predict_neuralode(p) - Array(prob_neuralode(u0, p, st)[1]) + Array(prob_neuralode(u0, p, st)[1]) end function loss_neuralode(p) @@ -83,14 +82,11 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) -numerical_neuralode = Optimization.solve(optprob, - Adam(0.05), - callback = callback, - maxiters = 300) +numerical_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 300) nn_sol, st = prob_neuralode(u0, numerical_neuralode.u, st) -scatter(tsteps,data') -plot!(nn_sol,lw=5) +scatter(tsteps, data') +plot!(nn_sol; lw = 5) ``` ## Generating the Collocation @@ -99,7 +95,8 @@ The smoothed collocation is a spline fit of the data points which allows us to get an estimate of the approximate noiseless dynamics: ```@example collocation -using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationOptimisers, DifferentialEquations, Plots +using ComponentArrays, + Lux, DiffEqFlux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, Plots using Random rng = Random.default_rng() @@ -107,27 +104,27 @@ rng = Random.default_rng() u0 = Float32[2.0; 0.0] datasize = 300 tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) .+ 0.1randn(2,300) +data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) .+ 0.1randn(2, 300) -du,u = collocate_data(data,tsteps,EpanechnikovKernel()) +du, u = collocate_data(data, tsteps, EpanechnikovKernel()) -scatter(tsteps,data') -plot!(tsteps,u',lw=5) +scatter(tsteps, data') +plot!(tsteps, u'; lw = 5) ``` We can then differentiate the smoothed function to get estimates of the derivative at each data point: ```@example collocation -plot(tsteps,du') +plot(tsteps, du') ``` Because we have `(u',u)` pairs, we can write a loss function that @@ -135,16 +132,14 @@ calculates the squared difference between `f(u,p,t)` and `u'` at each point, and find the parameters which minimize this difference: ```@example collocation -dudt2 = Lux.Chain(x -> x.^3, - Lux.Dense(2, 50, tanh), - Lux.Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) function loss(p) cost = zero(first(p)) - for i in 1:size(du,2) - _du, _ = dudt2(@view(u[:,i]),p, st) - dui = @view du[:,i] - cost += sum(abs2,dui .- _du) + for i in 1:size(du, 2) + _du, _ = dudt2(@view(u[:, i]), p, st) + dui = @view du[:, i] + cost += sum(abs2, dui .- _du) end sqrt(cost) end @@ -152,18 +147,18 @@ end pinit, st = Lux.setup(rng, dudt2) callback = function (p, l) - return false + return false end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype) +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) -result_neuralode = Optimization.solve(optprob, Adam(0.05), callback = callback, maxiters = 10000) +result_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 10000) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) nn_sol, st = prob_neuralode(u0, result_neuralode.u, st) -scatter(tsteps,data') +scatter(tsteps, data') plot!(nn_sol) ``` @@ -174,7 +169,7 @@ initial condition to the next phase of our fitting: ```@example collocation function predict_neuralode(p) - Array(prob_neuralode(u0, p, st)[1]) + Array(prob_neuralode(u0, p, st)[1]) end function loss_neuralode(p) @@ -187,14 +182,11 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) -numerical_neuralode = Optimization.solve(optprob, - Adam(0.05), - callback = callback, - maxiters = 300) +numerical_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 300) nn_sol, st = prob_neuralode(u0, numerical_neuralode.u, st) -scatter(tsteps,data') -plot!(nn_sol,lw=5) +scatter(tsteps, data') +plot!(nn_sol; lw = 5) ``` This method then has a good global starting position, making it less diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index f04edd805d..41e2d5a49a 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -10,9 +10,9 @@ Now we make some simplifying assumptions, and assign ``m = 1`` and ``k = 1``. An ```@example hamiltonian_cp using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, IterTools -t = range(0.0f0, 1.0f0, length=1024) +t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) @@ -22,12 +22,12 @@ dpdt = -2π_32 .* q_t data = vcat(q_t, p_t) target = vcat(dqdt, dpdt) B = 256 -NEPOCHS = 500 +NEPOCHS = 100 dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) + selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) for i in 1:(size(data, 2) ÷ B)), NEPOCHS) -hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1))) +hnn = HamiltonianNN(Chain(Dense(2 => 64, relu), Dense(64 => 1)); ad = AutoZygote()) ps, st = Lux.setup(Random.default_rng(), hnn) ps_c = ps |> ComponentArray @@ -38,20 +38,25 @@ function loss_function(ps, data, target) return mean(abs2, pred .- target), pred end +function callback(ps, loss, pred) + println("[Hamiltonian NN] Loss: ", loss) + return false +end + opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), - Optimization.AutoForwardDiff()) + Optimization.AutoForwardDiff()) opt_prob = OptimizationProblem(opt_func, ps_c) -res = Optimization.solve(opt_prob, opt, dataloader) +res = Optimization.solve(opt_prob, opt, dataloader; callback) ps_trained = res.u -model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, - save_start=true, saveat=t) +model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + save_start = true, saveat = t) pred = Array(first(model(data[:, 1], ps_trained, st))) -plot(data[1, :], data[2, :], lw=4, label="Original") -plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") +plot(data[1, :], data[2, :]; lw = 4, label = "Original") +plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted") xlabel!("Position (q)") ylabel!("Momentum (p)") ``` @@ -63,22 +68,22 @@ ylabel!("Momentum (p)") The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we generate the pairs ``(q, p)`` using the equations given at the top. Additionally, to supervise the training, we also generate the gradients. Next, we use Flux DataLoader for automatically batching our dataset. ```@example hamiltonian -using Flux, DiffEqFlux, DifferentialEquations, Statistics, Plots, ReverseDiff, Random, - IterTools, Lux, ComponentArrays, Optimization, OptimizationOptimisers +using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, + ComponentArrays, Optimization, OptimizationOptimisers, IterTools -t = range(0.0f0, 1.0f0, length = 1024) +t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = cat(q_t, p_t, dims = 1) -target = cat(dqdt, dpdt, dims = 1) +data = cat(q_t, p_t; dims = 1) +target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 500 +NEPOCHS = 100 dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) + selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) for i in 1:(size(data, 2) ÷ B)), NEPOCHS) ``` @@ -87,7 +92,7 @@ dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data We parameterize the HamiltonianNN with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization. ```@example hamiltonian -hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1))) +hnn = HamiltonianNN(Chain(Dense(2 => 64, relu), Dense(64 => 1)); ad = AutoZygote()) ps, st = Lux.setup(Random.default_rng(), hnn) ps_c = ps |> ComponentArray @@ -99,12 +104,12 @@ function loss_function(ps, data, target) end function callback(ps, loss, pred) - println("Loss: ", loss) + println("[Hamiltonian NN] Loss: ", loss) return false end opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), - Optimization.AutoForwardDiff()) + Optimization.AutoZygote()) opt_prob = OptimizationProblem(opt_func, ps_c) res = solve(opt_prob, opt, dataloader; callback) @@ -117,12 +122,12 @@ ps_trained = res.u In order to visualize the learned trajectories, we need to solve the ODE. We will use the `NeuralHamiltonianDE` layer, which is essentially a wrapper over `HamiltonianNN` layer, and solves the ODE. ```@example hamiltonian -model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, - save_start=true, saveat=t) +model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + save_start = true, saveat = t) pred = Array(first(model(data[:, 1], ps_trained, st))) -plot(data[1, :], data[2, :], lw=4, label="Original") -plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") +plot(data[1, :], data[2, :]; lw = 4, label = "Original") +plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted") xlabel!("Position (q)") ylabel!("Momentum (p)") ``` @@ -131,7 +136,7 @@ ylabel!("Momentum (p)") ## Expected Output -```julia +```txt Loss: 19.865715 Loss: 18.196068 Loss: 19.179213 diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index 7b8a015c34..eb79cac50b 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -3,305 +3,89 @@ Training a Convolutional Neural Net Classifier for **MNIST** using a neural ordinary differential equation **NN-ODE** on **GPUs** with **Minibatching**. -(Step-by-step description below) +For a step-by-step tutorial see the tutorial on the MNIST Neural ODE Classification Tutorial +using Fully Connected Layers. ```julia -using DiffEqFlux, DifferentialEquations, Printf -using Flux.Losses: logitcrossentropy -using Flux.Data: DataLoader -using MLDatasets -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs -using CUDA -CUDA.allowscalar(false) - -function loadmnist(batchsize = bs, train_split = 0.9) - # Use MLDataUtils LabelEnc for natural onehot conversion - onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, - LabelEnc.NativeLabels(collect(0:9))) - # Load MNIST - mnist = MNIST(split = :train) - imgs, labels_raw = mnist.features, mnist.targets - # Process images into (H,W,C,BS) batches - x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3))) - y_data = onehot(labels_raw) - (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data), - p = train_split) - return ( - # Use Flux's DataLoader to automatically minibatch and shuffle the data - DataLoader(Flux.gpu.(collect.((x_train, y_train))); batchsize = batchsize, - shuffle = true), - # Don't shuffle the test data - DataLoader(Flux.gpu.(collect.((x_test, y_test))); batchsize = batchsize, - shuffle = false) - ) -end - -# Main -const bs = 128 -const train_split = 0.9 -train_dataloader, test_dataloader = loadmnist(bs, train_split) - -down = Flux.Chain(Flux.Conv((3, 3), 1=>64, relu, stride = 1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, relu, stride = 2, pad=1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, stride = 2, pad = 1)) |> Flux.gpu - -dudt = Flux.Chain(Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1), - Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1)) |> Flux.gpu - -fc = Flux.Chain(Flux.GroupNorm(64, 64), x -> relu.(x), Flux.MeanPool((6, 6)), - x -> reshape(x, (64, :)), Flux.Dense(64,10)) |> Flux.gpu - -nn_ode = NeuralODE(dudt, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu - -function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) - return xarr[:,:,:,:,1] -end - -# Build our over-all model topology -model = Flux.Chain(down, # (28, 28, 1, BS) -> (6, 6, 64, BS) - nn_ode, # (6, 6, 64, BS) -> (6, 6, 64, BS, 1) - DiffEqArray_to_Array, # (6, 6, 64, BS, 1) -> (6, 6, 64, BS) - fc) # (6, 6, 64, BS) -> (10, BS) - -# To understand the intermediate NN-ODE layer, we can examine it's dimensionality -img, lab = train_dataloader.data[1][:, :, :, 1:1], train_dataloader.data[2][:, 1:1] - -x_d = down(img) - -# We can see that we can compute the forward pass through the NN topology -# featuring an NNODE layer. -x_m = model(img) - -classify(x) = argmax.(eachcol(x)) - -function accuracy(model, data; n_batches = 100) - total_correct = 0 - total = 0 - for (i, (x, y)) in enumerate(data) - # Only evaluate accuracy for n_batches - i > n_batches && break - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) - total_correct += sum(target_class .== predicted_class) - total += length(target_class) - end - return total_correct / total -end - -# burn in accuracy -accuracy(model, train_dataloader) - -loss(x, y) = logitcrossentropy(model(x), y) - -# burn in loss -loss(img, lab) - -opt = Adam(0.05) -iter = 0 - -callback() = begin - global iter += 1 - # Monitor that the weights do infact update - # Every 10 training iterations show accuracy - if iter % 10 == 1 - train_accuracy = accuracy(model, train_dataloader) * 100 - test_accuracy = accuracy(model, test_dataloader; - n_batches = length(test_dataloader)) * 100 - @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n", - iter, train_accuracy, test_accuracy) - end -end - -Flux.train!(loss, Flux.params(down, nn_ode.p, fc), train_dataloader, opt, cb = callback) -``` +using DiffEqFlux, Statistics, + ComponentArrays, CUDA, Zygote, MLDatasets, OrdinaryDiffEq, Printf, Test, LuxCUDA, Random +using Optimization, OptimizationOptimisers +using MLDatasets: MNIST +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview +using OneHotArrays +const cdev = cpu_device() +const gdev = gpu_device() -## Step-by-Step Description - -### Load Packages - -```julia -using DiffEqFlux, DifferentialEquations, Printf -using Flux.Losses: logitcrossentropy -using Flux.Data: DataLoader -using MLDatasets -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs -``` - -### GPU -A good trick used here: - -```julia -using CUDA CUDA.allowscalar(false) -``` - -Ensures that only optimized kernels are called when using the GPU. -Additionally, the `gpu` function is shown as a way to translate models and data over to the GPU. -Note that this function is CPU-safe, so if the GPU is disabled or unavailable, this -code will fallback to the CPU. - -### Load MNIST Dataset into Minibatches - -The preprocessing is done in `loadmnist` where the raw MNIST data is split into features `x_train` -and labels `y_train` by specifying batchsize `bs`. The function `convertlabel` will then transform -the current labels (`labels_raw`) from numbers 0 to 9 (`LabelEnc.NativeLabels(collect(0:9))`) into -one hot encoding (`LabelEnc.OneOfK`). +ENV["DATADEPS_ALWAYS_ACCEPT"] = true -Features are reshaped into format **[Height, Width, Color, BatchSize]** or in this case **[28, 28, 1, 128]** -meaning that every minibatch will contain 128 images with a single color channel of 28x28 pixels. -The entire dataset of 60,000 images is split into the train and test dataset, ensuring a balanced ratio -of labels. These splits are then passed to Flux's DataLoader. This automatically minibatches both the images and -labels. Additionally, it allows us to shuffle the train dataset in each epoch while keeping the order of the -test data the same. +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) -```julia -function loadmnist(batchsize = bs, train_split = 0.9) +function loadmnist(batchsize = bs) # Use MLDataUtils LabelEnc for natural onehot conversion - onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, - LabelEnc.NativeLabels(collect(0:9))) + function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + end # Load MNIST - mnist = MNIST(split = :train) + mnist = MNIST(; split = :train) imgs, labels_raw = mnist.features, mnist.targets # Process images into (H,W,C,BS) batches - x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3))) - y_data = onehot(labels_raw) - (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data), - p = train_split) - return ( - # Use Flux's DataLoader to automatically minibatch and shuffle the data - DataLoader(Flux.gpu.(collect.((x_train, y_train))); batchsize = batchsize, - shuffle = true), - # Don't shuffle the test data - DataLoader(Flux.gpu.(collect.((x_test, y_test))); batchsize = batchsize, - shuffle = false) - ) + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> + gdev + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> gdev + y_train = batchview(y_train, batchsize) + return x_train, y_train end -``` -and then loaded from main: -```julia # Main const bs = 128 -const train_split = 0.9 -train_dataloader, test_dataloader = loadmnist(bs, train_split) -``` - - -### Layers +x_train, y_train = loadmnist(bs) -The Neural Network requires passing inputs sequentially through multiple layers. We use -`Chain` which allows inputs to functions to come from previous layer and sends the outputs -to the next. Four different sets of layers are used here: +down = Chain(Conv((3, 3), 1 => 64, relu; stride = 1), GroupNorm(64, 64), + Conv((4, 4), 64 => 64, relu; stride = 2, pad = 1), GroupNorm(64, 64), + Conv((4, 4), 64 => 64; stride = 2, pad = 1)) +dudt = Chain(Conv((3, 3), 64 => 64, tanh; stride = 1, pad = 1), + Conv((3, 3), 64 => 64, tanh; stride = 1, pad = 1)) -```julia -down = Flux.Chain(Flux.Conv((3, 3), 1=>64, relu, stride = 1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, relu, stride = 2, pad=1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, stride = 2, pad = 1)) |> Flux.gpu - -dudt = Flux.Chain(Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1), - Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1)) |> Flux.gpu - -fc = Flux.Chain(Flux.GroupNorm(64, 64), x -> relu.(x), Flux.MeanPool((6, 6)), - x -> reshape(x, (64, :)), Flux.Dense(64,10)) |> Flux.gpu - -nn_ode = NeuralODE(dudt, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu -``` +fc = Chain(GroupNorm(64, 64), x -> relu.(x), MeanPool((6, 6)), + x -> reshape(x, (64, :)), Dense(64, 10)) -`down`: This layer downsamples our images into `6 x 6 x 64` dimensional features. - It takes a 28 x 28 image, and passes it through a convolutional neural network - layer with `relu` activation +nn_ode = NeuralODE(dudt, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + reltol = 1e-3, abstol = 1e-3, save_start = false) -`nn`: A 2 layer Convolutional Neural Network Chain with `tanh` activation which is used to model - our differential equation - -`nn_ode`: ODE solver layer - -`fc`: The final fully connected layer which maps our learned features to the probability of - the feature vector of belonging to a particular class - -`gpu`: A utility function which transfers our model to GPU, if one is available - -### Array Conversion - -When using `NeuralODE`, we can use the following function as a cheap conversion of `DiffEqArray` -from the ODE solver into a Matrix that can be used in the following layer: - -```julia function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) - return xarr[:,:,:,:,1] + xarr = gdev(x) + return xarr[:, :, :, :, 1] end -``` - -For CPU: If this function does not automatically fallback to CPU when no GPU is present, we can -change `gpu(x)` with `Array(x)`. - - -### Build Topology -Next we connect all layers together in a single chain: - -```julia # Build our over-all model topology -model = Flux.Chain(down, # (28, 28, 1, BS) -> (6, 6, 64, BS) - nn_ode, # (6, 6, 64, BS) -> (6, 6, 64, BS, 1) - DiffEqArray_to_Array, # (6, 6, 64, BS, 1) -> (6, 6, 64, BS) - fc) # (6, 6, 64, BS) -> (10, BS) -``` - -There are a few things we can do to examine the inner workings of our neural network: - -```julia -img, lab = train_dataloader.data[1][:, :, :, 1:1], train_dataloader.data[2][:, 1:1] +m = Chain(down, # (28, 28, 1, BS) -> (6, 6, 64, BS) + nn_ode, # (6, 6, 64, BS) -> (6, 6, 64, BS, 1) + DiffEqArray_to_Array, # (6, 6, 64, BS, 1) -> (6, 6, 64, BS) + fc) # (6, 6, 64, BS) -> (10, BS) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> gdev +st = st |> gdev # To understand the intermediate NN-ODE layer, we can examine it's dimensionality -x_d = down(img) - -# We can see that we can compute the forward pass through the NN topology -# featuring an NNODE layer. -x_m = model(img) -``` - -This can also be built without the NN-ODE by replacing `nn-ode` with a simple `nn`: - -```julia -# We can also build the model topology without a NN-ODE -m_no_ode = Flux.Chain(down, nn, fc) |> Flux.gpu - -x_m = m_no_ode(img) -``` - -### Prediction +img = x_train[1][:, :, :, 1:1] |> gdev +lab = x_train[2][:, 1:1] |> gdev -To convert the classification back into readable numbers, we use `classify` which returns the -prediction by taking the arg max of the output for each column of the minibatch: +x_m, _ = m(img, ps, st) -```julia classify(x) = argmax.(eachcol(x)) -``` - -### Accuracy - -We then evaluate the accuracy on `n_batches` at a time through the entire network: -```julia -function accuracy(model, data; n_batches = 100) +function accuracy(model, data, ps, st; n_batches = 10) total_correct = 0 total = 0 - for (i, (x, y)) in enumerate(data) - # Only evaluate accuracy for n_batches - i > n_batches && break - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(cdev(y)) + predicted_class = classify(cdev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -309,69 +93,41 @@ function accuracy(model, data; n_batches = 100) end # burn in accuracy -accuracy(model, train_dataloader) -``` - -### Training Parameters - -Once we have our model, we can train our neural network by backpropagation using `Flux.train!`. -This function requires **Loss**, **Optimizer** and **Callback** functions. +accuracy(m, zip(x_train, y_train), ps, st) -#### Loss - -**Cross Entropy** is the loss function computed here which applies a **Softmax** operation on the -final output of our model. `logitcrossentropy` takes in the prediction from our -model `model(x)` and compares it to actual output `y`: - -```julia -loss(x, y) = logitcrossentropy(model(x), y) - -# burn in loss -loss(img, lab) -``` - -#### Optimizer - -`Adam` is specified here as our optimizer with a **learning rate of 0.05**: +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end -```julia -opt = Adam(0.05) -``` +#burn in loss +loss_function(ps, x_train[1], y_train[1]) -#### CallBack +opt = OptimizationOptimisers.Adam(0.05) +iter = 0 -This callback function is used to print both the training and testing accuracy after -10 training iterations: +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) -```julia -callback() = begin +function callback(ps, l, pred) global iter += 1 - # Monitor that the weights update - # Every 10 training iterations show accuracy - if iter % 10 == 1 - train_accuracy = accuracy(model, train_dataloader) * 100 - test_accuracy = accuracy(model, test_dataloader; - n_batches = length(test_dataloader)) * 100 - @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n", - iter, train_accuracy, test_accuracy) + #Monitor that the weights do infact update + #Every 10 training iterations show accuracy + if (iter % 10 == 0) + @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))" end + return false end -``` - -### Train -To train our model, we select the appropriate trainable parameters of our network with `params`. -In our case, backpropagation is required for `down`, `nn_ode` and `fc`. Notice that the parameters -for Neural ODE is given by `nn_ode.p`: - -```julia # Train the NN-ODE and monitor the loss and weights. -Flux.train!(loss, Flux.params(down, nn_ode.p, fc), train_dataloader, opt, callback = callback) +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); maxiters = 10, callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 ``` -### Expected Output +## Expected Output -```julia +```txt Iter: 1 || Train Accuracy: 8.453 || Test Accuracy: 8.883 Iter: 11 || Train Accuracy: 14.773 || Test Accuracy: 14.967 Iter: 21 || Train Accuracy: 24.383 || Test Accuracy: 24.433 diff --git a/docs/src/examples/mnist_neural_ode.md b/docs/src/examples/mnist_neural_ode.md index c5b83fcdb7..74e3d4577d 100644 --- a/docs/src/examples/mnist_neural_ode.md +++ b/docs/src/examples/mnist_neural_ode.md @@ -6,125 +6,147 @@ on **GPUs** with **minibatching**. (Step-by-step description below) ```julia -using DiffEqFlux, DifferentialEquations, NNlib, MLDataUtils, Printf -using Flux.Losses: logitcrossentropy -using MLDatasets -using CUDA +using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, + ComponentArrays, Random, Optimization, OptimizationOptimisers, LuxCUDA +using MLDatasets: MNIST +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs + CUDA.allowscalar(false) +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +const cdev = cpu_device() +const gdev = gpu_device() -train_data = MNIST(split = :train) -test_data = MNIST(split = :test) +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) -function loadmnist(data::MNIST; batchsize::Int = 128, shuffle = true) - # Process images into (H,W,C,N) batches by inserting channel dimension - x = reshape(data.features, 28, 28, 1, :) - # One-hot encode targets - y = Flux.onehotbatch(data.targets, 0:9) - # Minibatch data using Flux DataLoader - return Flux.DataLoader((x, y); batchsize = batchsize, shuffle = shuffle) |> Flux.gpu +function loadmnist(batchsize = bs) + # Use MLDataUtils LabelEnc for natural onehot conversion + function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + end + # Load MNIST + mnist = MNIST(; split = :train) + imgs, labels_raw = mnist.features, mnist.targets + # Process images into (H,W,C,BS) batches + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> + gdev + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> gdev + y_train = batchview(y_train, batchsize) + return x_train, y_train end # Main -train_dataloader = loadmnist(train_data) -test_dataloader = loadmnist(test_data, batchsize = length(test_data), shuffle = false) +const bs = 128 +x_train, y_train = loadmnist(bs) -down = Flux.Chain(Flux.flatten, Flux.Dense(784, 20, tanh)) |> Flux.gpu +down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) +nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 20, tanh)) +fc = Lux.Dense(20, 10) -nn = Flux.Chain(Flux.Dense(20, 10, tanh), - Flux.Dense(10, 10, tanh), - Flux.Dense(10, 20, tanh)) |> Flux.gpu - - -nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu -fc = Flux.Chain(Flux.Dense(20, 10)) |> Flux.gpu +nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, reltol = 1e-3, + abstol = 1e-3, save_start = false) function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) + xarr = gdev(x) return reshape(xarr, size(xarr)[1:2]) end -# Build our overall model topology -model = Flux.Chain(down, - nn_ode, - DiffEqArray_to_Array, - fc) |> Flux.gpu; +#Build our over-all model topology +m = Lux.Chain(; down, nn_ode, convert = Lux.WrappedFunction(DiffEqArray_to_Array), fc) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> gdev +st = st |> gdev -# To understand the intermediate NN-ODE layer, we can examine it's dimensionality -img, lab = train_dataloader.data[1:1] -x_d = down(img) +#We can also build the model topology without a NN-ODE +m_no_ode = Lux.Chain(; down, nn, fc) +ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode) +ps_no_ode = ComponentArray(ps_no_ode) |> gdev +st_no_ode = st_no_ode |> gdev -# We can see that we can compute the forward pass through the NN topology -# featuring an NNODE layer. -x_m = model(img) +#To understand the intermediate NN-ODE layer, we can examine it's dimensionality +x_d = first(down(x_train[1], ps.down, st.down)) + +# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer. +x_m = first(m(x_train[1], ps, st)) +#Or without the NN-ODE layer. +x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode)) classify(x) = argmax.(eachcol(x)) -function accuracy(model, data; n_batches = 100) +function accuracy(model, data, ps, st; n_batches = 100) total_correct = 0 total = 0 - for (i, (x, y)) in enumerate(collect(data)) - # Only evaluate accuracy for n_batches - i > n_batches && break - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(cdev(y)) + predicted_class = classify(cdev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end return total_correct / total end +#burn in accuracy +accuracy(m, zip(x_train, y_train), ps, st) -# burn in accuracy -accuracy(model, train_dataloader) - -loss(x, y) = logitcrossentropy(model(x), y) +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end -# burn in loss -loss(img, lab) +#burn in loss +loss_function(ps, x_train[1], y_train[1]) -opt = Adam(0.05) +opt = OptimizationOptimisers.Adam(0.05) iter = 0 -callback() = begin +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) + +function callback(ps, l, pred) global iter += 1 - # Monitor that the weights do infact update - # Every 10 training iterations show accuracy - if iter % 10 == 1 - train_accuracy = accuracy(model, train_dataloader) * 100 - test_accuracy = accuracy(model, test_dataloader; - n_batches = length(test_dataloader)) * 100 - @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n", - iter, train_accuracy, test_accuracy) + #Monitor that the weights do infact update + #Every 10 training iterations show accuracy + if (iter % 10 == 0) + @info "[MNIST GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))" end + return false end # Train the NN-ODE and monitor the loss and weights. -Flux.train!(loss, Flux.params(down, nn_ode.p, fc), train_dataloader, opt, cb = callback) +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 ``` - ## Step-by-Step Description ### Load Packages ```julia -using DiffEqFlux, DifferentialEquations, NNlib, MLDataUtils, Printf -using Flux.Losses: logitcrossentropy -using MLDatasets +using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, + ComponentArrays, Random, Optimization, OptimizationOptimisers, LuxCUDA +using MLDatasets: MNIST +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs ``` ### GPU + A good trick used here: ```julia -using CUDA + CUDA.allowscalar(false) +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +const cdev = cpu_device() +const gdev = gpu_device() ``` ensures that only optimized kernels are called when using the GPU. -Additionally, the `gpu` function is shown as a way to translate models and data over to the GPU. +Additionally, the `gpu_device` function is shown as a way to translate models and data over to the GPU. Note that this function is CPU-safe, so if the GPU is disabled or unavailable, this code will fall back to the CPU. @@ -136,22 +158,30 @@ The preprocessing is done in `loadmnist` where the raw MNIST data is split into Features are reshaped into format **[Height, Width, Color, Samples]**, in case of the train set **[28, 28, 1, 60000]**. Using Flux's `onehotbatch` function, the labels (numbers 0 to 9) are one-hot encoded, resulting in a a **[10, 60000]** `OneHotMatrix`. -Features and labels are then passed to Flux's DataLoader. +Features and labels are then passed to Flux's DataLoader. This automatically minibatches both the images and labels using the specified `batchsize`, meaning that every minibatch will contain 128 images with a single color channel of 28x28 pixels. Additionally, it allows us to shuffle the train dataset in each epoch. ```julia -train_data = MNIST(split = :train) -test_data = MNIST(split = :test) - -function loadmnist(data::MNIST; batchsize::Int = 128, shuffle = true) - # Process images into (H,W,C,N) batches by inserting channel dimension - x = reshape(data.features, 28, 28, 1, :) - # One-hot encode targets - y = Flux.onehotbatch(data.targets, 0:9) - # Minibatch data using Flux DataLoader - return Flux.DataLoader((x, y); batchsize = batchsize, shuffle = shuffle) |> Flux.gpu +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) + +function loadmnist(batchsize = bs) + # Use MLDataUtils LabelEnc for natural onehot conversion + function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + end + # Load MNIST + mnist = MNIST(; split = :train) + imgs, labels_raw = mnist.features, mnist.targets + # Process images into (H,W,C,BS) batches + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> + gdev + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> gdev + y_train = batchview(y_train, batchsize) + return x_train, y_train end ``` @@ -159,47 +189,35 @@ and then loaded from main: ```julia # Main -train_dataloader = loadmnist(train_data) -test_dataloader = loadmnist(test_data, batchsize = length(test_data), shuffle = false) +const bs = 128 +x_train, y_train = loadmnist(bs) ``` - ### Layers The Neural Network requires passing inputs sequentially through multiple layers. We use `Chain` which allows inputs to functions to come from the previous layer and sends the outputs to the next. Four different sets of layers are used here: - ```julia -down = Flux.Chain(Flux.flatten, Flux.Dense(784, 20, tanh)) |> Flux.gpu - -nn = Flux.Chain(Flux.Dense(20, 10, tanh), - Flux.Dense(10, 10, tanh), - Flux.Dense(10, 20, tanh)) |> Flux.gpu - - -nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu -fc = Flux.Chain(Flux.Dense(20, 10)) |> Flux.gpu +down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) +nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 20, tanh)) +fc = Lux.Dense(20, 10) ``` `down`: This layer downsamples our images into a 20 dimensional feature vector. - It takes a 28 x 28 image, flattens it, and then passes it through a fully connected - layer with `tanh` activation +It takes a 28 x 28 image, flattens it, and then passes it through a fully connected +layer with `tanh` activation `nn`: A 3 layers Deep Neural Network Chain with `tanh` activation which is used to model - our differential equation +our differential equation `nn_ode`: ODE solver layer `fc`: The final fully connected layer which maps our learned feature vector to the probability of - the feature vector of belonging to a particular class - -`|> gpu`: An utility function which transfers our model to GPU, if it is available +the feature vector of belonging to a particular class ### Array Conversion @@ -207,15 +225,17 @@ When using `NeuralODE`, this function converts the ODESolution's `DiffEqArray` t a Matrix (CuArray), and reduces the matrix from 3 to 2 dimensions for use in the next layer. ```julia +nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, reltol = 1e-3, + abstol = 1e-3, save_start = false) + function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) + xarr = gdev(x) return reshape(xarr, size(xarr)[1:2]) end ``` For CPU: If this function does not automatically fall back to CPU when no GPU is present, we can -change `gpu(x)` to `Array(x)`. - +change `gdev(x)` to `Array(x)`. ### Build Topology @@ -223,34 +243,10 @@ Next, we connect all layers together in a single chain: ```julia # Build our overall model topology -model = Flux.Chain(down, - nn_ode, - DiffEqArray_to_Array, - fc) |> Flux.gpu; -``` - -There are a few things we can do to examine the inner workings of our neural network: - -```julia -img, lab = train_dataloader.data[1:1] - -# To understand the intermediate NN-ODE layer, we can examine it's dimensionality -x_d = down(img) - -# We can see that we can compute the forward pass through the NN topology -# featuring an NNODE layer. -x_m = model(img) -``` - -This can also be built without the NN-ODE by replacing `nn-ode` with a simple `nn`: - -```julia -# We can also build the model topology without a NN-ODE -m_no_ode = Flux.Chain(down, - nn, - fc) |> Flux.gpu - -x_m = m_no_ode(img) +m = Lux.Chain(; down, nn_ode, convert = Lux.WrappedFunction(DiffEqArray_to_Array), fc) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> gdev +st = st |> gdev ``` ### Prediction @@ -267,22 +263,20 @@ classify(x) = argmax.(eachcol(x)) We then evaluate the accuracy on `n_batches` at a time through the entire network: ```julia -function accuracy(model, data; n_batches = 100) +function accuracy(model, data, ps, st; n_batches = 100) total_correct = 0 total = 0 - for (i, (x, y)) in enumerate(collect(data)) - # Only evaluate accuracy for n_batches - i > n_batches && break - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(cdev(y)) + predicted_class = classify(cdev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end return total_correct / total end - -# burn in accuracy -accuracy(m, train_dataloader) +#burn in accuracy +accuracy(m, zip(x_train, y_train), ps, st) ``` ### Training Parameters @@ -297,10 +291,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our model `model(x)` and compares it to actual output `y`: ```julia -loss(x, y) = logitcrossentropy(model(x), y) +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end -# burn in loss -loss(img, lab) +#burn in loss +loss_function(ps, x_train[1], y_train[1]) ``` #### Optimizer @@ -308,7 +305,7 @@ loss(img, lab) `Adam` is specified here as our optimizer with a **learning rate of 0.05**: ```julia -opt = Adam(0.05) +opt = OptimizationOptimisers.Adam(0.05) ``` #### CallBack @@ -317,17 +314,20 @@ This callback function is used to print both the training and testing accuracy a 10 training iterations: ```julia -callback() = begin +iter = 0 + +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) + +function callback(ps, l, pred) global iter += 1 - # Monitor that the weights update - # Every 10 training iterations show accuracy - if iter % 10 == 1 - train_accuracy = accuracy(model, train_dataloader) * 100 - test_accuracy = accuracy(model, test_dataloader; - n_batches = length(test_dataloader)) * 100 - @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n", - iter, train_accuracy, test_accuracy) + #Monitor that the weights do infact update + #Every 10 training iterations show accuracy + if (iter % 10 == 0) + @info "[MNIST GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))" end + return false end ``` @@ -339,53 +339,57 @@ for Neural ODE is given by `nn_ode.p`: ```julia # Train the NN-ODE and monitor the loss and weights. -Flux.train!(loss, Flux.params( down, nn_ode.p, fc), zip( x_train, y_train ), opt, callback = callback) +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 ``` ### Expected Output -```julia -Iter: 1 || Train Accuracy: 16.203 || Test Accuracy: 16.933 -Iter: 11 || Train Accuracy: 64.406 || Test Accuracy: 64.900 -Iter: 21 || Train Accuracy: 76.656 || Test Accuracy: 76.667 -Iter: 31 || Train Accuracy: 81.758 || Test Accuracy: 81.683 -Iter: 41 || Train Accuracy: 81.078 || Test Accuracy: 81.967 -Iter: 51 || Train Accuracy: 83.953 || Test Accuracy: 84.417 -Iter: 61 || Train Accuracy: 85.266 || Test Accuracy: 85.017 -Iter: 71 || Train Accuracy: 85.938 || Test Accuracy: 86.400 -Iter: 81 || Train Accuracy: 84.836 || Test Accuracy: 85.533 -Iter: 91 || Train Accuracy: 86.148 || Test Accuracy: 86.583 -Iter: 101 || Train Accuracy: 83.859 || Test Accuracy: 84.500 -Iter: 111 || Train Accuracy: 86.227 || Test Accuracy: 86.617 -Iter: 121 || Train Accuracy: 87.508 || Test Accuracy: 87.200 -Iter: 131 || Train Accuracy: 86.227 || Test Accuracy: 85.917 -Iter: 141 || Train Accuracy: 84.453 || Test Accuracy: 84.850 -Iter: 151 || Train Accuracy: 86.063 || Test Accuracy: 85.650 -Iter: 161 || Train Accuracy: 88.375 || Test Accuracy: 88.033 -Iter: 171 || Train Accuracy: 87.398 || Test Accuracy: 87.683 -Iter: 181 || Train Accuracy: 88.070 || Test Accuracy: 88.350 -Iter: 191 || Train Accuracy: 86.836 || Test Accuracy: 87.150 -Iter: 201 || Train Accuracy: 89.266 || Test Accuracy: 88.583 -Iter: 211 || Train Accuracy: 86.633 || Test Accuracy: 85.550 -Iter: 221 || Train Accuracy: 89.313 || Test Accuracy: 88.217 -Iter: 231 || Train Accuracy: 88.641 || Test Accuracy: 89.417 -Iter: 241 || Train Accuracy: 88.617 || Test Accuracy: 88.550 -Iter: 251 || Train Accuracy: 88.211 || Test Accuracy: 87.950 -Iter: 261 || Train Accuracy: 87.742 || Test Accuracy: 87.317 -Iter: 271 || Train Accuracy: 89.070 || Test Accuracy: 89.217 -Iter: 281 || Train Accuracy: 89.703 || Test Accuracy: 89.067 -Iter: 291 || Train Accuracy: 88.484 || Test Accuracy: 88.250 -Iter: 301 || Train Accuracy: 87.898 || Test Accuracy: 88.367 -Iter: 311 || Train Accuracy: 88.438 || Test Accuracy: 88.633 -Iter: 321 || Train Accuracy: 88.664 || Test Accuracy: 88.567 -Iter: 331 || Train Accuracy: 89.906 || Test Accuracy: 89.883 -Iter: 341 || Train Accuracy: 88.883 || Test Accuracy: 88.667 -Iter: 351 || Train Accuracy: 89.609 || Test Accuracy: 89.283 -Iter: 361 || Train Accuracy: 89.516 || Test Accuracy: 89.117 -Iter: 371 || Train Accuracy: 89.898 || Test Accuracy: 89.633 -Iter: 381 || Train Accuracy: 89.055 || Test Accuracy: 89.017 -Iter: 391 || Train Accuracy: 89.445 || Test Accuracy: 89.467 -Iter: 401 || Train Accuracy: 89.156 || Test Accuracy: 88.250 -Iter: 411 || Train Accuracy: 88.977 || Test Accuracy: 89.083 -Iter: 421 || Train Accuracy: 90.109 || Test Accuracy: 89.417 +```txt +[ Info: [MNIST GPU] Accuracy: 0.602734375 +[ Info: [MNIST GPU] Accuracy: 0.719609375 +[ Info: [MNIST GPU] Accuracy: 0.783671875 +[ Info: [MNIST GPU] Accuracy: 0.8171875 +[ Info: [MNIST GPU] Accuracy: 0.82390625 +[ Info: [MNIST GPU] Accuracy: 0.840546875 +[ Info: [MNIST GPU] Accuracy: 0.839765625 +[ Info: [MNIST GPU] Accuracy: 0.843046875 +[ Info: [MNIST GPU] Accuracy: 0.8609375 +[ Info: [MNIST GPU] Accuracy: 0.86 +[ Info: [MNIST GPU] Accuracy: 0.866875 +[ Info: [MNIST GPU] Accuracy: 0.86484375 +[ Info: [MNIST GPU] Accuracy: 0.883515625 +[ Info: [MNIST GPU] Accuracy: 0.87046875 +[ Info: [MNIST GPU] Accuracy: 0.87609375 +[ Info: [MNIST GPU] Accuracy: 0.880703125 +[ Info: [MNIST GPU] Accuracy: 0.874609375 +[ Info: [MNIST GPU] Accuracy: 0.870859375 +[ Info: [MNIST GPU] Accuracy: 0.881640625 +[ Info: [MNIST GPU] Accuracy: 0.887734375 +[ Info: [MNIST GPU] Accuracy: 0.88734375 +[ Info: [MNIST GPU] Accuracy: 0.880078125 +[ Info: [MNIST GPU] Accuracy: 0.88078125 +[ Info: [MNIST GPU] Accuracy: 0.88125 +[ Info: [MNIST GPU] Accuracy: 0.87203125 +[ Info: [MNIST GPU] Accuracy: 0.857890625 +[ Info: [MNIST GPU] Accuracy: 0.87203125 +[ Info: [MNIST GPU] Accuracy: 0.877578125 +[ Info: [MNIST GPU] Accuracy: 0.879765625 +[ Info: [MNIST GPU] Accuracy: 0.885703125 +[ Info: [MNIST GPU] Accuracy: 0.895 +[ Info: [MNIST GPU] Accuracy: 0.90171875 +[ Info: [MNIST GPU] Accuracy: 0.893359375 +[ Info: [MNIST GPU] Accuracy: 0.882109375 +[ Info: [MNIST GPU] Accuracy: 0.87453125 +[ Info: [MNIST GPU] Accuracy: 0.881171875 +[ Info: [MNIST GPU] Accuracy: 0.891171875 +[ Info: [MNIST GPU] Accuracy: 0.899921875 +[ Info: [MNIST GPU] Accuracy: 0.89890625 +[ Info: [MNIST GPU] Accuracy: 0.895078125 +[ Info: [MNIST GPU] Accuracy: 0.89171875 +[ Info: [MNIST GPU] Accuracy: 0.899296875 +[ Info: [MNIST GPU] Accuracy: 0.891484375 +[ Info: [MNIST GPU] Accuracy: 0.899375 +[ Info: [MNIST GPU] Accuracy: 0.88953125 +[ Info: [MNIST GPU] Accuracy: 0.88890625 ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 0963f16f16..8dd888a2d2 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -1,7 +1,8 @@ # Multiple Shooting !!! note - The form of multiple shooting found here is a specialized form for implicit layer deep learning (known as data shooting) which assumes full observability of the underlying dynamics and lack of noise. For a more general implementation of multiple shooting, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/). For an implementation more directly tied to parameter estimation against data, see [DiffEqParamEstim.jl](https://docs.sciml.ai/DiffEqParamEstim/stable/). + + The form of multiple shooting found here is a specialized form for implicit layer deep learning (known as data shooting) which assumes full observability of the underlying dynamics and lack of noise. For a more general implementation of multiple shooting, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/). For an implementation more directly tied to parameter estimation against data, see [DiffEqParamEstim.jl](https://docs.sciml.ai/DiffEqParamEstim/stable/). In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals. If the end conditions of any @@ -12,8 +13,7 @@ then the joined/combined solution is the same as solving on the whole dataset To ensure that the overlapping part of two consecutive intervals coincide, we add a penalizing term: -`continuity_term * absolute_value_of(prediction -of last point of group i - prediction of first point of group i+1)` +`continuity_term * absolute_value_of(prediction of last point of group i - prediction of first point of group i+1)` to the loss. @@ -23,7 +23,8 @@ high penalties in case the solver predicts discontinuous values. The following is a working demo, using Multiple Shooting: ```julia -using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots +using ComponentArrays, + Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, OrdinaryDiffEq, Plots using DiffEqFlux: group_ranges using Random @@ -33,51 +34,49 @@ rng = Random.default_rng() datasize = 30 u0 = Float32[2.0, 0.0] tspan = (0.0f0, 5.0f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) # Get the data function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) +ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) # Define the Neural Network -nn = Lux.Chain(x -> x.^3, - Lux.Dense(2, 16, tanh), - Lux.Dense(16, 2)) +nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2)) p_init, st = Lux.setup(rng, nn) -neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps) -prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init)) +neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) +prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init)) function plot_multiple_shoot(plt, preds, group_size) - step = group_size-1 - ranges = group_ranges(datasize, group_size) + step = group_size - 1 + ranges = group_ranges(datasize, group_size) - for (i, rg) in enumerate(ranges) - plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)") - end + for (i, rg) in enumerate(ranges) + plot!(plt, tsteps[rg], preds[i][1, :]; markershape = :circle, label = "Group $(i)") + end end anim = Plots.Animation() iter = 0 callback = function (p, l, preds; doplot = true) - display(l) - global iter - iter += 1 - if doplot && iter%1 == 0 - # plot the original data - plt = scatter(tsteps, ode_data[1,:], label = "Data") - - # plot the different predictions for individual shoot - plot_multiple_shoot(plt, preds, group_size) - - frame(anim) - display(plot(plt)) - end - return false + display(l) + global iter + iter += 1 + if doplot && iter % 1 == 0 + # plot the original data + plt = scatter(tsteps, ode_data[1, :]; label = "Data") + + # plot the different predictions for individual shoot + plot_multiple_shoot(plt, preds, group_size) + + frame(anim) + display(plot(plt)) + end + return false end # Define parameters for Multiple Shooting @@ -85,19 +84,23 @@ group_size = 3 continuity_term = 200 function loss_function(data, pred) - return sum(abs2, data - pred) + return sum(abs2, data - pred) end +ps = ComponentArray(p_init) +pd, pax = getdata(ps), getaxes(ps) + function loss_multiple_shooting(p) - return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), - group_size; continuity_term) + ps = ComponentArray(p, pax) + return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), + group_size; continuity_term) end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype) -optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_init)) -res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback) -gif(anim, "multiple_shooting.gif", fps=15) +optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) +optprob = Optimization.OptimizationProblem(optf, pd) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +gif(anim, "multiple_shooting.gif"; fps = 15) ``` ![pic](https://camo.githubusercontent.com/9f1a4b38895ebaa47b7d90e53268e6f10d04da684b58549624c637e85c22d27b/68747470733a2f2f692e696d6775722e636f6d2f636d507a716a722e676966) diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index 770476fc3b..5f40ccfa89 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -12,30 +12,29 @@ Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process: ```@example neuralode_cp -using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationOptimisers, Random, Plots +using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL, + OptimizationOptimisers, Random, Plots rng = Random.default_rng() u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) +ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) -dudt2 = Lux.Chain(x -> x.^3, - Lux.Dense(2, 50, tanh), - Lux.Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) p, st = Lux.setup(rng, dudt2) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) function predict_neuralode(p) - Array(prob_neuralode(u0, p, st)[1]) + Array(prob_neuralode(u0, p, st)[1]) end function loss_neuralode(p) @@ -47,18 +46,18 @@ end # Do not plot by default for the documentation # Users should change doplot=true to see the plots callbacks callback = function (p, l, pred; doplot = false) - println(l) - # plot current prediction against data - if doplot - plt = scatter(tsteps, ode_data[1,:], label = "data") - scatter!(plt, tsteps, pred[1,:], label = "prediction") - display(plot(plt)) - end - return false + println(l) + # plot current prediction against data + if doplot + plt = scatter(tsteps, ode_data[1, :]; label = "data") + scatter!(plt, tsteps, pred[1, :]; label = "prediction") + display(plot(plt)) + end + return false end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...; doplot=true) +callback(pinit, loss_neuralode(pinit)...; doplot = true) # use Optimization.jl to solve the problem adtype = Optimization.AutoZygote() @@ -66,19 +65,15 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) -result_neuralode = Optimization.solve(optprob, - Adam(0.05), - callback = callback, - maxiters = 300) +result_neuralode = Optimization.solve(optprob, Adam(0.05); callback = callback, + maxiters = 300) -optprob2 = remake(optprob,u0 = result_neuralode.u) +optprob2 = remake(optprob; u0 = result_neuralode.u) -result_neuralode2 = Optimization.solve(optprob2, - Optim.BFGS(initial_stepnorm=0.01), - callback=callback, - allow_f_increases = false) +result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = 0.01); + callback, allow_f_increases = false) -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true) +callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) ``` ![Neural ODE](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) @@ -88,21 +83,22 @@ callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=tru Let's get a time series array from a spiral ODE to train against. ```@example neuralode -using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationOptimisers, Random, Plots +using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, + OptimizationOptimJL, OptimizationOptimisers, Random, Plots rng = Random.default_rng() u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) +ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) ``` Now let's define a neural network with a `NeuralODE` layer. First, we define @@ -110,19 +106,15 @@ the layer. Here we're going to use `Lux.Chain`, which is a suitable neural netwo structure for NeuralODEs with separate handling of state variables: ```@example neuralode -dudt2 = Lux.Chain(x -> x.^3, - Lux.Dense(2, 50, tanh), - Lux.Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) p, st = Lux.setup(rng, dudt2) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps) ``` -Note that we can directly use `Chain`s from Flux.jl as well, for example: +Note that we can directly use `Chain`s from Lux.jl as well, for example: ```julia -dudt2 = Chain(x -> x.^3, - Dense(2, 50, tanh), - Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) ``` In our model, we used the `x -> x.^3` assumption in the model. By incorporating @@ -136,7 +128,7 @@ output against the time series data: ```@example neuralode function predict_neuralode(p) - Array(prob_neuralode(u0, p, st)[1]) + Array(prob_neuralode(u0, p, st)[1]) end function loss_neuralode(p) @@ -153,14 +145,14 @@ it would show every step and overflow the documentation, but for your use case ```@example neuralode # Callback function to observe training callback = function (p, l, pred; doplot = false) - println(l) - # plot current prediction against data - if doplot - plt = scatter(tsteps, ode_data[1,:], label = "data") - scatter!(plt, tsteps, pred[1,:], label = "prediction") - display(plot(plt)) - end - return false + println(l) + # plot current prediction against data + if doplot + plt = scatter(tsteps, ode_data[1, :]; label = "data") + scatter!(plt, tsteps, pred[1, :]; label = "prediction") + display(plot(plt)) + end + return false end pinit = ComponentArray(p) @@ -189,9 +181,9 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) result_neuralode = Optimization.solve(optprob, - Adam(0.05), - callback = callback, - maxiters = 300) + Adam(0.05); + callback = callback, + maxiters = 300) ``` We then complete the training using a different optimizer, starting from where @@ -200,16 +192,16 @@ halt when near the minimum. ```@example neuralode # Retrain using the LBFGS optimizer -optprob2 = remake(optprob,u0 = result_neuralode.u) +optprob2 = remake(optprob; u0 = result_neuralode.u) result_neuralode2 = Optimization.solve(optprob2, - Optim.BFGS(initial_stepnorm=0.01), - callback = callback, - allow_f_increases = false) + Optim.BFGS(; initial_stepnorm = 0.01); + callback = callback, + allow_f_increases = false) ``` And then we use the callback with `doplot=true` to see the final plot: ```@example neuralode -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true) +callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) ``` diff --git a/docs/src/examples/neural_ode_weather_forecast.md b/docs/src/examples/neural_ode_weather_forecast.md index e04fface26..fa2151068b 100644 --- a/docs/src/examples/neural_ode_weather_forecast.md +++ b/docs/src/examples/neural_ode_weather_forecast.md @@ -1,43 +1,27 @@ # Weather forecasting with neural ODEs -In this example we are going to apply neural ODEs to a multidimensional weather dataset and use it for weather forecasting. +In this example we are going to apply neural ODEs to a multidimensional weather dataset and use it for weather forecasting. This example is adapted from [Forecasting the weather with neural ODEs - Sebatian Callh personal blog](https://sebastiancallh.github.io/post/neural-ode-weather-forecast/). ## The data The data is a four-dimensional dataset of daily temperature, humidity, wind speed and pressure meassured over four years in the city Delhi. Let us download and plot it. - ```julia -using Random -using Dates -using Optimization -using ComponentArrays -using Lux -using DiffEqFlux: NeuralODE, AdamW, swish -using DifferentialEquations -using CSV -using DataFrames -using Dates -using Statistics -using Plots -using DataDeps - - -function download_data( - data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/", - data_local_path = "./delhi" -) - - load(file_name) = begin - data_dep = DataDep("delhi/train", "", "$data_url/$file_name") - Base.download(data_dep, data_local_path; i_accept_the_terms_of_use=true) - CSV.read(joinpath(data_local_path, file_name), DataFrame) - end - - train_df = load("DailyDelhiClimateTrain.csv") - test_df = load("DailyDelhiClimateTest.csv") - return vcat(train_df, test_df) +using Random, Dates, Optimization, ComponentArrays, Lux, OptimizationOptimisers, DiffEqFlux, + OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots, DataDeps + +function download_data(data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/", + data_local_path = "./delhi") + function load(file_name) + data_dep = DataDep("delhi/train", "", "$data_url/$file_name") + Base.download(data_dep, data_local_path; i_accept_the_terms_of_use = true) + CSV.read(joinpath(data_local_path, file_name), DataFrame) + end + + train_df = load("DailyDelhiClimateTrain.csv") + test_df = load("DailyDelhiClimateTest.csv") + return vcat(train_df, test_df) end df = download_data() @@ -50,95 +34,79 @@ FEATURE_NAMES = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"] function plot_data(df) plots = map(enumerate(zip(FEATURES, FEATURE_NAMES, UNITS))) do (i, (f, n, u)) - plot(df[:, :date], df[:, f], - title=n, label=nothing, - ylabel=u, size=(800, 600), - color=i) + plot(df[:, :date], df[:, f]; title = n, label = nothing, + ylabel = u, size = (800, 600), color = i) end n = length(plots) - plot(plots..., layout=(Int(n / 2), Int(n / 2))) + plot(plots...; layout = (Int(n / 2), Int(n / 2))) end plot_data(df) ``` + The data show clear annual behaviour (it is difficult to see for pressure due to wild measurement errors but the pattern is there). It is concievable that this system can be described with an ODE, but which? Let us use an network to learn the dynamics from the dataset. Training neural networks is easier with standardised data so we will compute standardised features before training. Finally, we take the first 20 days for training and the rest for testing. ```julia function standardize(x) - μ = mean(x; dims=2) - σ = std(x; dims=2) + μ = mean(x; dims = 2) + σ = std(x; dims = 2) z = (x .- μ) ./ σ return z, μ, σ end -function featurize(raw_df, num_train=20) +function featurize(raw_df, num_train = 20) raw_df.year = Float64.(year.(raw_df.date)) raw_df.month = Float64.(month.(raw_df.date)) - df = combine( - groupby(raw_df, [:year, :month]), + df = combine(groupby(raw_df, [:year, :month]), :date => (d -> mean(year.(d)) .+ mean(month.(d)) ./ 12), :meantemp => mean, :humidity => mean, :wind_speed => mean, - :meanpressure => mean, - renamecols=false - ) + :meanpressure => mean; + renamecols = false) @show size(df) t_and_y(df) = df.date', Matrix(select(df, FEATURES))' - t_train, y_train = t_and_y(df[1:num_train,:]) - t_test, y_test = t_and_y(df[num_train+1:end,:]) + t_train, y_train = t_and_y(df[1:num_train, :]) + t_test, y_test = t_and_y(df[(num_train + 1):end, :]) t_train, t_mean, t_scale = standardize(t_train) y_train, y_mean, y_scale = standardize(y_train) t_test = (t_test .- t_mean) ./ t_scale y_test = (y_test .- y_mean) ./ y_scale - return ( - vec(t_train), y_train, - vec(t_test), y_test, + return (vec(t_train), y_train, + vec(t_test), y_test, (t_mean, t_scale), - (y_mean, y_scale) - ) + (y_mean, y_scale)) end function plot_features(t_train, y_train, t_test, y_test) - - plt_split = plot( - reshape(t_train, :), y_train', - linewidth = 3, colors = 1:4, - xlabel = "Normalized time", - ylabel = "Normalized values", - label = nothing, - title = "Features" - ) - plot!( - plt_split, reshape(t_test, :), y_test', - linewidth = 3, linestyle = :dash, - color = [1 2 3 4], label = nothing - ) - - plot!( - plt_split, [0], [0], linewidth = 0, - label = "Train", color = 1 - ) - plot!( - plt_split, [0], [0], linewidth = 0, - linestyle = :dash, label = "Test", - color = 1, - ylims=(-5, 5) - ) + plt_split = plot(reshape(t_train, :), y_train'; + linewidth = 3, colors = 1:4, + xlabel = "Normalized time", + ylabel = "Normalized values", + label = nothing, + title = "Features") + plot!(plt_split, reshape(t_test, :), y_test'; + linewidth = 3, linestyle = :dash, + color = [1 2 3 4], label = nothing) + + plot!(plt_split, [0], [0]; linewidth = 0, + label = "Train", color = 1) + plot!(plt_split, [0], [0]; linewidth = 0, + linestyle = :dash, label = "Test", + color = 1, + ylims = (-5, 5)) end -( - t_train, - y_train, - t_test, - y_test, - (t_mean, t_scale), - (y_mean, y_scale) -) = featurize(df) +(t_train, +y_train, +t_test, +y_test, +(t_mean, t_scale), +(y_mean, y_scale)) = featurize(df) plot_features(t_train, y_train, t_test, y_test) ``` @@ -149,38 +117,30 @@ We are now ready to construct and train our model! To avoid local minimas we wil ```julia function neural_ode(t, data_dim) - f = Lux.Chain( - Lux.Dense(data_dim, 64, swish), - Lux.Dense(64, 32, swish), - Lux.Dense(32, data_dim) - ) - - node = NeuralODE( - f, extrema(t), Tsit5(), - saveat=t, - abstol=1e-9, reltol=1e-9 - ) - + f = Chain(Dense(data_dim => 64, swish), Dense(64 => 32, swish), Dense(32 => data_dim)) + + node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, + abstol = 1e-9, reltol = 1e-9) + rng = Random.default_rng() p, state = Lux.setup(rng, f) return node, ComponentArray(p), state end -function train_one_round(node, p, state, y, opt, maxiters, rng, y0=y[:, 1]; kwargs...) +function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kwargs...) predict(p) = Array(node(y0, p, state)[1]) loss(p) = sum(abs2, predict(p) .- y) - + adtype = Optimization.AutoZygote() optf = OptimizationFunction((p, _) -> loss(p), adtype) optprob = OptimizationProblem(optf, p) - res = solve(optprob, opt, maxiters=maxiters; kwargs...) + res = solve(optprob, opt; maxiters = maxiters, kwargs...) res.minimizer, state end -function train(t, y, obs_grid, maxiters, lr, rng, p=nothing, state=nothing; kwargs...) - log_results(ps, losses) = - (p, loss) -> begin +function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...) + log_results(ps, losses) = (p, loss) -> begin push!(ps, copy(p)) push!(losses, loss) false @@ -189,14 +149,11 @@ function train(t, y, obs_grid, maxiters, lr, rng, p=nothing, state=nothing; kwar ps, losses = ComponentArray[], Float32[] for k in obs_grid node, p_new, state_new = neural_ode(t, size(y, 1)) - if p === nothing p = p_new end - if state === nothing state = state_new end - - p, state = train_one_round( - node, p, state, y, AdamW(lr), maxiters, rng; - callback=log_results(ps, losses), - kwargs... - ) + p === nothing && (p = p_new) + state === nothing && (state = state_new) + + p, state = train_one_round(node, p, state, y, AdamW(lr), maxiters, rng; + callback = log_results(ps, losses), kwargs...) end ps, state, losses end @@ -205,7 +162,7 @@ rng = MersenneTwister(123) obs_grid = 4:4:length(t_train) # we train on an increasing amount of the first k obs maxiters = 150 lr = 5e-3 -ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng, progress=true); +ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng; progress = true); ``` We can now animate the training to get a better understanding of the fit. @@ -216,90 +173,59 @@ predict(y0, t, p, state) = begin Array(node(y0, p, state)[1]) end -function plot_pred( - t_train, - y_train, - t_grid, - rescale_t, - rescale_y, - num_iters, - p, - state, - loss, - y0=y_train[:, 1] -) +function plot_pred(t_train, y_train, t_grid, rescale_t, rescale_y, num_iters, p, state, + loss, y0 = y_train[:, 1]) y_pred = predict(y0, t_grid, p, state) - plot_result( - rescale_t(t_train), - rescale_y(y_train), - rescale_t(t_grid), - rescale_y(y_pred), - loss, - num_iters - ) + return plot_result(rescale_t(t_train), rescale_y(y_train), rescale_t(t_grid), + rescale_y(y_pred), loss, num_iters) end function plot_pred(t, y, y_pred) - plt = Plots.scatter(t, y, label="Observation") - Plots.plot!(plt, t, y_pred, label="Prediction") + plt = Plots.scatter(t, y; label = "Observation") + Plots.plot!(plt, t, y_pred; label = "Prediction") end function plot_pred(t, y, t_pred, y_pred; kwargs...) plot_params = zip(eachrow(y), eachrow(y_pred), FEATURE_NAMES, UNITS) map(enumerate(plot_params)) do (i, (yᵢ, ŷᵢ, name, unit)) - plt = Plots.plot( - t_pred, ŷᵢ, label="Prediction", color=i, linewidth=3, - legend=nothing, title=name; kwargs... - ) - Plots.scatter!( - plt, t, yᵢ, label="Observation", - xlabel="Time", ylabel=unit, - markersize=5, color=i - ) + plt = Plots.plot(t_pred, ŷᵢ; label = "Prediction", color = i, linewidth = 3, + legend = nothing, title = name, kwargs...) + Plots.scatter!(plt, t, yᵢ; label = "Observation", xlabel = "Time", ylabel = unit, + markersize = 5, color = i) end end function plot_result(t, y, t_pred, y_pred, loss, num_iters; kwargs...) plts_preds = plot_pred(t, y, t_pred, y_pred; kwargs...) - plot!(plts_preds[1], ylim=(10, 40), legend=(0.65, 1.0)) - plot!(plts_preds[2], ylim=(20, 100)) - plot!(plts_preds[3], ylim=(2, 12)) - plot!(plts_preds[4], ylim=(990, 1025)) - - p_loss = Plots.plot( - loss, label=nothing, linewidth=3, - title="Loss", xlabel="Iterations", - xlim=(0, num_iters) - ) + plot!(plts_preds[1]; ylim = (10, 40), legend = (0.65, 1.0)) + plot!(plts_preds[2]; ylim = (20, 100)) + plot!(plts_preds[3]; ylim = (2, 12)) + plot!(plts_preds[4]; ylim = (990, 1025)) + + p_loss = Plots.plot(loss; label = nothing, linewidth = 3, + title = "Loss", xlabel = "Iterations", xlim = (0, num_iters)) plots = [plts_preds..., p_loss] - plot(plots..., layout=grid(length(plots), 1), size=(900, 900)) + plot(plots...; layout = grid(length(plots), 1), size = (900, 900)) end -function animate_training( - plot_frame, - t_train, - y_train, - ps, - losses, - obs_grid; - pause_for=300 -) +function animate_training(plot_frame, t_train, y_train, ps, losses, obs_grid; + pause_for = 300) obs_count = Dict(i - 1 => n for (i, n) in enumerate(obs_grid)) is = [min(i, length(losses)) for i in 2:(length(losses) + pause_for)] @animate for i in is stage = Int(floor((i - 1) / length(losses) * length(obs_grid))) k = obs_count[stage] - plot_frame(t_train[1:k], y_train[:,1:k], ps[i], losses[1:i]) + plot_frame(t_train[1:k], y_train[:, 1:k], ps[i], losses[1:i]) end every 2 end num_iters = length(losses) -t_train_grid = collect(range(extrema(t_train)..., length=500)) +t_train_grid = collect(range(extrema(t_train)...; length = 500)) rescale_t(x) = t_scale .* x .+ t_mean rescale_y(x) = y_scale .* x .+ y_mean -plot_frame(t, y, p, loss) = plot_pred( - t, y, t_train_grid, rescale_t, rescale_y, num_iters, p, state, loss -) +function plot_frame(t, y, p, loss) + plot_pred(t, y, t_train_grid, rescale_t, rescale_y, num_iters, p, state, loss) +end anim = animate_training(plot_frame, t_train, y_train, ps, losses, obs_grid); gif(anim, "node_weather_forecast_training.gif") ``` @@ -310,26 +236,21 @@ Looks good! But how well does the model forecast? function plot_extrapolation(t_train, y_train, t_test, y_test, t̂, ŷ) plts = plot_pred(t_train, y_train, t̂, ŷ) for (i, (plt, y)) in enumerate(zip(plts, eachrow(y_test))) - scatter!(plt, t_test, y, color=i, markerstrokecolor=:white, label="Test observation") + scatter!(plt, t_test, y; color = i, markerstrokecolor = :white, + label = "Test observation") end - plot!(plts[1], ylim=(10, 40), legend=:topleft) - plot!(plts[2], ylim=(20, 100)) - plot!(plts[3], ylim=(2, 12)) - plot!(plts[4], ylim=(990, 1025)) - plot(plts..., layout=grid(length(plts), 1), size=(900, 900)) + plot!(plts[1]; ylim = (10, 40), legend = :topleft) + plot!(plts[2]; ylim = (20, 100)) + plot!(plts[3]; ylim = (2, 12)) + plot!(plts[4]; ylim = (990, 1025)) + plot(plts...; layout = grid(length(plts), 1), size = (900, 900)) end -t_grid = collect(range(minimum(t_train), maximum(t_test), length=500)) -y_pred = predict(y_train[:,1], t_grid, ps[end], state) -plot_extrapolation( - rescale_t(t_train), - rescale_y(y_train), - rescale_t(t_test), - rescale_y(y_test), - rescale_t(t_grid), - rescale_y(y_pred) -) +t_grid = collect(range(minimum(t_train), maximum(t_test); length = 500)) +y_pred = predict(y_train[:, 1], t_grid, ps[end], state) +plot_extrapolation(rescale_t(t_train), rescale_y(y_train), rescale_t(t_test), + rescale_y(y_test), rescale_t(t_grid), rescale_y(y_pred)) ``` -While there is some drift in the weather patterns, the model extrapolates very well. \ No newline at end of file +While there is some drift in the weather patterns, the model extrapolates very well. diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 4d7d17cf3c..8885e86bae 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -28,28 +28,28 @@ dudt!(u, h, p, t) = model([u; h(t - p.tau)]) prob = DDEProblem(dudt_, u0, h, tspan, nothing) ``` - First, let's build training data from the same example as the neural ODE: ```@example nsde using Plots, Statistics -using Flux, Optimization, OptimizationOptimisers, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis +using ComponentArrays, Optimization, + OptimizationOptimisers, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis, Random -u0 = Float32[2.; 0.] +u0 = Float32[2.0; 0.0] datasize = 30 tspan = (0.0f0, 1.0f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) ``` ```@example nsde function trueSDEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end mp = Float32[0.2, 0.2] function true_noise_func(du, u, p, t) - du .= mp.*u + du .= mp .* u end prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan) @@ -61,8 +61,8 @@ data from the average of 10,000 runs of the SDE: ```@example nsde # Take a typical sample from the mean -ensemble_prob = EnsembleProblem(prob_truesde, safetycopy = false) -ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000) +ensemble_prob = EnsembleProblem(prob_truesde; safetycopy = false) +ensemble_sol = solve(ensemble_prob, SOSRI(); trajectories = 10000) ensemble_sum = EnsembleSummary(ensemble_sol) sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps)) @@ -72,40 +72,38 @@ Now we build a neural SDE. For simplicity, we will use the `NeuralDSDE` neural SDE with diagonal noise layer function: ```@example nsde -drift_dudt = Flux.Chain(x -> x.^3, - Flux.Dense(2, 50, tanh), - Flux.Dense(50, 2)) -p1, re1 = Flux.destructure(drift_dudt) - -diffusion_dudt = Flux.Chain(Flux.Dense(2, 2)) -p2, re2 = Flux.destructure(diffusion_dudt) +drift_dudt = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) +diffusion_dudt = Dense(2, 2) -neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(), - saveat = tsteps, reltol = 1e-1, abstol = 1e-1); -nothing +neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(); + saveat = tsteps, reltol = 1e-1, abstol = 1e-1) +ps, st = Lux.setup(Random.default_rng(), neuralsde) +ps = ComponentArray(ps) ``` Let's see what that looks like: ```@example nsde # Get the prediction using the correct initial condition -prediction0 = neuralsde(u0) +prediction0 = neuralsde(u0, ps, st)[1] -drift_(u, p, t) = re1(p[1:neuralsde.len])(u) -diffusion_(u, p, t) = re2(p[neuralsde.len+1:end])(u) +drift_model = Lux.Experimental.StatefulLuxLayer(drift_dudt, nothing, st.drift) +diffusion_model = Lux.Experimental.StatefulLuxLayer(diffusion_dudt, nothing, st.diffusion) -prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), neuralsde.p) +drift_(u, p, t) = drift_model(u, p.drift) +diffusion_(u, p, t) = diffusion_model(u, p.diffusion) -ensemble_nprob = EnsembleProblem(prob_neuralsde, safetycopy = false) -ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100, - saveat = tsteps) +prob_neuralsde = SDEProblem(drift_, diffusion_, u0, (0.0f0, 1.2f0), ps) + +ensemble_nprob = EnsembleProblem(prob_neuralsde; safetycopy = false) +ensemble_nsol = solve(ensemble_nprob, SOSRI(); trajectories = 100, saveat = tsteps) ensemble_nsum = EnsembleSummary(ensemble_nsol) -plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training") -scatter!(plt1, tsteps, sde_data', lw = 3) +plt1 = plot(ensemble_nsum; title = "Neural SDE: Before Training") +scatter!(plt1, tsteps, sde_data'; lw = 3) -scatter(tsteps, sde_data[1,:], label = "data") -scatter!(tsteps, prediction0[1,:], label = "prediction") +scatter(tsteps, sde_data[1, :]; label = "data") +scatter!(tsteps, prediction0[1, :]; label = "prediction") ``` Now just as with the neural ODE we define a loss function that calculates the @@ -113,18 +111,20 @@ mean and variance from `n` runs at each time point and uses the distance from the data values: ```@example nsde +neuralsde_model = Lux.Experimental.StatefulLuxLayer(neuralsde, nothing, st) + function predict_neuralsde(p, u = u0) - return Array(neuralsde(u, p)) + return Array(neuralsde_model(u, p)) end function loss_neuralsde(p; n = 100) - u = repeat(reshape(u0, :, 1), 1, n) - samples = predict_neuralsde(p, u) - means = mean(samples, dims = 2) - vars = var(samples, dims = 2, mean = means)[:, 1, :] - means = means[:, 1, :] - loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) - return loss, means, vars + u = repeat(reshape(u0, :, 1), 1, n) + samples = predict_neuralsde(p, u) + means = mean(samples; dims = 2) + vars = var(samples; dims = 2, mean = means)[:, 1, :] + means = means[:, 1, :] + loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) + return loss, means, vars end ``` @@ -134,26 +134,26 @@ iter = 0 # Callback function to observe training callback = function (p, loss, means, vars; doplot = false) - global list_plots, iter - - if iter == 0 - list_plots = [] - end - iter += 1 - - # loss against current data - display(loss) - - # plot current prediction against data - plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:], - ylim = (-4.0, 8.0), label = "data") - Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction") - push!(list_plots, plt) - - if doplot - display(plt) - end - return false + global list_plots, iter + + if iter == 0 + list_plots = [] + end + iter += 1 + + # loss against current data + display(loss) + + # plot current prediction against data + plt = Plots.scatter(tsteps, sde_data[1, :]; yerror = sde_data_vars[1, :], + ylim = (-4.0, 8.0), label = "data") + Plots.scatter!(plt, tsteps, means[1, :]; ribbon = vars[1, :], label = "prediction") + push!(list_plots, plt) + + if doplot + display(plt) + end + return false end ``` @@ -166,36 +166,35 @@ opt = Adam(0.025) # First round of training with n = 10 adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype) -optprob = Optimization.OptimizationProblem(optf, neuralsde.p) -result1 = Optimization.solve(optprob, opt, - callback = callback, maxiters = 100) +optf = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 10), adtype) +optprob = Optimization.OptimizationProblem(optf, ps) +result1 = Optimization.solve(optprob, opt; callback, maxiters = 100) ``` We resume the training with a larger `n`. (WARNING - this step is a couple of orders of magnitude longer than the previous one). ```@example nsde -optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype) +optf2 = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 100), adtype) optprob2 = Optimization.OptimizationProblem(optf2, result1.u) -result2 = Optimization.solve(optprob2, opt, - callback = callback, maxiters = 20) +result2 = Optimization.solve(optprob2, opt; callback, maxiters = 20) ``` And now we plot the solution to an ensemble of the trained neural SDE: ```@example nsde -_, means, vars = loss_neuralsde(result2.u, n = 1000) +_, means, vars = loss_neuralsde(result2.u; n = 1000) -plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars', - label = "data", title = "Neural SDE: After Training", - xlabel = "Time") -plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction") +plt2 = Plots.scatter(tsteps, sde_data'; yerror = sde_data_vars', + label = "data", title = "Neural SDE: After Training", + xlabel = "Time") +plot!(plt2, tsteps, means'; lw = 8, ribbon = vars', label = "prediction") -plt = plot(plt1, plt2, layout = (2, 1)) -savefig(plt, "NN_sde_combined.png"); nothing # sde +plt = plot(plt1, plt2; layout = (2, 1)) +savefig(plt, "NN_sde_combined.png"); +nothing; # sde ``` -![](https://user-images.githubusercontent.com/1814174/76975872-88dc9100-6909-11ea-80f7-242f661ebad1.png) +![Neural SDE Trained Example](https://user-images.githubusercontent.com/1814174/76975872-88dc9100-6909-11ea-80f7-242f661ebad1.png) Try this with GPUs as well! diff --git a/docs/src/examples/normalizing_flows.md b/docs/src/examples/normalizing_flows.md index 3dedaa148c..35f55198c1 100644 --- a/docs/src/examples/normalizing_flows.md +++ b/docs/src/examples/normalizing_flows.md @@ -8,55 +8,52 @@ Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process: ```@example cnf -using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimisers, - OptimizationOptimJL, Distributions +using ComponentArrays, DiffEqFlux, OrdinaryDiffEq, Optimization, Distributions, + Random, OptimizationOptimisers, OptimizationOptimJL -nn = Flux.Chain( - Flux.Dense(1, 3, tanh), - Flux.Dense(3, 1, tanh), -) |> f32 +nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) tspan = (0.0f0, 10.0f0) -ffjord_mdl = FFJORD(nn, tspan, Tsit5()) +ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote()) +ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) +ps = ComponentArray(ps) +model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st) # Training data_dist = Normal(6.0f0, 0.7f0) train_data = Float32.(rand(data_dist, 1, 100)) function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ) - -mean(logpx) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end function cb(p, l) - @info "Training" loss = loss(p) - false + @info "FFJORD Training" loss=loss(p) + return false end -adtype = Optimization.AutoZygote() +adtype = Optimization.AutoForwardDiff() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) -optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) +optprob = Optimization.OptimizationProblem(optf, ps) -res1 = Optimization.solve(optprob, - Adam(0.1), - maxiters = 100, - callback=cb) +res1 = Optimization.solve(optprob, Adam(0.01); maxiters = 20, callback = cb) optprob2 = Optimization.OptimizationProblem(optf, res1.u) -res2 = Optimization.solve(optprob2, - Optim.LBFGS(), - allow_f_increases=false, - callback=cb) +res2 = Optimization.solve(optprob2, Optim.LBFGS(); allow_f_increases = false, + callback = cb) # Evaluation using Distances +st_ = (; st..., monte_carlo = false) + actual_pdf = pdf.(data_dist, train_data) -learned_pdf = exp.(ffjord_mdl(train_data, res2.u, monte_carlo=false)[1]) +learned_pdf = exp.(ffjord_mdl(train_data, res2.u, st_)[1][1]) train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2) # Data Generation -ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u)) +ffjord_dist = FFJORDDistribution(ffjord_mdl, ps, st) new_data = rand(ffjord_dist, 100) ``` @@ -65,17 +62,17 @@ new_data = rand(ffjord_dist, 100) We can use DiffEqFlux.jl to define, train and output the densities computed by CNF layers. In the same way as a neural ODE, the layer takes a neural network that defines its derivative function (see [1] for a reference). A possible way to define a CNF layer, would be: ```@example cnf2 -using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimisers, - OptimizationOptimJL, Distributions +using ComponentArrays, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimisers, + OptimizationOptimJL, Distributions, Random -nn = Flux.Chain( - Flux.Dense(1, 3, tanh), - Flux.Dense(3, 1, tanh), -) |> f32 +nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) tspan = (0.0f0, 10.0f0) -ffjord_mdl = FFJORD(nn, tspan, Tsit5()) -nothing +ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote()) +ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) +ps = ComponentArray(ps) +model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, ps, st) +ffjord_mdl ``` where we also pass as an input the desired timespan for which the differential equation that defines `log p_x` and `z(t)` will be solved. @@ -94,13 +91,13 @@ Now we define a loss function that we wish to minimize and a callback function t ```@example cnf2 function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ) - -mean(logpx) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end function cb(p, l) - @info "Training" loss = loss(p) - false + @info "FFJORD Training" loss=loss(p) + return false end ``` @@ -111,24 +108,20 @@ We then train the neural network to learn the distribution of `x`. Here we showcase starting the optimization with `Adam` to more quickly find a minimum, and then honing in on the minimum by using `LBFGS`. ```@example cnf2 -adtype = Optimization.AutoZygote() + +adtype = Optimization.AutoForwardDiff() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) -optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) +optprob = Optimization.OptimizationProblem(optf, ps) -res1 = Optimization.solve(optprob, - Adam(0.1), - maxiters = 100, - callback=cb) +res1 = Optimization.solve(optprob, Adam(0.01); maxiters = 20, callback = cb) ``` We then complete the training using a different optimizer, starting from where `Adam` stopped. ```@example cnf2 optprob2 = Optimization.OptimizationProblem(optf, res1.u) -res2 = Optimization.solve(optprob2, - Optim.LBFGS(), - allow_f_increases=false, - callback=cb) +res2 = Optimization.solve(optprob2, Optim.LBFGS(); allow_f_increases = false, + callback = cb) ``` ### Evaluation @@ -139,8 +132,10 @@ Then we use a distance function between these distributions. ```@example cnf2 using Distances +st_ = (; st..., monte_carlo = false) + actual_pdf = pdf.(data_dist, train_data) -learned_pdf = exp.(ffjord_mdl(train_data, res2.u, monte_carlo=false)[1]) +learned_pdf = exp.(ffjord_mdl(train_data, res2.u, st_)[1][1]) train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2) ``` @@ -149,7 +144,7 @@ train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2) What's more, we can generate new data by using FFJORD as a distribution in `rand`. ```@example cnf2 -ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u)) +ffjord_dist = FFJORDDistribution(ffjord_mdl, ps, st) new_data = rand(ffjord_dist, 100) ``` diff --git a/docs/src/examples/tensor_layer.md b/docs/src/examples/tensor_layer.md index c13e828316..21b6860141 100644 --- a/docs/src/examples/tensor_layer.md +++ b/docs/src/examples/tensor_layer.md @@ -13,19 +13,21 @@ To obtain the training data, we solve the equation of motion using one of the solvers in `DifferentialEquations`: ```@example tensor -using DiffEqFlux, Optimization, OptimizationOptimisers, DifferentialEquations, LinearAlgebra +using ComponentArrays, + DiffEqFlux, Optimization, OptimizationOptimisers, + OrdinaryDiffEq, LinearAlgebra, Random k, α, β, γ = 1, 0.1, 0.2, 0.3 -tspan = (0.0,10.0) +tspan = (0.0, 10.0) -function dxdt_train(du,u,p,t) - du[1] = u[2] - du[2] = -k*u[1] - α*u[1]^3 - β*u[2] - γ*u[2]^3 +function dxdt_train(du, u, p, t) + du[1] = u[2] + du[2] = -k * u[1] - α * u[1]^3 - β * u[2] - γ * u[2]^3 end -u0 = [1.0,0.0] +u0 = [1.0, 0.0] ts = collect(0.0:0.1:tspan[2]) -prob_train = ODEProblem{true}(dxdt_train,u0,tspan) -data_train = Array(solve(prob_train,Tsit5(),saveat=ts)) +prob_train = ODEProblem{true}(dxdt_train, u0, tspan) +data_train = Array(solve(prob_train, Tsit5(); saveat = ts)) ``` Now, we create a TensorLayer that will be able to perform 10th order expansions in @@ -34,22 +36,26 @@ a Legendre Basis: ```@example tensor A = [LegendreBasis(10), LegendreBasis(10)] nn = TensorLayer(A, 1) +ps, st = Lux.setup(Random.default_rng(), nn) +ps = ComponentArray(ps) +nn = Lux.Experimental.StatefulLuxLayer(nn, nothing, st) ``` and we also instantiate the model we are trying to learn, “informing” the neural about the `∝x` and `∝v` dependencies in the equation of motion: ```@example tensor -f = x -> min(30one(x),x) +f = x -> min(30one(x), x) -function dxdt_pred(du,u,p,t) - du[1] = u[2] - du[2] = -p[1]*u[1] - p[2]*u[2] + f(nn(u,p[3:end])[1]) +function dxdt_pred(du, u, p, t) + du[1] = u[2] + du[2] = -p.p_model[1] * u[1] - p.p_model[2] * u[2] + f(nn(u, p.ps)[1]) end -α = zeros(102) +p_model = zeros(2) +α = ComponentArray(; p_model, ps = ps .* 0) -prob_pred = ODEProblem{true}(dxdt_pred,u0,tspan) +prob_pred = ODEProblem{true}(dxdt_pred, u0, tspan, α) ``` Note that we introduced a “cap” in the neural network term to avoid instabilities @@ -60,24 +66,24 @@ Finally, we introduce the corresponding loss function: ```@example tensor function predict_adjoint(θ) - x = Array(solve(prob_pred,Tsit5(),p=θ,saveat=ts, - sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))) + x = Array(solve(prob_pred, Tsit5(); p = θ, saveat = ts, + sensealg = InterpolatingAdjoint(; autojacvec = ReverseDiffVJP(true)))) end function loss_adjoint(θ) - x = predict_adjoint(θ) - loss = sum(norm.(x - data_train)) - return loss + x = predict_adjoint(θ) + loss = sum(norm.(x - data_train)) + return loss end iter = 0 -function callback(θ,l) - global iter - iter += 1 - if iter%10 == 0 - println(l) - end - return false +function callback(θ, l) + global iter + iter += 1 + if iter % 10 == 0 + println(l) + end + return false end ``` @@ -85,12 +91,12 @@ and we train the network using two rounds of `Adam`: ```@example tensor adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((x,p) -> loss_adjoint(x), adtype) +optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype) optprob = Optimization.OptimizationProblem(optf, α) -res1 = Optimization.solve(optprob, Adam(0.05), callback = callback, maxiters = 150) +res1 = Optimization.solve(optprob, Adam(0.05); callback = callback, maxiters = 150) optprob2 = Optimization.OptimizationProblem(optf, res1.u) -res2 = Optimization.solve(optprob2, Adam(0.001), callback = callback,maxiters = 150) +res2 = Optimization.solve(optprob2, Adam(0.001); callback = callback, maxiters = 150) opt = res2.u ``` @@ -99,10 +105,10 @@ We plot the results, and we obtain a fairly accurate learned model: ```@example tensor using Plots data_pred = predict_adjoint(res1.u) -plot(ts, data_train[1,:], label = "X (ODE)") -plot!(ts, data_train[2,:], label = "V (ODE)") -plot!(ts, data_pred[1,:], label = "X (NN)") -plot!(ts, data_pred[2,:],label = "V (NN)") +plot(ts, data_train[1, :]; label = "X (ODE)") +plot!(ts, data_train[2, :]; label = "V (ODE)") +plot!(ts, data_pred[1, :]; label = "X (NN)") +plot!(ts, data_pred[2, :]; label = "V (NN)") ``` ![plot_tutorial](https://user-images.githubusercontent.com/61364108/85925795-e2d5e680-b868-11ea-9816-29f8125c8cb5.png) diff --git a/docs/src/index.md b/docs/src/index.md index 6b95bdfd87..db5ad93b80 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,7 +5,7 @@ a high-level interface that pulls together all the tools with heuristics and helper functions to make training such deep implicit layer models fast and easy. !!! note - + DiffEqFlux.jl is only for pre-built architectures and utility functions for deep implicit learning, mixing differential equations with machine learning. For details on automatic differentiation of equation solvers @@ -25,15 +25,15 @@ into larger machine learning applications. The following layer functions exist: -- [Neural Ordinary Differential Equations (Neural ODEs)](https://arxiv.org/abs/1806.07366) -- [Collocation-Based Neural ODEs (Neural ODEs without a solver, by far the fastest way!)](https://www.degruyter.com/document/doi/10.1515/sagmb-2020-0025/html) -- [Multiple Shooting Neural Ordinary Differential Equations](https://arxiv.org/abs/2109.06786) -- [Neural Stochastic Differential Equations (Neural SDEs)](https://arxiv.org/abs/1907.07587) -- [Neural Differential-Algebriac Equations (Neural DAEs)](https://arxiv.org/abs/2001.04385) -- [Neural Delay Differential Equations (Neural DDEs)](https://arxiv.org/abs/2001.04385) -- [Augmented Neural ODEs](https://arxiv.org/abs/1904.01681) -- [Hamiltonian Neural Networks (with specialized second order and symplectic integrators)](https://arxiv.org/abs/1906.01563) -- [Continuous Normalizing Flows (CNF)](https://arxiv.org/abs/1806.07366) and [FFJORD](https://arxiv.org/abs/1810.01367) + - [Neural Ordinary Differential Equations (Neural ODEs)](https://arxiv.org/abs/1806.07366) + - [Collocation-Based Neural ODEs (Neural ODEs without a solver, by far the fastest way!)](https://www.degruyter.com/document/doi/10.1515/sagmb-2020-0025/html) + - [Multiple Shooting Neural Ordinary Differential Equations](https://arxiv.org/abs/2109.06786) + - [Neural Stochastic Differential Equations (Neural SDEs)](https://arxiv.org/abs/1907.07587) + - [Neural Differential-Algebriac Equations (Neural DAEs)](https://arxiv.org/abs/2001.04385) + - [Neural Delay Differential Equations (Neural DDEs)](https://arxiv.org/abs/2001.04385) + - [Augmented Neural ODEs](https://arxiv.org/abs/1904.01681) + - [Hamiltonian Neural Networks (with specialized second order and symplectic integrators)](https://arxiv.org/abs/1906.01563) + - [Continuous Normalizing Flows (CNF)](https://arxiv.org/abs/1806.07366) and [FFJORD](https://arxiv.org/abs/1810.01367) Examples of how to build architectures from scratch, with tutorials on things like Graph Neural ODEs, can be found in the [SciMLSensitivity.jl documentation](https://docs.sciml.ai/SciMLSensitivity/stable/). @@ -57,8 +57,8 @@ While one may think this recreates the neural network to act in `Float64` precis and instead its values will silently downgrade everything to `Float32`. This is only fixed by `Chain(Dense(2, 10, tanh), Dense(10, 1)) |> f64`. Similar cases will [lead to dropped gradients with complex numbers](https://github.com/FluxML/Optimisers.jl/issues/95). This is not an issue with the automatic differentiation library commonly associated with Flux (Zygote.jl) but rather due to choices in the neural network library's -decision for how to approach type handling and precision. Thus when using DiffEqFlux.jl with Flux, the user must be very careful to ensure that -the precision of the arguments are correct, and anything that requires alternative types (like `TrackerAdjoint` tracked values and +decision for how to approach type handling and precision. Thus when using DiffEqFlux.jl with Flux, the user must be very careful to ensure that +the precision of the arguments are correct, and anything that requires alternative types (like `TrackerAdjoint` tracked values and `ForwardDiffSensitivity` dual numbers) are suspect. Lux.jl has none of these issues, is simpler to work with due to the parameters in its function calls being explicit rather than implicit global @@ -80,36 +80,46 @@ If you use DiffEqFlux.jl or are influenced by its ideas, please cite: ``` ## Reproducibility + ```@raw html
The documentation of this SciML package was built using these direct dependencies, ``` + ```@example using Pkg # hide Pkg.status() # hide ``` + ```@raw html
``` + ```@raw html
and using this machine and Julia version. ``` + ```@example using InteractiveUtils # hide versioninfo() # hide ``` + ```@raw html
``` + ```@raw html
A more complete overview of all dependencies and their versions is also provided. ``` + ```@example using Pkg # hide -Pkg.status(;mode = PKGMODE_MANIFEST) # hide +Pkg.status(; mode = PKGMODE_MANIFEST) # hide ``` + ```@raw html
``` + ```@eval using TOML using Markdown @@ -125,4 +135,4 @@ file and the [project]($link_project) file. """) -``` \ No newline at end of file +``` diff --git a/docs/src/layers/HamiltonianNN.md b/docs/src/layers/HamiltonianNN.md index 30a24333f2..d3216da587 100644 --- a/docs/src/layers/HamiltonianNN.md +++ b/docs/src/layers/HamiltonianNN.md @@ -6,4 +6,4 @@ dynamics and conservation laws by approximating the hamiltonian of a system. ```@docs HamiltonianNN NeuralHamiltonianDE -``` \ No newline at end of file +``` diff --git a/docs/src/layers/NeuralDELayers.md b/docs/src/layers/NeuralDELayers.md index e641ce48d6..717329b5c9 100644 --- a/docs/src/layers/NeuralDELayers.md +++ b/docs/src/layers/NeuralDELayers.md @@ -20,5 +20,4 @@ AugmentedNDELayer ```@docs DimMover -FluxBatchOrder ``` diff --git a/docs/src/layers/SplineLayer.md b/docs/src/layers/SplineLayer.md index 9e14bc186c..77f9607a3e 100644 --- a/docs/src/layers/SplineLayer.md +++ b/docs/src/layers/SplineLayer.md @@ -2,11 +2,11 @@ Constructs a Spline Layer. At a high-level, it performs the following: -1. Takes as input a one-dimensional training dataset, a time span, a time step and - an interpolation method. -2. During training, adjusts the values of the function at multiples of the time-step - such that the curve interpolated through these points has minimum loss on the corresponding - one-dimensional dataset. + 1. Takes as input a one-dimensional training dataset, a time span, a time step and + an interpolation method. + 2. During training, adjusts the values of the function at multiples of the time-step + such that the curve interpolated through these points has minimum loss on the corresponding + one-dimensional dataset. ```@docs SplineLayer diff --git a/docs/src/utilities/Collocation.md b/docs/src/utilities/Collocation.md index 721e71c654..13168a1343 100644 --- a/docs/src/utilities/Collocation.md +++ b/docs/src/utilities/Collocation.md @@ -10,7 +10,7 @@ extremely fast and robust to noise, though, because it does not accumulate through time, is not as exact as other methods. !!! note - + This is one of many methods for calculating the collocation coefficients for the training process. For a more comprehensive set of collocation methods, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/manual/collocation/). @@ -24,7 +24,7 @@ collocate_data Note that the kernel choices of DataInterpolations.jl, such as `CubicSpline()`, are exact, i.e. go through the data points, while the smoothed kernels are regression splines. Thus `CubicSpline()` is preferred if the data is not too -noisy or is relatively sparse. If data is sparse and very noisy, a `BSpline()` +noisy or is relatively sparse. If data is sparse and very noisy, a `BSpline()` can be the best regression spline, otherwise one of the other kernels such as as `EpanechnikovKernel`. @@ -36,22 +36,26 @@ is non-allocating and compatible with forward-mode automatic differentiation: ```julia using PreallocationTools du = PreallocationTools.dualcache(similar(prob.u0)) -preview_est_sol = [@view estimated_solution[:,i] for i in 1:size(estimated_solution,2)] -preview_est_deriv = [@view estimated_derivative[:,i] for i in 1:size(estimated_solution,2)] - -function construct_iip_cost_function(f,du,preview_est_sol,preview_est_deriv,tpoints) - function (p) - _du = PreallocationTools.get_tmp(du,p) - vecdu = vec(_du) - cost = zero(first(p)) - for i in 1:length(preview_est_sol) - est_sol = preview_est_sol[i] - f(_du,est_sol,p,tpoints[i]) - vecdu .= vec(preview_est_deriv[i]) .- vec(_du) - cost += sum(abs2,vecdu) - end - sqrt(cost) - end +preview_est_sol = [@view estimated_solution[:, i] for i in 1:size(estimated_solution, 2)] +preview_est_deriv = [@view estimated_derivative[:, i] for i in 1:size(estimated_solution, 2)] + +function construct_iip_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints) + function (p) + _du = PreallocationTools.get_tmp(du, p) + vecdu = vec(_du) + cost = zero(first(p)) + for i in 1:length(preview_est_sol) + est_sol = preview_est_sol[i] + f(_du, est_sol, p, tpoints[i]) + vecdu .= vec(preview_est_deriv[i]) .- vec(_du) + cost += sum(abs2, vecdu) + end + sqrt(cost) + end end -cost_function = construct_iip_cost_function(f,du,preview_est_sol,preview_est_deriv,tpoints) +cost_function = construct_iip_cost_function(f, + du, + preview_est_sol, + preview_est_deriv, + tpoints) ``` diff --git a/docs/src/utilities/MultipleShooting.md b/docs/src/utilities/MultipleShooting.md index 78018ab67e..28fa796165 100644 --- a/docs/src/utilities/MultipleShooting.md +++ b/docs/src/utilities/MultipleShooting.md @@ -1,8 +1,10 @@ # Multiple Shooting Functionality !!! note - The form of multiple shooting found here is a specialized form for implicit layer deep learning (known as data shooting) which assumes full observability of the underlying dynamics and lack of noise. For a more general implementation of multiple shooting, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/). For an implementation more directly tied to parameter estimation against data, see [DiffEqParamEstim.jl](https://docs.sciml.ai/DiffEqParamEstim/stable/). + The form of multiple shooting found here is a specialized form for implicit layer deep learning (known as data shooting) which assumes full observability of the underlying dynamics and lack of noise. For a more general implementation of multiple shooting, see the [JuliaSimModelOptimizer](https://help.juliahub.com/jsmo/stable/). For an implementation more directly tied to parameter estimation against data, see [DiffEqParamEstim.jl](https://docs.sciml.ai/DiffEqParamEstim/stable/). + ```@docs multiple_shoot +DiffEqFlux.group_ranges ``` diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index 204b0903c3..5c895d657e 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -1,55 +1,62 @@ module DiffEqFlux -using Adapt, Base.Iterators, ChainRulesCore, ConsoleProgressMonitor, - DataInterpolations, DiffEqBase, Distributions, DistributionsAD, - ForwardDiff, Functors, LinearAlgebra, Logging, LoggingExtras, LuxCore, - Printf, ProgressLogging, Random, RecursiveArrayTools, Reexport, - SciMLBase, TerminalLoggers, Zygote, ZygoteRules +import PrecompileTools -@reexport using Flux -@reexport using SciMLSensitivity +PrecompileTools.@recompile_invalidations begin + using ADTypes, ChainRulesCore, ComponentArrays, ConcreteStructs, Functors, + LinearAlgebra, Lux, LuxCore, Random, Reexport, SciMLBase, SciMLSensitivity -gpu_or_cpu(x) = Array + # AD Packages + using ForwardDiff, Tracker, Zygote -# ForwardDiff integration - -ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where {T} - @assert length(ẋ) == 1 - ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,)) + # FFJORD Specific + using Distributions, DistributionsAD end -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where {T} = - d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),) +import ChainRulesCore as CRC +import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer +import Lux.Experimental: StatefulLuxLayer -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where {T} = - d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),) +@reexport using ADTypes, Lux -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) -ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end])) +# FIXME: Type Piracy +function CRC.rrule(::Type{Tridiagonal}, dl, d, du) + y = Tridiagonal(dl, d, du) + @views function ∇Tridiagonal(∂y) + return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y), + diag(∂y[1:(end - 1), 2:end])) + end + return y, ∇Tridiagonal +end include("ffjord.jl") include("neural_de.jl") include("spline_layer.jl") -include("tensor_product_basis.jl") -include("tensor_product_layer.jl") +include("tensor_product.jl") include("collocation.jl") include("hnn.jl") include("multiple_shooting.jl") -export FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, - NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE -export HamiltonianNN -export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis -export FFJORDDistribution -export DimMover, FluxBatchOrder - -export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel -export TriweightKernel, TricubeKernel, GaussianKernel, CosineKernel -export LogisticKernel, SigmoidKernel, SilvermanKernel +export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer, + NeuralODEMM, TensorLayer, SplineLayer +export NeuralHamiltonianDE, HamiltonianNN +export FFJORD, FFJORDDistribution +export TensorProductBasisFunction, + ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis +export DimMover + +export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel, TriweightKernel, + TricubeKernel, GaussianKernel, CosineKernel, LogisticKernel, SigmoidKernel, + SilvermanKernel export collocate_data export multiple_shoot +# Reexporting only certain functions from SciMLSensitivity +export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint, + TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, ForwardSensitivity, + ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, + ForwardLSS, AdjointLSS, NILSS, NILSAS +export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP + end diff --git a/src/collocation.jl b/src/collocation.jl index 93f7a16b17..c07a8a6e8d 100644 --- a/src/collocation.jl +++ b/src/collocation.jl @@ -11,101 +11,42 @@ struct LogisticKernel <: CollocationKernel end struct SigmoidKernel <: CollocationKernel end struct SilvermanKernel <: CollocationKernel end -function calckernel(::EpanechnikovKernel,t) - if abs(t) > 1 - return 0 - else - return 0.75*(1-t^2) - end -end - -function calckernel(::UniformKernel,t) - if abs(t) > 1 - return 0 - else - return 0.5 - end -end - -function calckernel(::TriangularKernel,t) - if abs(t) > 1 - return 0 - else - return (1-abs(t)) - end -end - -function calckernel(::QuarticKernel,t) - if abs(t) > 1 - return 0 - else - return (15*(1-t^2)^2)/16 - end -end - -function calckernel(::TriweightKernel,t) - if abs(t) > 1 - return 0 - else - return (35*(1-t^2)^3)/32 - end -end - -function calckernel(::TricubeKernel,t) - if abs(t) > 1 - return 0 - else - return (70*(1-abs(t)^3)^3)/81 - end +function calckernel(kernel, t::T) where {T} + abst = abs(t) + return ifelse(abst > 1, T(0), calckernel(kernel, t, abst)) end +calckernel(::EpanechnikovKernel, t::T, abst::T) where {T} = T(0.75) * (T(1) - t^2) +calckernel(::UniformKernel, t::T, abst::T) where {T} = T(0.5) +calckernel(::TriangularKernel, t::T, abst::T) where {T} = T(1) - abst +calckernel(::QuarticKernel, t::T, abst::T) where {T} = T(15) * (T(1) - t^2)^2 / T(16) +calckernel(::TriweightKernel, t::T, abst::T) where {T} = T(35) * (T(1) - t^2)^3 / T(32) +calckernel(::TricubeKernel, t::T, abst::T) where {T} = T(70) * (T(1) - abst^3)^3 / T(81) +calckernel(::CosineKernel, t::T, abst::T) where {T} = T(π) * cospi(t / T(2)) / T(4) -function calckernel(::GaussianKernel,t) - exp(-0.5*t^2)/(sqrt(2*π)) +calckernel(::GaussianKernel, t::T) where {T} = exp(-t^2 / T(2)) / sqrt(T(2) * π) +calckernel(::LogisticKernel, t::T) where {T} = T(1) / (exp(t) + T(2) + exp(-t)) +calckernel(::SigmoidKernel, t::T) where {T} = T(2) / (π * (exp(t) + exp(-t))) +function calckernel(::SilvermanKernel, t::T) where {T} + return sin(abs(t) / T(2) + π / T(4)) * T(0.5) * exp(-abs(t) / sqrt(T(2))) end -function calckernel(::CosineKernel,t) - if abs(t) > 1 - return 0 - else - return (π*cos(π*t/2))/4 - end -end - -function calckernel(::LogisticKernel,t) - 1/(exp(t)+2+exp(-t)) -end - -function calckernel(::SigmoidKernel,t) - 2/(π*(exp(t)+exp(-t))) -end - -function calckernel(::SilvermanKernel,t) - sin(abs(t)/2+π/4)*0.5*exp(-abs(t)/sqrt(2)) -end +construct_t1(t, tpoints) = hcat(ones(eltype(tpoints), length(tpoints)), tpoints .- t) -function construct_t1(t,tpoints) - hcat(ones(eltype(tpoints),length(tpoints)),tpoints.-t) +function construct_t2(t, tpoints) + return hcat(ones(eltype(tpoints), length(tpoints)), tpoints .- t, (tpoints .- t) .^ 2) end -function construct_t2(t,tpoints) - hcat(ones(eltype(tpoints),length(tpoints)),tpoints.-t,(tpoints.-t).^2) +function construct_w(t, tpoints, h, kernel) + W = @. calckernel((kernel,), ((tpoints - t) / (tpoints[end] - tpoints[begin])) / h) / h + return Diagonal(W) end -function construct_w(t,tpoints,h,kernel) - W = @. calckernel((kernel,),((tpoints-t)/(tpoints[end]-tpoints[begin]))/h)/h - Diagonal(W) -end - - """ -```julia -u′,u = collocate_data(data,tpoints,kernel=TriangularKernel(),bandwidth=nothing) -u′,u = collocate_data(data,tpoints,tpoints_sample,interp,args...) -``` + u′, u = collocate_data(data, tpoints, kernel = TriangularKernel(), bandwidth=nothing) + u′, u = collocate_data(data, tpoints, tpoints_sample, interp, args...) -Computes a non-parametrically smoothed estimate of `u'` and `u` -given the `data`, where each column is a snapshot of the timeseries at -`tpoints[i]`. +Computes a non-parametrically smoothed estimate of `u'` and `u` given the `data`, where each +column is a snapshot of the timeseries at `tpoints[i]`. For kernels, the following exist: @@ -128,50 +69,52 @@ Additionally, we can use interpolation methods from data from intermediate timesteps. In this case, pass any of the methods like `QuadraticInterpolation` as `interp`, and the timestamps to sample from as `tpoints_sample`. """ -function collocate_data(data, tpoints, kernel=TriangularKernel(), bandwidth=nothing) - _one = oneunit(first(data)) - _zero = zero(first(data)) - e1 = [_one;_zero] - e2 = [_zero;_one;_zero] - n = length(tpoints) - bandwidth = isnothing(bandwidth) ? (n^(-1/5))*(n^(-3/35))*((log(n))^(-1/16)) : bandwidth - - Wd = similar(data, n, size(data,1)) - WT1 = similar(data, n, 2) - WT2 = similar(data, n, 3) - T2WT2 = similar(data, 3, 3) - T1WT1 = similar(data, 2, 2) - x = map(tpoints) do _t - T1 = construct_t1(_t,tpoints) - T2 = construct_t2(_t,tpoints) - W = construct_w(_t,tpoints,bandwidth,kernel) - mul!(Wd,W,data') - mul!(WT1,W,T1) - mul!(WT2,W,T2) - mul!(T2WT2,T2',WT2) - mul!(T1WT1,T1',WT1) - (det(T2WT2) ≈ 0.0 || det(T1WT1) ≈ 0.0) && error("Collocation failed with bandwidth $bandwidth. Please choose a higher bandwidth") - (e2'*((T2'*WT2)\T2'))*Wd,(e1'*((T1'*WT1)\T1'))*Wd - end - estimated_derivative = reduce(hcat,transpose.(first.(x))) - estimated_solution = reduce(hcat,transpose.(last.(x))) - estimated_derivative,estimated_solution +function collocate_data(data, tpoints, kernel = TriangularKernel(), bandwidth = nothing) + _one = oneunit(first(data)) + _zero = zero(first(data)) + e1 = [_one; _zero] + e2 = [_zero; _one; _zero] + n = length(tpoints) + bandwidth = bandwidth === nothing ? + (n^(-1 / 5)) * (n^(-3 / 35)) * ((log(n))^(-1 / 16)) : bandwidth + + Wd = similar(data, n, size(data, 1)) + WT1 = similar(data, n, 2) + WT2 = similar(data, n, 3) + T2WT2 = similar(data, 3, 3) + T1WT1 = similar(data, 2, 2) + x = map(tpoints) do _t + T1 = construct_t1(_t, tpoints) + T2 = construct_t2(_t, tpoints) + W = construct_w(_t, tpoints, bandwidth, kernel) + mul!(Wd, W, data') + mul!(WT1, W, T1) + mul!(WT2, W, T2) + mul!(T2WT2, T2', WT2) + mul!(T1WT1, T1', WT1) + (det(T2WT2) ≈ 0.0 || det(T1WT1) ≈ 0.0) && + error("Collocation failed with bandwidth $bandwidth. Please choose a higher bandwidth") + (e2' * ((T2' * WT2) \ T2')) * Wd, (e1' * ((T1' * WT1) \ T1')) * Wd + end + estimated_derivative = mapreduce(xᵢ -> transpose(first(xᵢ)), hcat, x) + estimated_solution = mapreduce(xᵢ -> transpose(last(xᵢ)), hcat, x) + return estimated_derivative, estimated_solution end -function collocate_data(data::AbstractVector,tpoints::AbstractVector,tpoints_sample::AbstractVector, - interp,args...) - du, u = collocate_data(reshape(data, 1, :),tpoints,tpoints_sample,interp,args...) - return du[1, :], u[1, :] +@views function collocate_data(data::AbstractVector, tpoints::AbstractVector, + tpoints_sample::AbstractVector, interp, args...) + du, u = collocate_data(reshape(data, 1, :), tpoints, tpoints_sample, interp, args...) + return du[1, :], u[1, :] end -function collocate_data(data::AbstractMatrix{T},tpoints::AbstractVector{T}, - tpoints_sample::AbstractVector{T},interp,args...) where T - u = zeros(T,size(data, 1),length(tpoints_sample)) - du = zeros(T,size(data, 1),length(tpoints_sample)) - for d1 in axes(data, 1) - interpolation = interp(data[d1,:],tpoints,args...) - u[d1,:] .= interpolation.(tpoints_sample) - du[d1,:] .= DataInterpolations.derivative.((interpolation,), tpoints_sample) - end - return du, u +@views function collocate_data(data::AbstractMatrix{T}, tpoints::AbstractVector{T}, + tpoints_sample::AbstractVector{T}, interp, args...) where {T} + u = zeros(T, size(data, 1), length(tpoints_sample)) + du = zeros(T, size(data, 1), length(tpoints_sample)) + for d1 in axes(data, 1) + interpolation = interp(data[d1, :], tpoints, args...) + u[d1, :] .= interpolation.(tpoints_sample) + du[d1, :] .= DataInterpolations.derivative.((interpolation,), tpoints_sample) + end + return du, u end diff --git a/src/ffjord.jl b/src/ffjord.jl index 718f15e634..f9d25f18ce 100644 --- a/src/ffjord.jl +++ b/src/ffjord.jl @@ -1,9 +1,9 @@ abstract type CNFLayer <: LuxCore.AbstractExplicitContainerLayer{(:model,)} end -Flux.trainable(m::CNFLayer) = (m.p,) - -rng = Random.default_rng() """ + FFJORD(model, tspan, input_dims, args...; ad = AutoForwardDiff(), + basedist = nothing, kwargs...) + Constructs a continuous-time recurrent neural network, also known as a neural ordinary differential equation (neural ODE), with fast gradient calculation via adjoints [1] and specialized for density estimation based on continuous @@ -16,13 +16,15 @@ of the dynamics' jacobian. At a high level this corresponds to the following ste After these steps one may use the NN model and the learned θ to predict the density p_x for new values of x. -```julia -FFJORD(model, basedist=nothing, monte_carlo=false, tspan, args...; kwargs...) -``` Arguments: - `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the dynamics of the model. - `basedist`: Distribution of the base variable. Set to the unit normal by default. +- `input_dims`: Input Dimensions of the model. - `tspan`: The timespan to be solved on. +- `args`: Additional arguments splatted to the ODE solver. See the + [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) + documentation for more details. +- `ad`: The automatic differentiation method to use for the internal jacobian trace. Defaults to `AutoForwardDiff()`. - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. @@ -36,214 +38,236 @@ References: [3] Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018). """ -struct FFJORD{M, P, RE, D, T, A, K} <: CNFLayer where {M, P <: Union{AbstractVector{<: AbstractFloat}, Nothing}, RE <: Union{Function, Nothing}, D <: Distribution, T, A, K} +@concrete struct FFJORD{M <: AbstractExplicitLayer, D <: Union{Nothing, Distribution}} <: + CNFLayer model::M - p::P - re::RE basedist::D - tspan::T - args::A - kwargs::K - - function FFJORD(model::LuxCore.AbstractExplicitLayer,tspan,args...;p=nothing,basedist=nothing,kwargs...) - re = nothing - p = nothing - if isnothing(basedist) - size_input = model.layers.layer_1.in_dims - type_input = eltype(tspan) - basedist = MvNormal(zeros(type_input, size_input), Diagonal(ones(type_input, size_input))) - end - new{typeof(model),typeof(p),typeof(re), - typeof(basedist),typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,basedist,tspan,args,kwargs) - end - - function FFJORD(model, tspan, args...;p=nothing, basedist=nothing, kwargs...) - # - _p, re = Flux.destructure(model) - if isnothing(p) - p = _p - end - if isnothing(basedist) - size_input = size(model[1].weight, 2) - type_input = eltype(model[1].weight) - basedist = MvNormal(zeros(type_input, size_input), Diagonal(ones(type_input, size_input))) - end - new{typeof(model),typeof(p),typeof(re), - typeof(basedist),typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,basedist,tspan,args,kwargs) - end + ad + input_dims + tspan + args + kwargs end -_norm_batched(x::AbstractMatrix) = sqrt.(sum(x.^2, dims=1)) +function LuxCore.initialstates(rng::AbstractRNG, n::FFJORD) + return (; + model = LuxCore.initialstates(rng, n.model), regularize = false, monte_carlo = true) +end -function jacobian_fn(f, x::AbstractVector, args...) - y::AbstractVector, back = Zygote.pullback(f, x) - ȳ(i) = [i == j for j = 1:length(y)] - vcat([transpose(back(ȳ(i))[1]) for i = 1:length(y)]...) +function FFJORD(model, tspan, input_dims, args...; ad = AutoForwardDiff(), + basedist = nothing, kwargs...) + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return FFJORD(model, basedist, ad, input_dims, tspan, args, kwargs) end -function jacobian_fn(f, x::AbstractMatrix, args...) - y, back = Zygote.pullback(f, x) - z = ChainRulesCore.@ignore_derivatives similar(y) - ChainRulesCore.@ignore_derivatives fill!(z, zero(eltype(x))) - vec = Zygote.Buffer(x, size(x, 1), size(x, 1), size(x, 2)) - for i in 1:size(y, 1) - ChainRulesCore.@ignore_derivatives z[i, :] .= one(eltype(x)) - vec[i, :, :] = back(z)[1] - ChainRulesCore.@ignore_derivatives z[i, :] .= zero(eltype(x)) +function __jacobian_with_ps(model, psax, N, x) + function __jacobian_closure(psx) + x_ = reshape(psx[1:N], size(x)) + ps = ComponentArray(psx[(N + 1):end], psax) + return vec(model(x_, ps)) end - copy(vec) end -function jacobian_fn(f::LuxCore.AbstractExplicitLayer, x::AbstractMatrix, args...) - p,st = args - y, back = Zygote.pullback((z,ps,s)->f(z,ps,s)[1], x, p, st) - z = ChainRulesCore.@ignore_derivatives similar(y) - ChainRulesCore.@ignore_derivatives fill!(z, zero(eltype(x))) - vec = Zygote.Buffer(x, size(x, 1), size(x, 1), size(x, 2)) +function __jacobian(::AutoForwardDiff{nothing}, model, x::AbstractMatrix, + ps::ComponentArray) + psd = getdata(ps) + psx = vcat(vec(x), psd) + N = length(x) + J = ForwardDiff.jacobian(__jacobian_with_ps(model, getaxes(ps), N, x), psx) + return reshape(view(J, :, 1:N), :, size(x, 1), size(x, 2)) +end + +function __jacobian(::AutoForwardDiff{CS}, model, x::AbstractMatrix, ps) where {CS} + chunksize = CS === nothing ? ForwardDiff.pickchunksize(length(x)) : CS + __f = Base.Fix2(model, ps) + cfg = ForwardDiff.JacobianConfig(__f, x, ForwardDiff.Chunk{chunksize}()) + return reshape(ForwardDiff.jacobian(__f, x, cfg), :, size(x, 1), size(x, 2)) +end + +function __jacobian(::AutoZygote, model, x::AbstractMatrix, ps) + y, pb_f = Zygote.pullback(vec ∘ model, x, ps) + z = ChainRulesCore.@ignore_derivatives fill!(similar(y), __one(y)) + J = Zygote.Buffer(x, size(y, 1), size(x, 1), size(x, 2)) for i in 1:size(y, 1) - ChainRulesCore.@ignore_derivatives z[i, :] .= one(eltype(x)) - vec[i, :, :] = back(z)[1] - ChainRulesCore.@ignore_derivatives z[i, :] .= zero(eltype(x)) + ChainRulesCore.@ignore_derivatives z[i, :] .= __one(x) + J[i, :, :] = pb_f(z)[1] + ChainRulesCore.@ignore_derivatives z[i, :] .= __zero(x) end - copy(vec) + return copy(J) end -_trace_batched(x::AbstractArray{T, 3}) where T = - reshape([tr(x[:, :, i]) for i in 1:size(x, 3)], 1, size(x, 3)) +__one(::T) where {T <: Real} = one(T) +__one(x::T) where {T <: AbstractArray} = __one(first(x)) +__one(::Tracker.TrackedReal{T}) where {T <: Real} = one(T) -function ffjord(u, p, t, re, e, st; - regularize=false, monte_carlo=true) - m = re(p) - if regularize - z = u[1:end - 3, :] - if monte_carlo - mz, back = Zygote.pullback(m, z) - eJ = back(e)[1] - trace_jac = sum(eJ .* e, dims=1) - else - mz = m(z) - trace_jac = _trace_batched(jacobian_fn(m, z)) - end - vcat(mz, -trace_jac, sum(abs2, mz, dims=1), _norm_batched(eJ)) +__zero(::T) where {T <: Real} = zero(T) +__zero(x::T) where {T <: AbstractArray} = __zero(first(x)) +__zero(::Tracker.TrackedReal{T}) where {T <: Real} = zero(T) + +function _jacobian(ad, model, x, ps) + if ndims(x) == 1 + x_ = reshape(x, :, 1) + elseif ndims(x) > 2 + x_ = reshape(x, :, size(x, ndims(x))) else - z = u[1:end - 1, :] - if monte_carlo - mz, back = Zygote.pullback(m, z) - eJ = back(e)[1] - trace_jac = sum(eJ .* e, dims=1) - else - mz = m(z) - trace_jac = _trace_batched(jacobian_fn(m, z)) - end - vcat(mz, -trace_jac) + x_ = x end + return __jacobian(ad, model, x_, ps) end -function ffjord(u, p, t, re::LuxCore.AbstractExplicitLayer, e, st; - regularize=false, monte_carlo=true) +# This implementation constructs the final trace vector on the correct device +function __trace_batched(x::AbstractArray{T, 3}) where {T} + __diag(x) = reshape(@view(x[diagind(x)]), :, 1) + return sum(reduce(hcat, __diag.(eachslice(x; dims = 3))); dims = 1) +end + +__norm_batched(x) = sqrt.(sum(abs2, x; dims = 1:(ndims(x) - 1))) + +function __ffjord(model, u, p, ad = AutoForwardDiff(), regularize::Bool = false, + monte_carlo::Bool = true) + N = ndims(u) + L = size(u, N - 1) + z = selectdim(u, N - 1, 1:(L - ifelse(regularize, 3, 1))) + if monte_carlo + mz, pb_f = Zygote.pullback(model, z, p) + e = CRC.@ignore_derivatives randn!(similar(mz)) + eJ = first(pb_f(e)) + trace_jac = sum(eJ .* e; dims = 1:(N - 1)) + else + mz = model(z, p) + J = _jacobian(ad, model, z, p) + trace_jac = __trace_batched(J) + e = CRC.@ignore_derivatives randn!(similar(mz)) + eJ = vec(e)' * reshape(J, size(J, 1), :) + end if regularize - z = u[1:end - 3, :] - if monte_carlo - mz, back = Zygote.pullback((x,ps,s)->re(x,ps,s)[1], z, p, st) - eJ = back(e)[1] - trace_jac = sum(eJ .* e, dims=1) - else - mz = re(z, ps, st)[1] - trace_jac = _trace_batched(jacobian_fn(re, z, p, st)) - end - vcat(mz, -trace_jac, sum(abs2, mz, dims=1), _norm_batched(eJ)) + return cat(mz, -trace_jac, sum(abs2, mz; dims = 1:(N - 1)), __norm_batched(eJ); + dims = Val(N - 1)) else - z = u[1:end - 1, :] - if monte_carlo - mz, back = Zygote.pullback((x,ps,s)->re(x,ps,s)[1], z, p, st) - eJ = back(e)[1] - trace_jac = sum(eJ .* e, dims=1) - else - mz = re(z, p, st)[1] - trace_jac = _trace_batched(jacobian_fn(re, z, p, st)) - end - vcat(mz, -trace_jac) + return cat(mz, -trace_jac; dims = Val(N - 1)) end end -# When running on GPU e needs to be passed separately, when using Lux pass st as a kwarg -(n::FFJORD)(args...; kwargs...) = forward_ffjord(n, args...; kwargs...) +(n::FFJORD)(x, ps, st) = __forward_ffjord(n, x, ps, st) + +function __forward_ffjord(n::FFJORD, x, ps, st) + N, S, T = ndims(x), size(x), eltype(x) + (; regularize, monte_carlo) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + + model = StatefulLuxLayer(n.model, nothing, st.model) -function forward_ffjord(n::FFJORD, x, p=n.p, e=randn(eltype(x), size(x)); - regularize=false, monte_carlo=true, st=nothing) - pz = n.basedist - sensealg = InterpolatingAdjoint() - ffjord_(u, p, t) = ffjord(u, p, t, n.re, e, st; regularize, monte_carlo) - # ffjord_(u, p, t) = ffjord(u, p, t, n.re, e; regularize, monte_carlo) + ffjord(u, p, t) = __ffjord(model, u, p, n.ad, regularize, monte_carlo) + + _z = ChainRulesCore.@ignore_derivatives fill!(similar(x, + S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T)) + + prob = ODEProblem{false}(ffjord, cat(x, _z; dims = Val(N - 1)), n.tspan, ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., save_everystep = false, + save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) + + z = selectdim(pred, N - 1, 1:(L - ifelse(regularize, 3, 1))) + i₁ = L - ifelse(regularize, 2, 0) + delta_logp = selectdim(pred, N - 1, i₁:i₁) if regularize - _z = ChainRulesCore.@ignore_derivatives similar(x, 3, size(x, 2)) - ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x))) - prob = ODEProblem{false}(ffjord_, vcat(x, _z), n.tspan, p) - pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end] - z = pred[1:end - 3, :] - delta_logp = pred[end - 2:end - 2, :] - λ₁ = pred[end - 1, :] - λ₂ = pred[end, :] + λ₁ = selectdim(pred, N, (L - 1):(L - 1)) + λ₂ = selectdim(pred, N, L:L) else - _z = ChainRulesCore.@ignore_derivatives similar(x, 1, size(x, 2)) - ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x))) - prob = ODEProblem{false}(ffjord_, vcat(x, _z), n.tspan, p) - pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end] - z = pred[1:end - 1, :] - delta_logp = pred[end:end, :] - λ₁ = λ₂ = _z[1, :] + # For Type Stability + λ₁ = λ₂ = delta_logp end - logpz = reshape(logpdf(pz, z), 1, size(x, 2)) - logpx = logpz .- delta_logp + if n.basedist === nothing + logpz = -sum(abs2, z; dims = 1:(N - 1)) / T(2) .- + T(prod(S[1:(N - 1)]) / 2 * log(2π)) + else + logpz = logpdf(n.basedist, z) + end + logpx = reshape(logpz, 1, S[N]) .- delta_logp - logpx, λ₁, λ₂ + return (logpx, λ₁, λ₂), (; model = model.st, regularize, monte_carlo) end -function backward_ffjord(n::FFJORD, n_samples, p=n.p, e=randn(eltype(n.model[1].weight), n_samples); - regularize=false, monte_carlo=true, rng=nothing, st=nothing) +__get_pred(sol::ODESolution) = last(sol.u) +__get_pred(sol::AbstractArray{T, N}) where {T, N} = selectdim(sol, N, size(sol, N)) + +function __backward_ffjord(::Type{T1}, n::FFJORD, n_samples::Int, ps, st, rng) where {T1} px = n.basedist - x = isnothing(rng) ? rand(px, n_samples) : rand(rng, px, n_samples) - sensealg = InterpolatingAdjoint() - ffjord_(u, p, t) = ffjord(u, p, t, n.re, e, st; regularize, monte_carlo) - if regularize - _z = ChainRulesCore.@ignore_derivatives similar(x, 3, size(x, 2)) - ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x))) - prob = ODEProblem{false}(ffjord_, vcat(x, _z), reverse(n.tspan), p) - pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end] - z = pred[1:end - 3, :] + + if px === nothing + if rng === nothing + x = randn(T1, (n.input_dims..., n_samples)) + else + x = randn(rng, T1, (n.input_dims..., n_samples)) + end else - _z = ChainRulesCore.@ignore_derivatives similar(x, 1, size(x, 2)) - ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x))) - prob = ODEProblem{false}(ffjord_, vcat(x, _z), reverse(n.tspan), p) - pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end] - z = pred[1:end - 1, :] + if rng === nothing + x = rand(px, n_samples) + else + x = rand(rng, px, n_samples) + end end - z + N, S, T = ndims(x), size(x), eltype(x) + (; regularize, monte_carlo) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + + model = StatefulLuxLayer(n.model, nothing, st.model) + + ffjord(u, p, t) = __ffjord(model, u, p, n.ad, regularize, monte_carlo) + + _z = ChainRulesCore.@ignore_derivatives fill!(similar(x, + S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T)) + + prob = ODEProblem{false}(ffjord, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., save_everystep = false, + save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) + + return selectdim(pred, N - 1, 1:(L - ifelse(regularize, 3, 1))) end """ -FFJORD can be used as a distribution to generate new samples by `rand` or estimate densities by `pdf` or `logpdf` (from `Distributions.jl`). +FFJORD can be used as a distribution to generate new samples by `rand` or estimate densities +by `pdf` or `logpdf` (from `Distributions.jl`). Arguments: + - `model`: A FFJORD instance - `regularize`: Whether we use regularization (default: `false`) - `monte_carlo`: Whether we use monte carlo (default: `true`) - """ -struct FFJORDDistribution <: ContinuousMultivariateDistribution - model::FFJORD - regularize::Bool - monte_carlo::Bool +@concrete struct FFJORDDistribution{F <: FFJORD} <: ContinuousMultivariateDistribution + model::F + ps + st end -FFJORDDistribution(model; regularize=false, monte_carlo=true) = FFJORDDistribution(model, regularize, monte_carlo) +Base.length(d::FFJORDDistribution) = prod(d.model.input_dims) +Base.eltype(d::FFJORDDistribution) = __eltype(d.ps) + +__eltype(ps::ComponentArray) = __eltype(getdata(ps)) +__eltype(x::AbstractArray) = eltype(x) +function __eltype(x::NamedTuple) + T = Ref(Bool) + fmap(x) do x_ + T[] = promote_type(T[], __eltype(x_)) + x_ + end + return T[] +end -Base.length(d::FFJORDDistribution) = size(d.model.model[1].weight, 2) -Base.eltype(d::FFJORDDistribution) = eltype(d.model.model[1].weight) -Distributions._logpdf(d::FFJORDDistribution, x::AbstractArray) = forward_ffjord(d.model, x; d.regularize, d.monte_carlo)[1] -Distributions._rand!(rng::AbstractRNG, d::FFJORDDistribution, x::AbstractVector{<: Real}) = (x[:] = backward_ffjord(d.model, size(x, 2); d.regularize, d.monte_carlo, rng)) -Distributions._rand!(rng::AbstractRNG, d::FFJORDDistribution, A::DenseMatrix{<: Real}) = (A[:] = backward_ffjord(d.model, size(A, 2); d.regularize, d.monte_carlo, rng)) +function Distributions._logpdf(d::FFJORDDistribution, x::AbstractVector) + return first(first(__forward_ffjord(d.model, reshape(x, :, 1), d.ps, d.st))) +end +function Distributions._logpdf(d::FFJORDDistribution, x::AbstractArray) + return first(first(__forward_ffjord(d.model, x, d.ps, d.st))) +end +function Distributions._rand!(rng::AbstractRNG, d::FFJORDDistribution, + x::AbstractArray{<:Real}) + x[:] = __backward_ffjord(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng) + return x +end diff --git a/src/hnn.jl b/src/hnn.jl index 1887393bc5..c5a4bed22d 100644 --- a/src/hnn.jl +++ b/src/hnn.jl @@ -1,131 +1,108 @@ """ -Constructs a Hamiltonian Neural Network [1]. This neural network is useful for -learning symmetries and conservation laws by supervision on the gradients -of the trajectories. It takes as input a concatenated vector of length `2n` -containing the position (of size `n`) and momentum (of size `n`) of the -particles. It then returns the time derivatives for position and momentum. + HamiltonianNN(model; ad = AutoForwardDiff()) + +Constructs a Hamiltonian Neural Network [1]. This neural network is useful for learning +symmetries and conservation laws by supervision on the gradients of the trajectories. It +takes as input a concatenated vector of length `2n` containing the position (of size `n`) +and momentum (of size `n`) of the particles. It then returns the time derivatives for +position and momentum. !!! note This doesn't solve the Hamiltonian Problem. Use [`NeuralHamiltonianDE`](@ref) for such applications. -!!! note - To compute the gradients for this layer, it is recommended to use ForwardDiff.jl - -To obtain the gradients to train this network, ForwardDiff.gradient is supposed to -be used. Follow this -[tutorial](https://docs.sciml.ai/DiffEqFlux/stable/examples/hamiltonian_nn/) to see how -to define a training loop to circumvent this issue. +Arguments: -```julia -HamiltonianNN(model; p = nothing) -``` +1. `model`: A `Flux.Chain` or `Lux.AbstractExplicitLayer` neural network that returns the + Hamiltonian of the system. +2. `ad`: The autodiff framework to be used for the internal Hamiltonian computation. The + default is `AutoForwardDiff()` -Arguments: -1. `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that returns the Hamiltonian of the - system. -2. `p`: The initial parameters of the neural network. +!!! note + If training with Zygote, ensure that the `chunksize` for `AutoForwardDiff` is set to + `nothing`. References: [1] Greydanus, Samuel, Misko Dzamba, and Jason Yosinski. "Hamiltonian Neural Networks." Advances in Neural Information Processing Systems 32 (2019): 15379-15389. - """ -struct HamiltonianNN{M,R,P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} +@concrete struct HamiltonianNN{M <: AbstractExplicitLayer} <: + AbstractExplicitContainerLayer{(:model,)} model::M - re::R - p::P + ad end -function HamiltonianNN(model; p=nothing) - _p, re = Flux.destructure(model) - p === nothing && (p = _p) - return HamiltonianNN{typeof(model),typeof(re),typeof(p)}(model, re, p) +function HamiltonianNN(model; ad = AutoForwardDiff()) + @assert ad isa AutoForwardDiff || ad isa AutoZygote || ad isa AutoEnzyme + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return HamiltonianNN(model, ad) end -function HamiltonianNN(model::LuxCore.AbstractExplicitLayer; p=nothing) - @assert p === nothing - return HamiltonianNN{typeof(model),Nothing,Nothing}(model, nothing, nothing) +function __gradient_with_ps(model, psax, N, x) + function __gradient_closure(psx) + x_ = reshape(psx[1:N], size(x)) + ps = ComponentArray(psx[(N + 1):end], psax) + return sum(model(x_, ps)) + end end -Flux.trainable(hnn::HamiltonianNN) = (hnn.p,) +function __hamiltonian_forward(::AutoForwardDiff{nothing}, model, x, ps::ComponentArray) + psd = getdata(ps) + psx = vcat(vec(x), psd) + N = length(x) + H = ForwardDiff.gradient(__gradient_with_ps(model, getaxes(ps), N, x), psx) + return reshape(view(H, 1:N), size(x)) +end -function _hamiltonian_forward(re, p, x) - H = only(Zygote.gradient(x -> sum(re(p)(x)), x)) - n = size(x, 1) ÷ 2 - return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n)) +function __hamiltonian_forward(::AutoForwardDiff{CS}, model, x, ps) where {CS} + chunksize = CS === nothing ? ForwardDiff.pickchunksize(length(x)) : CS + __f = sum ∘ Base.Fix2(model, ps) + cfg = ForwardDiff.GradientConfig(__f, x, ForwardDiff.Chunk{chunksize}()) + return ForwardDiff.gradient(__f, x, cfg) end -(hnn::HamiltonianNN)(x, p=hnn.p) = _hamiltonian_forward(hnn.re, p, x) +function __hamiltonian_forward(::AutoZygote, model, x, ps) + return first(Zygote.gradient(sum ∘ model, x, ps)) +end function (hnn::HamiltonianNN{<:LuxCore.AbstractExplicitLayer})(x, ps, st) - (_, st), pb_f = Zygote.pullback(x) do x - y, st_ = hnn.model(x, ps, st) - return sum(y), st_ - end - H = only(pb_f((one(eltype(x)), nothing))) + model = StatefulLuxLayer(hnn.model, nothing, st) + H = __hamiltonian_forward(hnn.ad, model, x, ps) n = size(x, 1) ÷ 2 - return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n)), st + return vcat(selectdim(H, 1, (n + 1):(2n)), -selectdim(H, 1, 1:n)), model.st end """ -Contructs a Neural Hamiltonian DE Layer for solving Hamiltonian Problems -parameterized by a Neural Network [`HamiltonianNN`](@ref). + NeuralHamiltonianDE(model, tspan, args...; kwargs...) -```julia -NeuralHamiltonianDE(model, tspan, args...; kwargs...) -``` +Contructs a Neural Hamiltonian DE Layer for solving Hamiltonian Problems parameterized by a +Neural Network [`HamiltonianNN`](@ref). Arguments: -- `model`: A Flux.Chain, Lux.AbstractExplicitLayer, or Hamiltonian Neural Network that predicts the - Hamiltonian of the system. +- `model`: A Flux.Chain, Lux.AbstractExplicitLayer, or Hamiltonian Neural Network that + predicts the Hamiltonian of the system. - `tspan`: The timespan to be solved on. - `kwargs`: Additional arguments splatted to the ODE solver. See the - [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) - documentation for more details. + [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) + documentation for more details. """ -struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer - model::HamiltonianNN{M,RE,P} - p::P - tspan::T - args::A - kwargs::K -end - -# TODO: Make sensealg an argument -function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...) - hnn = HamiltonianNN(model, p=p) - return NeuralHamiltonianDE{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re), - typeof(tspan),typeof(args),typeof(kwargs)}( - hnn, hnn.p, tspan, args, kwargs) -end - -function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...; - p=hnn.p, kwargs...) where {M,RE,P} - return NeuralHamiltonianDE{M,P,RE,typeof(tspan),typeof(args), - typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs) +@concrete struct NeuralHamiltonianDE{M <: HamiltonianNN} <: NeuralDELayer + model::M + tspan + args + kwargs end -function (nhde::NeuralHamiltonianDE)(x, p=nhde.p) - function neural_hamiltonian!(du, u, p, t) - du .= reshape(nhde.model(u, p), size(du)) - end - prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, p) - # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use - # ForwardDiff.jl internally. - sensealg = InterpolatingAdjoint(; autojacvec=true) - return solve(prob, nhde.args...; sensealg, nhde.kwargs...) +function NeuralHamiltonianDE(model, tspan, args...; ad = AutoForwardDiff(), kwargs...) + hnn = model isa HamiltonianNN ? model : HamiltonianNN(model; ad) + return NeuralHamiltonianDE(hnn, tspan, args, kwargs) end -function (nhde::NeuralHamiltonianDE{<:LuxCore.AbstractExplicitLayer})(x, ps, st) - function neural_hamiltonian!(du, u, p, t) - y, st = nhde.model(u, p, st) - du .= reshape(y, size(du)) - end - prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, ps) - # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use - # ForwardDiff.jl internally. - sensealg = InterpolatingAdjoint(; autojacvec=true) - return solve(prob, nhde.args...; sensealg, nhde.kwargs...), st +function (nhde::NeuralHamiltonianDE)(x, ps, st) + model = StatefulLuxLayer(nhde.model, nothing, st) + neural_hamiltonian(u, p, t) = model(u, p) + prob = ODEProblem{false}(neural_hamiltonian, x, nhde.tspan, ps) + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + return solve(prob, nhde.args...; sensealg, nhde.kwargs...), model.st end diff --git a/src/multiple_shooting.jl b/src/multiple_shooting.jl index 81edf73220..746c8645f6 100644 --- a/src/multiple_shooting.jl +++ b/src/multiple_shooting.jl @@ -1,52 +1,41 @@ """ -Returns a total loss after trying a 'Direct multiple shooting' on ODE data -and an array of predictions from each of the groups (smaller intervals). -In Direct Multiple Shooting, the Neural Network divides the interval into smaller intervals -and solves for them separately. + multiple_shoot(p, ode_data, tsteps, prob, loss_function, + [continuity_loss = _default_continuity_loss], solver, group_size; + continuity_term = 100, kwargs...) + +Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of +predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the +Neural Network divides the interval into smaller intervals and solves for them separately. The default continuity term is 100, implying any losses arising from the non-continuity of 2 different groups will be scaled by 100. -```julia -multiple_shoot(p, ode_data, tsteps, prob, loss_function, solver, group_size; - continuity_term=100, kwargs...) -multiple_shoot(p, ode_data, tsteps, prob, loss_function, continuity_loss, solver, group_size; - continuity_term=100, kwargs...) -``` - Arguments: + - `p`: The parameters of the Neural Network to be trained. - `ode_data`: Original Data to be modelled. - `tsteps`: Timesteps on which ode_data was calculated. - `prob`: ODE problem that the Neural Network attempts to solve. - `loss_function`: Any arbitrary function to calculate loss. - `continuity_loss`: Function that takes states ``\\hat{u}_{end}`` of group ``k`` and - ``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss - between them. - If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used. + ``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss between + them. If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used. - `solver`: ODE Solver algorithm. - `group_size`: The group size achieved after splitting the ode_data into equal sizes. - `continuity_term`: Weight term to ensure continuity of predictions throughout different groups. - `kwargs`: Additional arguments splatted to the ODE solver. Refer to the - [Local Sensitivity Analysis](https://docs.sciml.ai/DiffEqDocs/stable/analysis/sensitivity/) and - [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) - documentation for more details. -Note: -The parameter 'continuity_term' should be a relatively big number to enforce a large penalty -whenever the last point of any group doesn't coincide with the first point of next group. + [Local Sensitivity Analysis](https://docs.sciml.ai/DiffEqDocs/stable/analysis/sensitivity/) and + [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) + documentation for more details. + +!!! note + + The parameter 'continuity_term' should be a relatively big number to enforce a large penalty + whenever the last point of any group doesn't coincide with the first point of next group. """ -function multiple_shoot( - p::AbstractArray, - ode_data::AbstractArray, - tsteps::AbstractArray, - prob::ODEProblem, - loss_function, - continuity_loss, - solver::DiffEqBase.AbstractODEAlgorithm, - group_size::Integer; - continuity_term::Real=100, - kwargs... -) +function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F, + continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; + continuity_term::Real = 100, kwargs...) where {F, C} datasize = size(ode_data, 2) if group_size < 2 || group_size > datasize @@ -57,26 +46,14 @@ function multiple_shoot( ranges = group_ranges(datasize, group_size) # Multiple shooting predictions - sols = [ - solve( - remake( - prob; - p=p, - tspan=(tsteps[first(rg)], tsteps[last(rg)]), - u0=ode_data[:, first(rg)], - ), - solver; - saveat=tsteps[rg], - kwargs... - ) for rg in ranges - ] + sols = [solve(remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]), + u0 = ode_data[:, first(rg)],), solver; saveat = tsteps[rg], kwargs...) + for rg in ranges] group_predictions = Array.(sols) # Abort and return infinite loss if one of the integrations failed retcodes = [sol.retcode for sol in sols] - if any(retcodes .!= :Success) - return Inf, group_predictions - end + all(SciMLBase.successful_retcode, retcodes) || return Inf, group_predictions # Calculate multiple shooting loss loss = 0 @@ -88,90 +65,61 @@ function multiple_shoot( if i > 1 # Ensure continuity between last state in previous prediction # and current initial condition in ode_data - loss += - continuity_term * continuity_loss(group_predictions[i - 1][:, end], u[:, 1]) + loss += continuity_term * + continuity_loss(group_predictions[i - 1][:, end], u[:, 1]) end end return loss, group_predictions end -function multiple_shoot( - p::AbstractArray, - ode_data::AbstractArray, - tsteps::AbstractArray, - prob::ODEProblem, - loss_function::Function, - solver::DiffEqBase.AbstractODEAlgorithm, - group_size::Integer; - kwargs..., -) - - return multiple_shoot( - p, - ode_data, - tsteps, - prob, - loss_function, - _default_continuity_loss, - solver, - group_size; - kwargs..., - ) +function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F, + solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; kwargs...) where {F} + return multiple_shoot(p, ode_data, tsteps, prob, loss_function, + _default_continuity_loss, solver, group_size; kwargs...,) end """ -Returns a total loss after trying a 'Direct multiple shooting' on ODE data -and an array of predictions from each of the groups (smaller intervals). -In Direct Multiple Shooting, the Neural Network divides the interval into smaller intervals -and solves for them separately. + multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function, + [continuity_loss = _default_continuity_loss], solver, group_size; + continuity_term = 100, kwargs...) + +Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of +predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the +Neural Network divides the interval into smaller intervals and solves for them separately. The default continuity term is 100, implying any losses arising from the non-continuity of 2 different groups will be scaled by 100. -```julia -multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, loss_function, solver, - group_size; continuity_term=100, trajectories) -multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, loss_function, - continuity_loss, solver, group_size; continuity_term=100, trajectories) -``` - Arguments: - - `p`: The parameters of the Neural Network to be trained. - - `ode_data_ensemble`: Original Data to be modelled. Batches (or equivalently "trajectories") are located in the third dimension. - - `tsteps`: Timesteps on which `ode_data_ensemble` was calculated. - - `ensemble_prob`: Ensemble problem that the Neural Network attempts to solve. - - `ensemble_alg`: Ensemble algorithm, e.g. `EnsembleThreads()` - - `loss_function`: Any arbitrary function to calculate loss. - - `continuity_loss`: Function that takes states ``\\hat{u}_{end}`` of group ``k`` and - ``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss - between them. - If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used. - - `solver`: ODE Solver algorithm. - - `group_size`: The group size achieved after splitting the ode_data into equal sizes. - - `continuity_term`: Weight term to ensure continuity of predictions throughout - different groups. - - `trajectories`: number of trajectories for `ensemble_prob`. - - `kwargs`: Additional arguments splatted to the ODE solver. Refer to the - [Local Sensitivity Analysis](https://docs.sciml.ai/DiffEqDocs/stable/analysis/sensitivity/) and - [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) - documentation for more details. -Note: -The parameter 'continuity_term' should be a relatively big number to enforce a large penalty -whenever the last point of any group doesn't coincide with the first point of next group. + +- `p`: The parameters of the Neural Network to be trained. +- `ode_data`: Original Data to be modelled. +- `tsteps`: Timesteps on which ode_data was calculated. +- `ensemble_prob`: Ensemble problem that the Neural Network attempts to solve. +- `ensemble_alg`: Ensemble algorithm, e.g. `EnsembleThreads()` +- `prob`: ODE problem that the Neural Network attempts to solve. +- `loss_function`: Any arbitrary function to calculate loss. +- `continuity_loss`: Function that takes states ``\\hat{u}_{end}`` of group ``k`` and +``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss between +them. If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used. +- `solver`: ODE Solver algorithm. +- `group_size`: The group size achieved after splitting the ode_data into equal sizes. +- `continuity_term`: Weight term to ensure continuity of predictions throughout +different groups. +- `kwargs`: Additional arguments splatted to the ODE solver. Refer to the +[Local Sensitivity Analysis](https://docs.sciml.ai/DiffEqDocs/stable/analysis/sensitivity/) and +[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) +documentation for more details. + +!!! note + + The parameter 'continuity_term' should be a relatively big number to enforce a large penalty + whenever the last point of any group doesn't coincide with the first point of next group. """ -function multiple_shoot( - p::AbstractArray, - ode_data::AbstractArray, - tsteps::AbstractArray, - ensembleprob::EnsembleProblem, - ensemblealg::SciMLBase.BasicEnsembleAlgorithm, - loss_function, - continuity_loss, - solver::DiffEqBase.AbstractODEAlgorithm, - group_size::Integer; - continuity_term::Real=100, - kwargs... -) +function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, + ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F, continuity_loss::C, + solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; + continuity_term::Real = 100, kwargs...) where {F, C} datasize = size(ode_data, 2) prob = ensembleprob.prob @@ -179,46 +127,30 @@ function multiple_shoot( throw(DomainError(group_size, "group_size can't be < 2 or > number of data points")) end - @assert ndims(ode_data) == 3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)" - @assert size(ode_data,2) == length(tsteps) - @assert size(ode_data,3) == kwargs[:trajectories] + @assert ndims(ode_data)==3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)" + @assert size(ode_data, 2) == length(tsteps) + @assert size(ode_data, 3) == kwargs[:trajectories] # Get ranges that partition data to groups of size group_size ranges = group_ranges(datasize, group_size) - # Multiple shooting predictions - # by using map we avoid mutating an array - sols = map( - rg -> begin - newprob = remake( - prob; - p=p, - tspan=(tsteps[first(rg)], tsteps[last(rg)]), - ) + # Multiple shooting predictions by using map we avoid mutating an array + sols = map(rg -> begin + newprob = remake(prob; + p = p, + tspan = (tsteps[first(rg)], tsteps[last(rg)]),) function prob_func(prob, i, repeat) - remake(prob, u0 = ode_data[:, first(rg), i]) + remake(prob; u0 = ode_data[:, first(rg), i]) end - newensembleprob = EnsembleProblem(newprob, - prob_func, - ensembleprob.output_func, - ensembleprob.reduction, - ensembleprob.u_init, - ensembleprob.safetycopy); - solve(newensembleprob, - solver, - ensemblealg; - saveat=tsteps[rg], - kwargs... - ) - end, - ranges) + newensembleprob = EnsembleProblem(newprob, prob_func, ensembleprob.output_func, + ensembleprob.reduction, ensembleprob.u_init, ensembleprob.safetycopy) + solve(newensembleprob, solver, ensemblealg; saveat = tsteps[rg], kwargs...) + end, ranges) group_predictions = Array.(sols) # Abort and return infinite loss if one of the integrations did not converge? convergeds = [sol.converged for sol in sols] - if any(.! convergeds) - return Inf, group_predictions - end + any(.!convergeds) && return Inf, group_predictions # Calculate multiple shooting loss loss = 0 @@ -233,53 +165,31 @@ function multiple_shoot( if i > 1 # Ensure continuity between last state in previous prediction # and current initial condition in ode_data - loss += - continuity_term * continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :]) + loss += continuity_term * + continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :]) end end return loss, group_predictions end -function multiple_shoot( - p::AbstractArray, - ode_data::AbstractArray, - tsteps::AbstractArray, - ensembleprob::EnsembleProblem, - ensemblealg::SciMLBase.BasicEnsembleAlgorithm, - loss_function::Function, - solver::DiffEqBase.AbstractODEAlgorithm, - group_size::Integer; - continuity_term::Real=100, - kwargs... -) - - return multiple_shoot( - p, - ode_data, - tsteps, - ensembleprob, - ensemblealg, - loss_function, - _default_continuity_loss, - solver, - group_size; - continuity_term, - kwargs... - ) - +function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, + ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F, + solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; + continuity_term::Real = 100, kwargs...) where {F} + return multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function, + _default_continuity_loss, solver, group_size; continuity_term, kwargs...) end """ + group_ranges(datasize, groupsize) + Get ranges that partition data of length `datasize` in groups of `groupsize` observations. If the data isn't perfectly dividable by `groupsize`, the last group contains the reminding observations. -```julia -group_ranges(datasize, groupsize) -``` - Arguments: + - `datasize`: amount of data points to be partitioned - `groupsize`: maximum amount of observations in each group @@ -293,18 +203,11 @@ julia> group_ranges(10, 5) ``` """ function group_ranges(datasize::Integer, groupsize::Integer) - 2 <= groupsize <= datasize || throw( - DomainError( - groupsize, - "datasize must be positive and groupsize must to be within [2, datasize]", - ), - ) - return [i:min(datasize, i + groupsize - 1) for i in 1:groupsize-1:datasize-1] + 2 <= groupsize <= datasize || throw(DomainError(groupsize, + "datasize must be positive and groupsize must to be within [2, datasize]")) + return [i:min(datasize, i + groupsize - 1) for i in 1:(groupsize - 1):(datasize - 1)] end # Default ontinuity loss between last state in previous prediction # and current initial condition in ode_data -function _default_continuity_loss(û_end::AbstractArray, - u_0::AbstractArray) - return sum(abs, û_end - u_0) -end +_default_continuity_loss(û_end, u_0) = sum(abs, û_end - u_0) diff --git a/src/neural_de.jl b/src/neural_de.jl index 0819b80343..7d4d54ec56 100644 --- a/src/neural_de.jl +++ b/src/neural_de.jl @@ -1,22 +1,21 @@ -abstract type NeuralDELayer <: LuxCore.AbstractExplicitContainerLayer{(:model,)} end -abstract type NeuralSDELayer <: LuxCore.AbstractExplicitContainerLayer{(:drift,:diffusion,)} end -basic_tgrad(u,p,t) = zero(u) -basic_dde_tgrad(u,h,p,t) = zero(u) +abstract type NeuralDELayer <: AbstractExplicitContainerLayer{(:model,)} end +abstract type NeuralSDELayer <: AbstractExplicitContainerLayer{(:drift, :diffusion)} end + +basic_tgrad(u, p, t) = zero(u) +basic_dde_tgrad(u, h, p, t) = zero(u) """ + NeuralODE(model, tspan, alg = nothing, args...; kwargs...) + Constructs a continuous-time recurrant neural network, also known as a neural ordinary differential equation (neural ODE), with a fast gradient calculation via adjoints [1]. At a high level this corresponds to solving the forward differential equation, using a second differential equation that propagates the derivatives of the loss backwards in time. -```julia -NeuralODE(model,tspan,alg=nothing,args...;kwargs...) -``` - Arguments: -- `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the ̇x. +- `model`: A `Flux.Chain` or `Lux.AbstractExplicitLayer` neural network that defines the ̇x. - `tspan`: The timespan to be solved on. - `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the default algorithm from DifferentialEquations.jl. @@ -31,240 +30,145 @@ Arguments: References: [1] Pontryagin, Lev Semenovich. Mathematical theory of optimal processes. CRC press, 1987. - """ -struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer +@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer model::M - p::P - re::RE - tspan::T - args::A - kwargs::K + tspan + args + kwargs end -function NeuralODE(model,tspan,args...;p = nothing,kwargs...) - _p,re = Flux.destructure(model) - if p === nothing - p = _p - end - NeuralODE{typeof(model),typeof(p),typeof(re), - typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,tspan,args,kwargs) +function NeuralODE(model, tspan, args...; kwargs...) + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return NeuralODE(model, tspan, args, kwargs) end -function NeuralODE(model::LuxCore.AbstractExplicitLayer,tspan,args...;p=nothing,kwargs...) - re = nothing - NeuralODE{typeof(model),typeof(p),typeof(re), - typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,tspan,args,kwargs) -end +function (n::NeuralODE)(x, p, st) + model = StatefulLuxLayer(n.model, nothing, st) -@functor NeuralODE (p,) + dudt(u, p, t) = model(u, p) + ff = ODEFunction{false}(dudt; tgrad = basic_tgrad) + prob = ODEProblem{false}(ff, x, n.tspan, p) -function (n::NeuralODE)(x,p=n.p) - dudt_(u,p,t) = n.re(p)(u) - ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad) - prob = ODEProblem{false}(ff,x,getfield(n,:tspan),p) - sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) - solve(prob,n.args...;sensealg=sense,n.kwargs...) -end - -function (n::NeuralODE{M})(x,p,st) where {M<:LuxCore.AbstractExplicitLayer} - function dudt(u,p,t;st=st) - u_, st = n.model(u,p,st) - return u_ - end - - ff = ODEFunction{false}(dudt,tgrad=basic_tgrad) - prob = ODEProblem{false}(ff,x,n.tspan,p) - sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) - return solve(prob,n.args...;sensealg=sense,n.kwargs...), st + return (solve(prob, n.args...; + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...), model.st) end """ -Constructs a neural stochastic differential equation (neural SDE) with diagonal noise. + NeuralDSDE(drift, diffusion, tspan, alg = nothing, args...; sensealg = TrackerAdjoint(), + kwargs...) -```julia -NeuralDSDE(drift,diffusion,tspan,alg=nothing,args...; - sensealg=TrackerAdjoint(),kwargs...) -``` +Constructs a neural stochastic differential equation (neural SDE) with diagonal noise. Arguments: -- `drift`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the drift function. -- `diffusion`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the diffusion function. - Should output a vector of the same size as the input. +- `drift`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the drift + function. +- `diffusion`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the + diffusion function. Should output a vector of the same size as the input. - `tspan`: The timespan to be solved on. - `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the default algorithm from DifferentialEquations.jl. - `sensealg`: The choice of differentiation algorthm used in the backpropogation. - Defaults to using reverse-mode automatic differentiation via Tracker.jl - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - """ -struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralSDELayer - p::P - len::Int - drift::M - re1::RE +@concrete struct NeuralDSDE{M1 <: AbstractExplicitLayer, M2 <: AbstractExplicitLayer} <: + NeuralSDELayer + drift::M1 diffusion::M2 - re2::RE2 - tspan::T - args::A - kwargs::K -end - -function NeuralDSDE(drift,diffusion,tspan,args...;p = nothing, kwargs...) - p1,re1 = Flux.destructure(drift) - p2,re2 = Flux.destructure(diffusion) - if p === nothing - p = [p1;p2] - end - NeuralDSDE{typeof(drift),typeof(p),typeof(re1),typeof(diffusion),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}(p, - length(p1),drift,re1,diffusion,re2,tspan,args,kwargs) + tspan + args + kwargs end -function NeuralDSDE(drift::LuxCore.AbstractExplicitLayer,diffusion::LuxCore.AbstractExplicitLayer,tspan,args...; - p1 =nothing, - p = nothing, kwargs...) - re1 = nothing - re2 = nothing - NeuralDSDE{typeof(drift),typeof(p),typeof(re1),typeof(diffusion),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}(p, - Int(1),drift,re1,diffusion,re2,tspan,args,kwargs) +function NeuralDSDE(drift, diffusion, tspan, args...; kwargs...) + !(drift isa AbstractExplicitLayer) && (drift = Lux.transform(drift)) + !(diffusion isa AbstractExplicitLayer) && (diffusion = Lux.transform(diffusion)) + return NeuralDSDE(drift, diffusion, tspan, args, kwargs) end -@functor NeuralDSDE (p,) +function (n::NeuralDSDE)(x, p, st) + drift = StatefulLuxLayer(n.drift, nothing, st.drift) + diffusion = StatefulLuxLayer(n.diffusion, nothing, st.diffusion) -function (n::NeuralDSDE)(x,p=n.p) - dudt_(u,p,t) = n.re1(p[1:n.len])(u) - g(u,p,t) = n.re2(p[(n.len+1):end])(u) - ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) - prob = SDEProblem{false}(ff,g,x,n.tspan,p) - solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) -end + dudt(u, p, t) = drift(u, p.drift) + g(u, p, t) = diffusion(u, p.diffusion) -function (n::NeuralDSDE{M})(x,p,st) where {M<:LuxCore.AbstractExplicitLayer} - st1 = st.drift - st2 = st.diffusion - function dudt_(u,p,t;st=st1) - u_, st = n.drift(u,p.drift,st) - return u_ - end - function g(u,p,t;st=st2) - u_, st = n.diffusion(u,p.diffusion,st) - return u_ - end - - ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) - prob = SDEProblem{false}(ff,g,x,n.tspan,p) - return solve(prob,n.args...;sensealg=BacksolveAdjoint(),n.kwargs...), (drift = st1, diffusion = st2) + ff = SDEFunction{false}(dudt, g; tgrad = basic_tgrad) + prob = SDEProblem{false}(ff, g, x, n.tspan, p) + return (solve(prob, n.args...; u0 = x, p, sensealg = TrackerAdjoint(), n.kwargs...), + (; drift = drift.st, diffusion = diffusion.st)) end """ -Constructs a neural stochastic differential equation (neural SDE). + NeuralSDE(drift, diffusion, tspan, nbrown, alg = nothing, args...; + sensealg=TrackerAdjoint(),kwargs...) -```julia -NeuralSDE(drift,diffusion,tspan,nbrown,alg=nothing,args...; - sensealg=TrackerAdjoint(),kwargs...) -``` +Constructs a neural stochastic differential equation (neural SDE). Arguments: -- `drift`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the drift function. -- `diffusion`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the diffusion function. - Should output a matrix that is nbrown x size(x,1). +- `drift`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the drift + function. +- `diffusion`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the + diffusion function. Should output a matrix that is `nbrown x size(x, 1)`. - `tspan`: The timespan to be solved on. - `nbrown`: The number of Brownian processes - `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the default algorithm from DifferentialEquations.jl. - `sensealg`: The choice of differentiation algorthm used in the backpropogation. - Defaults to using reverse-mode automatic differentiation via Tracker.jl - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - """ -struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralSDELayer - p::P - len::Int - drift::M - re1::RE +@concrete struct NeuralSDE{M1 <: AbstractExplicitLayer, M2 <: AbstractExplicitLayer} <: + NeuralSDELayer + drift::M1 diffusion::M2 - re2::RE2 - tspan::T + tspan nbrown::Int - args::A - kwargs::K -end - -function NeuralSDE(drift,diffusion,tspan,nbrown,args...;p=nothing,kwargs...) - p1,re1 = Flux.destructure(drift) - p2,re2 = Flux.destructure(diffusion) - if p === nothing - p = [p1;p2] - end - NeuralSDE{typeof(p),typeof(drift),typeof(re1),typeof(diffusion),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}( - p,length(p1),drift,re1,diffusion,re2,tspan,nbrown,args,kwargs) + args + kwargs end -function NeuralSDE(drift::LuxCore.AbstractExplicitLayer, diffusion::LuxCore.AbstractExplicitLayer,tspan,nbrown,args...; - p1 = nothing, p = nothing, kwargs...) - re1 = nothing - re2 = nothing - NeuralSDE{typeof(p),typeof(drift),typeof(re1),typeof(diffusion),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}( - p,Int(1),drift,re1,diffusion,re2,tspan,nbrown,args,kwargs) +function NeuralSDE(drift, diffusion, tspan, nbrown, args...; kwargs...) + !(drift isa AbstractExplicitLayer) && (drift = Lux.transform(drift)) + !(diffusion isa AbstractExplicitLayer) && (diffusion = Lux.transform(diffusion)) + return NeuralSDE(drift, diffusion, tspan, nbrown, args, kwargs) end -@functor NeuralSDE (p,) +function (n::NeuralSDE)(x, p, st) + drift = StatefulLuxLayer(n.drift, p.drift, st.drift) + diffusion = StatefulLuxLayer(n.diffusion, p.diffusion, st.diffusion) -function (n::NeuralSDE)(x,p=n.p) - dudt_(u,p,t) = n.re1(p[1:n.len])(u) - g(u,p,t) = n.re2(p[(n.len+1):end])(u) - ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) - prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown)) - solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) -end + dudt(u, p, t) = drift(u, p.drift) + g(u, p, t) = diffusion(u, p.diffusion) -function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:LuxCore.AbstractExplicitLayer} - st1 = st.drift - st2 = st.diffusion - function dudt_(u,p,t;st=st1) - u_, st = n.drift(u,p.drift,st) - return u_ - end - function g(u,p,t;st=st2) - u_, st = n.diffusion(u,p.diffusion,st) - return u_ - end + noise_rate_prototype = CRC.@ignore_derivatives fill!(similar(x, length(x), n.nbrown), 0) - ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) - prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown)) - return solve(prob,n.args...;sensealg=BacksolveAdjoint(),n.kwargs...), (drift = st1, diffusion = st2) + ff = SDEFunction{false}(dudt, g; tgrad = basic_tgrad) + prob = SDEProblem{false}(ff, g, x, n.tspan, p; noise_rate_prototype) + return (solve(prob, n.args...; u0 = x, p, sensealg = TrackerAdjoint(), n.kwargs...), + (; drift = drift.st, diffusion = diffusion.st)) end """ -Constructs a neural delay differential equation (neural DDE) with constant -delays. + NeuralCDDE(model, tspan, hist, lags, alg = nothing, args...; + sensealg = TrackerAdjoint(), kwargs...) -```julia -NeuralCDDE(model,tspan,hist,lags,alg=nothing,args...; - sensealg=TrackerAdjoint(),kwargs...) -``` +Constructs a neural delay differential equation (neural DDE) with constant delays. Arguments: -- `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the derivative function. - Should take an input of size `[x;x(t-lag_1);...;x(t-lag_n)]` and produce and - output shaped like `x`. +- `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the + derivative function. Should take an input of size `[x; x(t - lag_1); ...; x(t - lag_n)]` + and produce and output shaped like `x`. - `tspan`: The timespan to be solved on. -- `hist`: Defines the history function `h(t)` for values before the start of the - integration. +- `hist`: Defines the history function `h(u, p, t)` for values before the start of the + integration. Note that `u` is supposed to be used to return a value that matches the size + of `u`. - `lags`: Defines the lagged values that should be utilized in the neural network. - `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the default algorithm from DifferentialEquations.jl. @@ -273,74 +177,47 @@ Arguments: - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - """ -Unsupported_NeuralCDDE_pairing_message = """ - NeuralCDDE can only be instantiated with a Flux chain - """ - -struct Unsupported_pairing <:Exception - msg::Any -end - -function Base.showerror(io::IO, e::Unsupported_pairing) - println(io, e.msg) -end - -struct NeuralCDDE{P,M,RE,H,L,T,A,K} <: NeuralDELayer - p::P +@concrete struct NeuralCDDE{M <: AbstractExplicitLayer} <: NeuralDELayer model::M - re::RE - hist::H - lags::L - tspan::T - args::A - kwargs::K -end - -function NeuralCDDE(model,tspan,hist,lags,args...;p=nothing,kwargs...) - _p,re = Flux.destructure(model) - if p === nothing - p = _p - end - NeuralCDDE{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), - typeof(tspan),typeof(args),typeof(kwargs)}(p,model, - re,hist,lags,tspan,args,kwargs) + tspan + hist + lags + args + kwargs end -function NeuralCDDE(model::LuxCore.AbstractExplicitLayer,tspan,hist,lags,args...;p = nothing,kwargs...) - throw(Unsupported_pairing(Unsupported_NeuralCDDE_pairing_message)) -# re = nothing -# new{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), -# typeof(tspan),typeof(args),typeof(kwargs)}(p,model, -# re,hist,lags,tspan,args,kwargs) +function NeuralCDDE(model, tspan, hist, lags, args...; kwargs...) + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return NeuralCDDE(model, tspan, hist, lags, args, kwargs) end -@functor NeuralCDDE (p,) +function (n::NeuralCDDE)(x, ps, st) + model = StatefulLuxLayer(n.model, nothing, st) -function (n::NeuralCDDE)(x,p=n.p) - function dudt_(u,h,p,t) - _u = vcat(u,(h(p,t-lag) for lag in n.lags)...) - n.re(p)(_u) + function dudt(u, h, p, t) + xs = mapfoldl(lag -> h(p, t - lag), vcat, n.lags) + return model(vcat(u, xs), p) end - ff = DDEFunction{false}(dudt_,tgrad=basic_dde_tgrad) - prob = DDEProblem{false}(ff,x,n.hist,n.tspan,p,constant_lags = n.lags) - solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) + + ff = DDEFunction{false}(dudt; tgrad = basic_dde_tgrad) + prob = DDEProblem{false}(ff, x, (p, t) -> n.hist(x, p, t), n.tspan, ps; + constant_lags = n.lags) + + return (solve(prob, n.args...; sensealg = TrackerAdjoint(), n.kwargs...), model.st) end """ -Constructs a neural differential-algebraic equation (neural DAE). + NeuralDAE(model, constraints_model, tspan, args...; differential_vars = nothing, + sensealg = TrackerAdjoint(), kwargs...) -```julia -NeuralDAE(model,constraints_model,tspan,alg=nothing,args...; - sensealg=TrackerAdjoint(),kwargs...) -``` +Constructs a neural differential-algebraic equation (neural DAE). Arguments: -- `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the derivative function. - Should take an input of size `x` and produce the residual of `f(dx,x,t)` - for only the differential variables. +- `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the + derivative function. Should take an input of size `x` and produce the residual of + `f(dx,x,t)` for only the differential variables. - `constraints_model`: A function `constraints_model(u,p,t)` for the fixed constaints to impose on the algebraic equations. - `tspan`: The timespan to be solved on. @@ -351,115 +228,73 @@ Arguments: - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - """ -struct NeuralDAE{P,M,M2,D,RE,T,DV,A,K} <: NeuralDELayer +@concrete struct NeuralDAE{M <: AbstractExplicitLayer} <: NeuralDELayer model::M - constraints_model::M2 - p::P - du0::D - re::RE - tspan::T - differential_vars::DV - args::A - kwargs::K -end - -function NeuralDAE(model,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...) - _p,re = Flux.destructure(model) - - if p === nothing - p = _p - end - - NeuralDAE{typeof(p),typeof(model),typeof(constraints_model), - typeof(du0),typeof(re),typeof(tspan), - typeof(differential_vars),typeof(args),typeof(kwargs)}( - model,constraints_model,p,du0,re,tspan,differential_vars, - args,kwargs) + constraints_model + tspan + args + differential_vars + kwargs end -function NeuralDAE(model::LuxCore.AbstractExplicitLayer,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...) - re = nothing - - NeuralDAE{typeof(p),typeof(model),typeof(constraints_model), - typeof(du0),typeof(re),typeof(tspan), - typeof(differential_vars),typeof(args),typeof(kwargs)}( - model,constraints_model,p,du0,re,tspan,differential_vars, - args,kwargs) +function NeuralDAE(model, constraints_model, tspan, args...; differential_vars = nothing, + kwargs...) + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return NeuralDAE(model, constraints_model, tspan, args, differential_vars, kwargs) end -@functor NeuralDAE (p,) +function (n::NeuralDAE)(u_du::Tuple, p, st) + u0, du0 = u_du + model = StatefulLuxLayer(n.model, nothing, st) -function (n::NeuralDAE)(x,du0=n.du0,p=n.p) - function f(du,u,p,t) - nn_out = n.re(p)(vcat(u,du)) - alg_out = n.constraints_model(u,p,t) - iter_nn = 0 - iter_consts = 0 - map(n.differential_vars) do isdiff + function f(du, u, p, t) + nn_out = model(vcat(u, du), p) + alg_out = n.constraints_model(u, p, t) + iter_nn, iter_const = 0, 0 + res = map(n.differential_vars) do isdiff if isdiff iter_nn += 1 nn_out[iter_nn] else - iter_consts += 1 - alg_out[iter_consts] + iter_const += 1 + alg_out[iter_const] end end + return res end - prob = DAEProblem{false}(f,du0,x,n.tspan,p,differential_vars=n.differential_vars) - solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) -end -function (n::NeuralDAE{P,M})(x,p,st) where {P,M<:LuxCore.AbstractExplicitLayer} - du0 = n.du0 - function f(du,u,p,t;st=st) - nn_out, st = n.model(vcat(u,du),p,st) - alg_out = n.constraints_model(u,p,t) - iter_nn = 0 - iter_consts = 0 - map(n.differential_vars) do isdiff - if isdiff - iter_nn += 1 - nn_out[iter_nn] - else - iter_consts += 1 - alg_out[iter_consts] - end - end - end - prob = DAEProblem{false}(f,du0,x,n.tspan,p,differential_vars=n.differential_vars) - return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st + prob = DAEProblem{false}(f, du0, u0, n.tspan, p; n.differential_vars) + return solve(prob, n.args...; sensealg = TrackerAdjoint(), n.kwargs...), st end """ -Constructs a physically-constrained continuous-time recurrant neural network, -also known as a neural differential-algebraic equation (neural DAE), with a -mass matrix and a fast gradient calculation via adjoints [1]. The mass matrix -formulation is: + NeuralODEMM(model, constraints_model, tspan, mass_matrix, alg = nothing, args...; + sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), kwargs...) + +Constructs a physically-constrained continuous-time recurrant neural network, also known as +a neural differential-algebraic equation (neural DAE), with a mass matrix and a fast +gradient calculation via adjoints [1]. The mass matrix formulation is: ```math Mu' = f(u,p,t) ``` -where `M` is semi-explicit, i.e. singular with zeros for rows corresponding to -the constraint equations. - -```julia -NeuralODEMM(model,constraints_model,tspan,mass_matrix,alg=nothing,args...;kwargs...) -``` +where `M` is semi-explicit, i.e. singular with zeros for rows corresponding to the +constraint equations. Arguments: - `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that defines the ̇`f(u,p,t)` -- `constraints_model`: A function `constraints_model(u,p,t)` for the fixed - constaints to impose on the algebraic equations. +- `constraints_model`: A function `constraints_model(u,p,t)` for the fixed constaints to + impose on the algebraic equations. - `tspan`: The timespan to be solved on. - `mass_matrix`: The mass matrix associated with the DAE -- `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the - default algorithm from DifferentialEquations.jl. This method requires an - implicit ODE solver compatible with singular mass matrices. Consult the - [DAE solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/dae_solve/) documentation for more details. +- `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the default + algorithm from DifferentialEquations.jl. This method requires an implicit ODE solver + compatible with singular mass matrices. Consult the + [DAE solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/dae_solve/) documentation + for more details. - `sensealg`: The choice of differentiation algorthm used in the backpropogation. Defaults to an adjoint method. See the [Local Sensitivity Analysis](https://docs.sciml.ai/DiffEqDocs/stable/analysis/sensitivity/) @@ -467,73 +302,41 @@ Arguments: - `kwargs`: Additional arguments splatted to the ODE solver. See the [Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details. - """ -struct NeuralODEMM{M,M2,P,RE,T,MM,A,K} <: NeuralDELayer +@concrete struct NeuralODEMM{M <: AbstractExplicitLayer} <: NeuralDELayer model::M - constraints_model::M2 - p::P - re::RE - tspan::T - mass_matrix::MM - args::A - kwargs::K -end - -function NeuralODEMM(model,constraints_model,tspan,mass_matrix,args...; - p = nothing, kwargs...) - _p,re = Flux.destructure(model) - - if p === nothing - p = _p - end - NeuralODEMM{typeof(model),typeof(constraints_model),typeof(p),typeof(re), - typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( - model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) + constraints_model + tspan + mass_matrix + args + kwargs end -function NeuralODEMM(model::LuxCore.AbstractExplicitLayer,constraints_model,tspan,mass_matrix,args...; - p=nothing,kwargs...) - re = nothing - NeuralODEMM{typeof(model),typeof(constraints_model),typeof(p),typeof(re), - typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( - model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) +function NeuralODEMM(model, constraints_model, tspan, mass_matrix, args...; kwargs...) + !(model isa AbstractExplicitLayer) && (model = Lux.transform(model)) + return NeuralODEMM(model, constraints_model, tspan, mass_matrix, args, kwargs) end -@functor NeuralODEMM (p,) +function (n::NeuralODEMM)(x, ps, st) + model = StatefulLuxLayer(n.model, nothing, st) -function (n::NeuralODEMM)(x,p=n.p) - function f(u,p,t) - nn_out = n.re(p)(u) - alg_out = n.constraints_model(u,p,t) - vcat(nn_out,alg_out) + function f(u, p, t) + nn_out = model(u, p) + alg_out = n.constraints_model(u, p, t) + return vcat(nn_out, alg_out) end - dudt_= ODEFunction{false}(f,mass_matrix=n.mass_matrix,tgrad=basic_tgrad) - prob = ODEProblem{false}(dudt_,x,n.tspan,p) - sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) - solve(prob,n.args...;sensealg=sense,n.kwargs...) -end + dudt = ODEFunction{false}(f; mass_matrix = n.mass_matrix, tgrad = basic_tgrad) + prob = ODEProblem{false}(dudt, x, n.tspan, ps) -function (n::NeuralODEMM{M})(x,p,st) where {M<:LuxCore.AbstractExplicitLayer} - function f(u,p,t;st=st) - nn_out,st = n.model(u,p,st) - alg_out = n.constraints_model(u,p,t) - return vcat(nn_out,alg_out) - end - dudt_= ODEFunction{false}(f;mass_matrix=n.mass_matrix,tgrad=basic_tgrad) - prob = ODEProblem{false}(dudt_,x,n.tspan,p) - - sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) - return solve(prob,n.args...;sensealg=sense,n.kwargs...), st + return (solve(prob, n.args...; + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...), model.st) end """ -Constructs an Augmented Neural Differential Equation Layer. + AugmentedNDELayer(nde, adim::Int) -```julia -AugmentedNDELayer(nde, adim::Int) -``` +Constructs an Augmented Neural Differential Equation Layer. Arguments: @@ -543,56 +346,44 @@ Arguments: References: [1] Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented neural ODEs." In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pp. 3140-3150. 2019. - """ -abstract type AugmentedNDEType <: LuxCore.AbstractExplicitContainerLayer{(:nde,)} end -struct AugmentedNDELayer{DE<:Union{NeuralDELayer,NeuralSDELayer}} <: AugmentedNDEType - nde::DE - adim::Int +function AugmentedNDELayer(model::Union{NeuralDELayer, NeuralSDELayer}, adim::Int) + return Chain(Base.Fix2(__augment, adim), model) end -(ande::AugmentedNDELayer)(x, args...) = ande.nde(augment(x, ande.adim), args...) - -augment(x::AbstractVector{S}, augment_dim::Int) where S = - cat(x, zeros(S, (augment_dim,)), dims = 1) - -augment(x::AbstractArray{S, T}, augment_dim::Int) where {S, T} = - cat(x, zeros(S, (size(x)[1:(T - 2)]..., augment_dim, size(x, T))), dims = T - 1) - -Base.getproperty(ande::AugmentedNDELayer, sym::Symbol) = - hasproperty(ande, sym) ? getfield(ande, sym) : getfield(ande.nde, sym) +function __augment(x::AbstractVector, augment_dim::Int) + y = CRC.@ignore_derivatives fill!(similar(x, augment_dim), 0) + return vcat(x, y) +end -abstract type HelperLayer <: Function end +function __augment(x::AbstractArray, augment_dim::Int) + y = CRC.@ignore_derivatives fill!(similar(x, size(x)[1:(ndims(x) - 2)]..., + augment_dim, size(x, ndims(x))), 0) + return cat(x, y; dims = Val(ndims(x) - 1)) +end """ + DimMover(from, to) + Constructs a Dimension Mover Layer. -```julia -DimMover(from, to) -``` +We can have Flux's conventional order `(data, channel, batch)` by using it as the last layer +of `Flux.Chain` to swap the batch-index and the time-index of the Neural DE's output +considering that each time point is a channel. """ -struct DimMover <: HelperLayer - from::Integer - to::Integer +@concrete struct DimMover <: AbstractExplicitLayer + from + to end -function (dm::DimMover)(x) - @assert !iszero(dm.from) - @assert !iszero(dm.to) - - from = dm.from > 0 ? dm.from : (length(size(x)) + 1 + dm.from) - to = dm.to > 0 ? dm.to : (length(size(x)) + 1 + dm.to) - - cat(eachslice(x; dims=from)...; dims=to) +function DimMover(; from = -2, to = -1) + @assert from !== 0 && to !== 0 + return DimMover(from, to) end -""" -We can have Flux's conventional order (data, channel, batch) -by using it as the last layer of `Flux.Chain` to swap the batch-index and the time-index of the Neural DE's output. -considering that each time point is a channel. +function (dm::DimMover)(x, ps, st) + from = dm.from > 0 ? dm.from : (ndims(x) + 1 + dm.from) + to = dm.to > 0 ? dm.to : (ndims(x) + 1 + dm.to) -```julia -FluxBatchOrder = DimMover(-2, -1) -``` -""" -const FluxBatchOrder = DimMover(-2, -1) + return cat(eachslice(x; dims = from)...; dims = to), st +end diff --git a/src/spline_layer.jl b/src/spline_layer.jl index a6089176ff..b4fbee70ee 100644 --- a/src/spline_layer.jl +++ b/src/spline_layer.jl @@ -1,37 +1,47 @@ -abstract type AbstractSplineLayer <: Function end -Flux.trainable(m::AbstractSplineLayer) = (m.p,) - """ + SplineLayer(time_span, time_step, spline_basis, init_saved_points = nothing) + Constructs a Spline Layer. At a high-level, it performs the following: + 1. Takes as input a one-dimensional training dataset, a time span, a time step and -an interpolation method. -2. During training, adjusts the values of the function at multiples of the time-step -such that the curve interpolated through these points has minimum loss on the corresponding -one-dimensional dataset. + an interpolation method. +2. During training, adjusts the values of the function at multiples of the time-step such + that the curve interpolated through these points has minimum loss on the corresponding + one-dimensional dataset. -```julia -SplineLayer(time_span,time_step,spline_basis,saved_points=nothing) -``` Arguments: + - `time_span`: Tuple of real numbers corresponding to the time span. - `time_step`: Real number corresponding to the time step. - `spline_basis`: Interpolation method to be used yb the basis (current supported - interpolation methods: ConstantInterpolation, LinearInterpolation, QuadraticInterpolation, - QuadraticSpline, CubicSpline). -- 'saved_points': values of the function at multiples of the time step. Initialized by default -to a random vector sampled from the unit normal. + interpolation methods: `ConstantInterpolation`, `LinearInterpolation`, + `QuadraticInterpolation`, `QuadraticSpline`, `CubicSpline`). +- 'init_saved_points': values of the function at multiples of the time step. Initialized by + default to a random vector sampled from the unit normal. Alternatively, can take a + function with the signature `init_saved_points(rng, time_span, time_step)`. """ -struct SplineLayer{T<:Tuple{Real, Real},R<:Real,S1<:AbstractVector,S2<:UnionAll} <: AbstractSplineLayer - time_span::T - time_step::R - saved_points::S1 - spline_basis::S2 - function SplineLayer(time_span,time_step,spline_basis,saved_points=nothing) - saved_points = randn(length(time_span[1]:time_step:time_span[2])) - new{typeof(time_span),typeof(time_step),typeof(saved_points),typeof(spline_basis)}(time_span,time_step,saved_points,spline_basis) +@concrete struct SplineLayer <: AbstractExplicitLayer + tspan + tstep + spline_basis + init_saved_points +end + +function SplineLayer(tspan, tstep, spline_basis; init_saved_points::F = nothing) where {F} + return SplineLayer(tspan, tstep, spline_basis, init_saved_points) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SplineLayer) + if l.init_saved_points === nothing + return (; + saved_points = randn(rng, typeof(l.tspan[1]), + length(l.tspan[1]:(l.tstep):l.tspan[2]))) + else + return (; saved_points = l.init_saved_points(rng, l.tspan, l.tstep)) end end -function (layer::SplineLayer)(t::Real,p=layer.saved_points) - return layer.spline_basis(p,layer.time_span[1]:layer.time_step:layer.time_span[2])(t) +function (layer::SplineLayer)(t, ps, st) + return (layer.spline_basis(ps.saved_points, + layer.tspan[1]:(layer.tstep):layer.tspan[2])(t), st) end diff --git a/src/tensor_product.jl b/src/tensor_product.jl new file mode 100644 index 0000000000..da0dc4d95a --- /dev/null +++ b/src/tensor_product.jl @@ -0,0 +1,121 @@ +abstract type TensorProductBasis <: Function end + +@concrete struct TensorProductBasisFunction + f + n +end + +(basis::TensorProductBasisFunction)(x) = map(i -> basis.f(i, x), 1:(basis.n)) + +""" + ChebyshevBasis(n) + +Constructs a Chebyshev basis of the form [T_{0}(x), T_{1}(x), ..., T_{n-1}(x)] where T_j(.) +is the j-th Chebyshev polynomial of the first kind. + +Arguments: + +- `n`: number of terms in the polynomial expansion. +""" +ChebyshevBasis(n) = TensorProductBasisFunction(__chebyshev, n) + +__chebyshev(i, x) = cos(i * acos(x)) + +""" + SinBasis(n) + +Constructs a sine basis of the form [sin(x), sin(2*x), ..., sin(n*x)]. + +Arguments: + +- `n`: number of terms in the sine expansion. +""" +SinBasis(n) = TensorProductBasisFunction(sin ∘ *, n) + +""" + CosBasis(n) + +Constructs a cosine basis of the form [cos(x), cos(2*x), ..., cos(n*x)]. + +Arguments: + +- `n`: number of terms in the cosine expansion. +""" +CosBasis(n) = TensorProductBasisFunction(cos ∘ *, n) + +""" + FourierBasis(n) + +Constructs a Fourier basis of the form +F_j(x) = j is even ? cos((j÷2)*x) : sin((j÷2)*x) => [F_0(x), F_1(x), ..., F_n(x)]. + +Arguments: + +- `n`: number of terms in the Fourier expansion. +""" +FourierBasis(n) = TensorProductBasisFunction(__fourier, n) + +__fourier(i::Int, x) = ifelse(iseven(i), cos(i * x / 2), sin(i * x / 2)) + +""" + LegendreBasis(n) + +Constructs a Legendre basis of the form [P_{0}(x), P_{1}(x), ..., P_{n-1}(x)] where +P_j(.) is the j-th Legendre polynomial. + +Arguments: + +- `n`: number of terms in the polynomial expansion. +""" +LegendreBasis(n) = TensorProductBasisFunction(__legendre_poly, n) + +## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl +function __legendre_poly(i::Int, x) + p = i - 1 + a = one(x) + b = x + + p ≤ 0 && return a + p == 1 && return b + + for j in 2:p + a, b = b, ((2j - 1) * x * b - (j - 1) * a) / j + end + + return b +end + +""" + PolynomialBasis(n) + +Constructs a Polynomial basis of the form [1, x, ..., x^(n-1)]. + +Arguments: + +- `n`: number of terms in the polynomial expansion. +""" +PolynomialBasis(n) = TensorProductBasisFunction(__polynomial, n) + +__polynomial(i, x) = x^(i - 1) + +""" + TensorLayer(model, out_dim::Int, init_p::F = randn) where {F <: Function} + +Constructs the Tensor Product Layer, which takes as input an array of n tensor +product basis, [B_1, B_2, ..., B_n] a data point x, computes +z[i] = W[i,:] ⨀ [B_1(x[1]) ⨂ B_2(x[2]) ⨂ ... ⨂ B_n(x[n])], where W is the layer's weight, +and returns [z[1], ..., z[out]]. + +Arguments: + +- `model`: Array of TensorProductBasis [B_1(n_1), ..., B_k(n_k)], where k corresponds to the + dimension of the input. +- `out`: Dimension of the output. +- `p`: Optional initialization of the layer's weight. Initialized to standard normal by + default. +""" +function TensorLayer(model, out_dim::Int, init_p::F = randn) where {F <: Function} + number_of_weights = prod(Base.Fix2(getproperty, :n), model) + return Chain(x -> mapfoldl(((m, xᵢ),) -> m(xᵢ), kron, zip(model, x)), + Dense(number_of_weights => out_dim; use_bias = false, init_weight = init_p)) +end diff --git a/src/tensor_product_basis.jl b/src/tensor_product_basis.jl deleted file mode 100644 index 2333f1dc1e..0000000000 --- a/src/tensor_product_basis.jl +++ /dev/null @@ -1,122 +0,0 @@ -abstract type TensorProductBasis <: Function end - -""" -Constructs a Chebyshev basis of the form [T_{0}(x), T_{1}(x), ..., T_{n-1}(x)] where T_j(.) is the j-th Chebyshev polynomial of the first kind. -```julia -ChebyshevBasis(n) -``` -Arguments: -- `n`: number of terms in the polynomial expansion. -""" -struct ChebyshevBasis <: TensorProductBasis - n::Int -end - -function (basis::ChebyshevBasis)(x) - return map(j -> cos(j*acos(x)), 1:basis.n) -end - -""" -Constructs a sine basis of the form [sin(x), sin(2*x), ..., sin(n*x)]. -```julia -SinBasis(n) -``` -Arguments: -- `n`: number of terms in the sine expansion. -""" -struct SinBasis <: TensorProductBasis - n::Int -end - -function (basis::SinBasis)(x) - return map(j -> sin(j*x), 1:basis.n) -end - -""" -Constructs a cosine basis of the form [cos(x), cos(2*x), ..., cos(n*x)]. -```julia -CosBasis(n) -``` -Arguments: -- `n`: number of terms in the cosine expansion. -""" -struct CosBasis <: TensorProductBasis - n::Int -end - -function (basis::CosBasis)(x) - return map(j -> cos(j*x), 1:basis.n) -end - -#auxiliary function -function fourier(i::Int, x::Real) - return iseven(i) ? cos(i*x/2) : sin(i*x/2) -end - -""" -Constructs a Fourier basis of the form F_j(x) = j is even ? cos((j÷2)*x) : sin((j÷2)*x) => [F_0(x), F_1(x), ..., F_n(x)]. -```julia -FourierBasis(n) -``` -Arguments: -- `n`: number of terms in the Fourier expansion. -""" -struct FourierBasis <: TensorProductBasis - n::Int -end - -function (basis::FourierBasis)(x) - return map(j -> fourier(j,x), 1:basis.n) -end - -#auxiliary function -##Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl -function legendre_poly(x, p::Integer) - a::typeof(x) = one(x) - b::typeof(x) = x - - if p <= 0 - return a - elseif p == 1 - return b - end - - for j in 2:p - a, b = b, ((2j-1)*x*b - (j-1)*a) / j - end - - b -end - -""" -Constructs a Legendre basis of the form [P_{0}(x), P_{1}(x), ..., P_{n-1}(x)] where P_j(.) is the j-th Legendre polynomial. -```julia -LegendreBasis(n) -``` -Arguments: -- `n`: number of terms in the polynomial expansion. -""" -struct LegendreBasis <: TensorProductBasis - n::Int -end - -function (basis::LegendreBasis)(x) - f = k -> legendre_poly(x,k-1) - return map(f, 1:basis.n) -end - -""" -Constructs a Polynomial basis of the form [1, x, ..., x^(n-1)]. -```julia -PolynomialBasis(n) -``` -Arguments: -- `n`: number of terms in the polynomial expansion. -""" -struct PolynomialBasis <: TensorProductBasis - n::Int -end - -function (basis::PolynomialBasis)(x) - return [evalpoly(x, (I+zeros(basis.n,basis.n))[k,:]) for k in 1:basis.n] -end diff --git a/src/tensor_product_layer.jl b/src/tensor_product_layer.jl deleted file mode 100644 index 926a3f7aef..0000000000 --- a/src/tensor_product_layer.jl +++ /dev/null @@ -1,42 +0,0 @@ -abstract type AbstractTensorProductLayer <: Function end -""" -Constructs the Tensor Product Layer, which takes as input an array of n tensor -product basis, [B_1, B_2, ..., B_n] a data point x, computes -z[i] = W[i,:] ⨀ [B_1(x[1]) ⨂ B_2(x[2]) ⨂ ... ⨂ B_n(x[n])], where W is the layer's weight, -and returns [z[1], ..., z[out]]. - -```julia -TensorLayer(model,out,p=nothing) -``` -Arguments: -- `model`: Array of TensorProductBasis [B_1(n_1), ..., B_k(n_k)], where k corresponds to the dimension of the input. -- `out`: Dimension of the output. -- `p`: Optional initialization of the layer's weight. Initialized to standard normal by default. -""" -struct TensorLayer{M<:Array{TensorProductBasis},P<:AbstractArray,Int} <: AbstractTensorProductLayer - model::M - p::P - in::Int - out::Int - function TensorLayer(model,out,p=nothing) - number_of_weights = 1 - for basis in model - number_of_weights *= basis.n - end - if p === nothing - p = randn(out*number_of_weights) - end - new{Array{TensorProductBasis},typeof(p),Int}(model,p,length(model),out) - end -end - -function (layer::TensorLayer)(x,p=layer.p) - model,out = layer.model,layer.out - W = reshape(p, out, Int(length(p)/out)) - tensor_prod = model[1](x[1]) - for i in 2:length(model) - tensor_prod = kron(tensor_prod,model[i](x[i])) - end - z = W*tensor_prod - return z -end diff --git a/test/Project.toml b/test/Project.toml index c2a3a409cb..1340f5e68d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -11,13 +12,16 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea" GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" diff --git a/test/augmented_nde.jl b/test/augmented_nde.jl deleted file mode 100644 index 859d19d90b..0000000000 --- a/test/augmented_nde.jl +++ /dev/null @@ -1,95 +0,0 @@ -using ComponentArrays, DiffEqFlux, Flux, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random -import Lux - -x = Float32[2.; 0.] -xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) -tspan = (0.0f0, 1.0f0) -fluxdudt = Flux.Chain(Flux.Dense(4, 50, tanh), Flux.Dense(50, 4)) -fluxdudt2 = Flux.Chain(Flux.Dense(4, 50, tanh), Flux.Dense(50, 4)) -fluxdudt22 = Flux.Chain(Flux.Dense(4, 50, tanh), Flux.Dense(50, 16), x -> reshape(x, 4, 4)) -fluxddudt = Flux.Chain(Flux.Dense(12, 50, tanh), Flux.Dense(50, 4)) - -# Augmented Neural ODE -anode = AugmentedNDELayer( - NeuralODE(fluxdudt, tspan, Tsit5(), save_everystep=false, save_start=false), 2 -) -anode(x) - -grads = Zygote.gradient(() -> sum(anode(x)), Flux.params(x, anode.nde)) -@test ! iszero(grads[x]) -@test ! iszero(grads[anode.p]) - -# Augmented Neural DSDE -andsde = AugmentedNDELayer( - NeuralDSDE(fluxdudt, fluxdudt2, (0.0f0, 0.1f0), SOSRI(), saveat=0.0:0.01:0.1), 2 -) -andsde(x) - -grads = Zygote.gradient(() -> sum(andsde(x)), Flux.params(x, andsde.nde)) -@test ! iszero(grads[x]) -@test ! iszero(grads[andsde.p]) - -# Augmented Neural SDE -asode = AugmentedNDELayer( - NeuralSDE(fluxdudt, fluxdudt22,(0.0f0, 0.1f0), 4, LambaEM(), saveat=0.0:0.01:0.1), 2 -) -asode(x) - -grads = Zygote.gradient(() -> sum(asode(x)), Flux.params(x, asode.nde)) -@test ! iszero(grads[x]) -@test ! iszero(grads[asode.p]) - -# Augmented Neural CDDE -adode = AugmentedNDELayer( - NeuralCDDE(fluxddudt, (0.0f0, 2.0f0), (p, t) -> zeros(Float32, 4), (1f-1, 2f-1), - MethodOfSteps(Tsit5()), saveat=0.0:0.1:2.0), 2 -) -adode(x) - -grads = Zygote.gradient(() -> sum(adode(x)), Flux.params(x, adode.nde)) -@test ! iszero(grads[x]) -@test ! iszero(grads[adode.p]) - -## AugmentedNDELayer with Lux - -rng = Random.default_rng() - -dudt = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4)) -dudt2 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4)) -dudt22 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 16), (x) -> reshape(x, 4, 4)) - -# Augmented Neural ODE -anode = AugmentedNDELayer( - NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false), 2 -) -pd, st = Lux.setup(rng, anode) -pd = ComponentArray(pd) -anode(x,pd,st) - -grads = Zygote.gradient((x,p,st) -> sum(anode(x,p,st)[1]), x, pd, st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -# Augmented Neural DSDE -andsde = AugmentedNDELayer( - NeuralDSDE(dudt, dudt2, (0.0f0, 0.1f0), EulerHeun(), saveat=0.0:0.01:0.1, dt=0.01), 2 -) -pd, st = Lux.setup(rng, andsde) -pd = ComponentArray(pd) -andsde(x,pd,st) - -grads = Zygote.gradient((x,p,st) -> sum(andsde(x,p,st)[1]), x, pd, st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -# Augmented Neural SDE -asode = AugmentedNDELayer( - NeuralSDE(dudt, dudt22,(0.0f0, 0.1f0), 4, EulerHeun(), saveat=0.0:0.01:0.1, dt=0.01), 2 -) -pd, st = Lux.setup(rng, asode) -pd = ComponentArray(pd) -asode(x,pd,st) - -grads = Zygote.gradient((x,p,st) -> sum(asode(x,p,st)[1]), x, pd, st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) diff --git a/test/cnf_test.jl b/test/cnf_test.jl index a142f730e6..4bf2d96e0b 100644 --- a/test/cnf_test.jl +++ b/test/cnf_test.jl @@ -1,236 +1,89 @@ using DiffEqFlux, Zygote, Distances, Distributions, DistributionsAD, Optimization, - LinearAlgebra, OrdinaryDiffEq, Random, Test, OptimizationOptimisers + LinearAlgebra, OrdinaryDiffEq, Random, Test, OptimizationOptimisers, Statistics, + ComponentArrays Random.seed!(1999) ## callback to be used by all tests -function callback(p, l) - @show l - false +function callback(adtype) + return function (p, l) + @info "[FFJORD $(nameof(typeof(adtype)))] Loss: $(l)" + false + end end @testset "Smoke test for FFJORD" begin - nn = Flux.Chain( - Flux.Dense(1, 1, tanh), - ) |> f32 + nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5()) + ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) data_dist = Beta(2.0f0, 2.0f0) train_data = Float32.(rand(data_dist, 1, 100)) - function loss(θ; regularize, monte_carlo) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - -mean(logpx) - end - - @testset "AutoForwardDiff as adtype" begin - adtype = Optimization.AutoForwardDiff() - - @testset "regularize=false & monte_carlo=false" begin - regularize = false - monte_carlo = false - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=false & monte_carlo=true" begin - regularize = false - monte_carlo = true - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=false" begin - regularize = true - monte_carlo = false - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test_broken !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=true" begin - regularize = true - monte_carlo = true - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - end - @testset "AutoReverseDiff as adtype" begin - adtype = Optimization.AutoReverseDiff() - - @testset "regularize=false & monte_carlo=false" begin - regularize = false - monte_carlo = false - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=false & monte_carlo=true" begin - regularize = false - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=false" begin - regularize = true - monte_carlo = false - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test_broken !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=true" begin - regularize = true - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - end - @testset "AutoTracker as adtype" begin - adtype = Optimization.AutoTracker() - - @testset "regularize=false & monte_carlo=false" begin - regularize = false - monte_carlo = false - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=false & monte_carlo=true" begin - regularize = false - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=false" begin - regularize = true - monte_carlo = false - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test_broken !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=true" begin - regularize = true - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - end - @testset "AutoZygote as adtype" begin - adtype = Optimization.AutoZygote() - - @testset "regularize=false & monte_carlo=false" begin - regularize = false - monte_carlo = false - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=false & monte_carlo=true" begin - regularize = false - monte_carlo = true - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=false" begin - regularize = true - monte_carlo = false - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test_broken !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=true" begin - regularize = true - monte_carlo = true - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end - @testset "AutoFiniteDiff as adtype" begin - adtype = Optimization.AutoFiniteDiff() - - @testset "regularize=false & monte_carlo=false" begin - regularize = false - monte_carlo = false - - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=false & monte_carlo=true" begin - regularize = false - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=false" begin - regularize = true - monte_carlo = false - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test_broken !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) - end - @testset "regularize=true & monte_carlo=true" begin - regularize = true - monte_carlo = true - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - @test !isnothing(Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=10)) + @testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), + Optimization.AutoReverseDiff(), Optimization.AutoTracker(), + Optimization.AutoZygote(), Optimization.AutoFiniteDiff()) + @testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in (true, + false), monte_carlo in (true, false) + @info "regularize = $(regularize) & monte_carlo = $(monte_carlo)" + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + @test !isnothing(Optimization.solve(optprob, Adam(0.1); + callback = callback(adtype), maxiters = 3)) broken=(adtype isa Optimization.AutoTracker) end end end + @testset "Smoke test for FFJORDDistribution (sampling & pdf)" begin - nn = Flux.Chain( - Flux.Dense(1, 1, tanh), - ) |> f32 + nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5()) + ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) + + regularize = false + monte_carlo = false data_dist = Beta(2.0f0, 2.0f0) train_data = Float32.(rand(data_dist, 1, 100)) - function loss(θ; regularize, monte_carlo) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - -mean(logpx) + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end adtype = Optimization.AutoZygote() - regularize = false - monte_carlo = false + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ; regularize, monte_carlo), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - res = Optimization.solve(optprob, Adam(0.1); callback= callback, maxiters=10) + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) - ffjord_d = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res.u); regularize, monte_carlo) + ffjord_d = FFJORDDistribution(ffjord_mdl, res.u, st_) @test !isnothing(pdf(ffjord_d, train_data)) @test !isnothing(rand(ffjord_d)) @test !isnothing(rand(ffjord_d, 10)) end + @testset "Test for default base distribution and deterministic trace FFJORD" begin - nn = Flux.Chain( - Flux.Dense(1, 1, tanh), - ) |> f32 + nn = Chain(Dense(1, 1, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5()) + ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) + regularize = false monte_carlo = false @@ -238,29 +91,34 @@ end train_data = Float32.(rand(data_dist, 1, 100)) test_data = Float32.(rand(data_dist, 1, 100)) - function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - -mean(logpx) + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p) - res = Optimization.solve(optprob, Adam(0.1); callback= callback, maxiters=10) + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) + + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) actual_pdf = pdf.(data_dist, test_data) - learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1]) + learned_pdf = exp.(model(test_data, res.u)[1]) - @test ffjord_mdl.p != res.u + @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.9 end + @testset "Test for alternative base distribution and deterministic trace FFJORD" begin - nn = Flux.Chain( - Flux.Dense(1, 3, tanh), - Flux.Dense(3, 1, tanh), - ) |> f32 + nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5(); basedist=MvNormal([0.0f0], Diagonal([4.0f0]))) + ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); + basedist = MvNormal([0.0f0], Diagonal([4.0f0]))) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) + regularize = false monte_carlo = false @@ -268,28 +126,34 @@ end train_data = Float32.(rand(data_dist, 1, 100)) test_data = Float32.(rand(data_dist, 1, 100)) - function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - -mean(logpx) + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ), adtype) - optprob = Optimization.OptimizationProblem(optf, 0.01f0 * ffjord_mdl.p) - res = Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=300) + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) + + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), + maxiters = 30) actual_pdf = pdf.(data_dist, test_data) - learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1]) + learned_pdf = exp.(model(test_data, res.u)[1]) - @test 0.01f0 * ffjord_mdl.p != res.u + @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end + @testset "Test for multivariate distribution and deterministic trace FFJORD" begin - nn = Flux.Chain( - Flux.Dense(2, 2, tanh), - ) |> f32 + nn = Chain(Dense(2, 2, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5()) + ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) + regularize = false monte_carlo = false @@ -299,28 +163,34 @@ end train_data = Float32.(rand(data_dist, 100)) test_data = Float32.(rand(data_dist, 100)) - function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - -mean(logpx) + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ), adtype) - optprob = Optimization.OptimizationProblem(optf, 0.01f0 * ffjord_mdl.p) - res = Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=300) + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) + + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + res = Optimization.solve(optprob, Adam(0.01); callback = callback(adtype), + maxiters = 30) actual_pdf = pdf(data_dist, test_data) - learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1]) + learned_pdf = exp.(model(test_data, res.u)[1]) - @test 0.01f0 * ffjord_mdl.p != res.u + @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end + @testset "Test for default multivariate distribution and FFJORD with regularizers" begin - nn = Flux.Chain( - Flux.Dense(2, 2, tanh), - ) |> f32 + nn = Chain(Dense(2, 2, tanh)) tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, Tsit5()) + ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) + ps, st = Lux.setup(Random.default_rng(), ffjord_mdl) + ps = ComponentArray(ps) + regularize = true monte_carlo = true @@ -330,19 +200,23 @@ end train_data = Float32.(rand(data_dist, 100)) test_data = Float32.(rand(data_dist, 100)) - function loss(θ) - logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ; regularize, monte_carlo) - mean(-logpx .+ 1f-1 * λ₁ .+ 1f-1 * λ₂) + function loss(model, θ) + logpx, λ₁, λ₂ = model(train_data, θ) + return -mean(logpx) end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((θ,_) -> loss(θ), adtype) - optprob = Optimization.OptimizationProblem(optf, 0.01f0 * ffjord_mdl.p) - res = Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters=300) + st_ = (; st..., regularize, monte_carlo) + model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st_) + + optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) + optprob = Optimization.OptimizationProblem(optf, ps) + res = Optimization.solve(optprob, Adam(0.01); callback = callback(adtype), + maxiters = 30) actual_pdf = pdf(data_dist, test_data) - learned_pdf = exp.(ffjord_mdl(test_data, res.u; regularize, monte_carlo)[1]) + learned_pdf = exp.(model(test_data, res.u)[1]) - @test 0.01f0 * ffjord_mdl.p != res.u - @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.40 + @test ps != res.u + @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end diff --git a/test/collocation.jl b/test/collocation.jl index 14cac4017c..1f2b568133 100644 --- a/test/collocation.jl +++ b/test/collocation.jl @@ -1,43 +1,32 @@ using DiffEqFlux, OrdinaryDiffEq, Test -bounded_support_kernels = [ - EpanechnikovKernel(), - UniformKernel(), - TriangularKernel(), - QuarticKernel(), - TriweightKernel(), - TricubeKernel(), - CosineKernel(), -] +bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(), + QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()] -unbounded_support_kernels = - [GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] +unbounded_support_kernels = [GaussianKernel(), LogisticKernel(), SigmoidKernel(), + SilvermanKernel()] @testset "Kernel Functions" begin ts = collect(-5.0:0.1:5.0) @testset "Kernels with support from -1 to 1" begin minus_one_index = findfirst(x -> ==(x, -1.0), ts) plus_one_index = findfirst(x -> ==(x, 1.0), ts) - @testset "$kernel" for (kernel, x0) in zip( - bounded_support_kernels, - [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0], - ) + @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, + [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0]) ws = DiffEqFlux.calckernel.((kernel,), ts) # t < -1 - @test all(ws[1:minus_one_index-1] .== 0.0) + @test all(ws[1:(minus_one_index - 1)] .== 0.0) # t > 1 - @test all(ws[plus_one_index+1:end] .== 0.0) + @test all(ws[(plus_one_index + 1):end] .== 0.0) # -1 < t <1 - @test all(ws[minus_one_index+1:plus_one_index-1] .> 0.0) + @test all(ws[(minus_one_index + 1):(plus_one_index - 1)] .> 0.0) # t = 0 @test DiffEqFlux.calckernel(kernel, 0.0) == x0 end end @testset "Kernels with unbounded support" begin - @testset "$kernel" for (kernel, x0) in zip( - unbounded_support_kernels, - [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))], - ) + @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, + [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) # t = 0 @test DiffEqFlux.calckernel(kernel, 0.0) == x0 end @@ -45,18 +34,18 @@ unbounded_support_kernels = end @testset "Collocation of data" begin - function f(u, p, t) - p .* u - end + f(u, p, t) = p .* u rc = 2 ps = repeat([-0.001], rc) tspan = (0.0, 50.0) u0 = 3.4 .+ ones(rc) - t = collect(range(minimum(tspan), stop = maximum(tspan), length = 1000)) + t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) prob = ODEProblem(f, u0, tspan, ps) - data = Array(solve(prob, Tsit5(), saveat = t, abstol = 1e-12, reltol = 1e-12)) - @testset "$kernel" for kernel in - [bounded_support_kernels..., unbounded_support_kernels...] + data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) + @testset "$kernel" for kernel in [ + bounded_support_kernels..., + unbounded_support_kernels..., + ] u′, u = collocate_data(data, t, kernel, 0.003) @test sum(abs2, u - data) < 1e-8 end diff --git a/test/hamiltonian_nn.jl b/test/hamiltonian_nn.jl index 6221fa1452..5437963ad5 100644 --- a/test/hamiltonian_nn.jl +++ b/test/hamiltonian_nn.jl @@ -1,26 +1,25 @@ -using DiffEqFlux, Zygote, OrdinaryDiffEq, ForwardDiff, Test, Optimisers, Random, Lux, ComponentArrays, Statistics +using DiffEqFlux, Zygote, OrdinaryDiffEq, ForwardDiff, Test, Optimisers, Random, Lux, + ComponentArrays, Statistics # Checks for Shapes and Non-Zero Gradients u0 = rand(Float32, 6, 1) -hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) -ps, st = Lux.setup(Random.default_rng(), hnn) -ps = ps |> ComponentArray - -@test size(first(hnn(u0, ps, st))) == (6, 1) +for ad in (AutoForwardDiff(), AutoZygote()) + hnn = HamiltonianNN(Chain(Dense(6 => 12, relu), Dense(12 => 1)); ad) + ps, st = Lux.setup(Random.default_rng(), hnn) + ps = ps |> ComponentArray -@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) + @test size(first(hnn(u0, ps, st))) == (6, 1) -hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) -ps, st = Lux.setup(Random.default_rng(), hnn) -ps = ps |> ComponentArray + @test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) -@test size(first(hnn(u0, ps, st))) == (6, 1) + ad isa AutoZygote && continue -@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) + @test !iszero(only(Zygote.gradient(ps -> sum(first(hnn(u0, ps, st))), ps))) +end # Test Convergence on a toy problem -t = range(0.0f0, 1.0f0, length=64) +t = range(0.0f0, 1.0f0; length = 64) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) @@ -30,7 +29,7 @@ dpdt = -2π_32 .* q_t data = vcat(q_t, p_t) target = vcat(dqdt, dpdt) -hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 16, relu), Lux.Dense(16, 1))) +hnn = HamiltonianNN(Chain(Dense(2 => 16, relu), Dense(16 => 1)); ad = AutoForwardDiff()) ps, st = Lux.setup(Random.default_rng(), hnn) ps = ps |> ComponentArray @@ -42,8 +41,7 @@ initial_loss = loss(data, target, ps) for epoch in 1:100 global ps, st_opt - # Forward Mode over Reverse Mode for Training - gs = ForwardDiff.gradient(ps -> loss(data, target, ps), ps) + gs = last(Zygote.gradient(loss, data, target, ps)) st_opt, ps = Optimisers.update!(st_opt, ps, gs) end @@ -54,11 +52,8 @@ final_loss = loss(data, target, ps) # Test output and gradient of NeuralHamiltonianDE Layer tspan = (0.0f0, 1.0f0) -model = NeuralHamiltonianDE( - hnn, tspan, Tsit5(), - save_everystep=false, save_start=true, - saveat=range(tspan[1], tspan[2], length=10) -) +model = NeuralHamiltonianDE(hnn, tspan, Tsit5(); save_everystep = false, save_start = true, + saveat = range(tspan[1], tspan[2]; length = 10)) sol = Array(first(model(data[:, 1], ps, st))) @test size(sol) == (2, 10) diff --git a/test/mnist_conv_gpu.jl b/test/mnist_conv_gpu.jl index dc16f76ea0..bc8ed8945f 100644 --- a/test/mnist_conv_gpu.jl +++ b/test/mnist_conv_gpu.jl @@ -1,112 +1,116 @@ -using DiffEqFlux, CUDA, Zygote, MLDatasets, OrdinaryDiffEq, Printf, Test -using Flux.Losses: logitcrossentropy -using Flux.Data: DataLoader -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs - -CUDA.allowscalar(false) -ENV["DATADEPS_ALWAYS_ACCEPT"] = true - -function loadmnist(batchsize = bs, train_split = 0.9) - # Use MLDataUtils LabelEnc for natural onehot conversion - onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, - LabelEnc.NativeLabels(collect(0:9))) - # Load MNIST - mnist = MNIST(split = :train) - imgs, labels_raw = mnist.features, mnist.targets - # Process images into (H,W,C,BS) batches - x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3))) - y_data = onehot(labels_raw) - (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data), - p = train_split) - return ( - # Use Flux's DataLoader to automatically minibatch and shuffle the data - DataLoader(Flux.gpu.(collect.((x_train, y_train))); batchsize = batchsize, - shuffle = true), - # Don't shuffle the test data - DataLoader(Flux.gpu.(collect.((x_test, y_test))); batchsize = batchsize, - shuffle = false) - ) -end - -# Main -const bs = 128 -const train_split = 0.9 -train_dataloader, test_dataloader = loadmnist(bs, train_split) - -down = Flux.Chain(Flux.Conv((3, 3), 1=>64, relu, stride = 1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, relu, stride = 2, pad=1), Flux.GroupNorm(64, 64), - Flux.Conv((4, 4), 64=>64, stride = 2, pad = 1)) |>Flux.gpu - -dudt = Flux.Chain(Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1), - Flux.Conv((3, 3), 64=>64, tanh, stride=1, pad=1)) |>Flux.gpu - -fc = Flux.Chain(Flux.GroupNorm(64, 64), x -> relu.(x), Flux.MeanPool((6, 6)), - x -> reshape(x, (64, :)), Flux.Dense(64,10)) |> Flux.gpu - -nn_ode = NeuralODE(dudt, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu - -function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) - return xarr[:,:,:,:,1] -end - -# Build our over-all model topology -model = Flux.Chain(down, # (28, 28, 1, BS) -> (6, 6, 64, BS) - nn_ode, # (6, 6, 64, BS) -> (6, 6, 64, BS, 1) - DiffEqArray_to_Array, # (6, 6, 64, BS, 1) -> (6, 6, 64, BS) - fc) # (6, 6, 64, BS) -> (10, BS) - -# To understand the intermediate NN-ODE layer, we can examine it's dimensionality -img, lab = train_dataloader.data[1][:, :, :, 1:1], train_dataloader.data[2][:, 1:1] - -x_d = down(img) - -# We can see that we can compute the forward pass through the NN topology -# featuring an NNODE layer. -x_m = model(img) - -classify(x) = argmax.(eachcol(x)) - -function accuracy(model, data; n_batches = 10) - total_correct = 0 - total = 0 - for (i, (x, y)) in enumerate(data) - # Only evaluate accuracy for n_batches - i > n_batches && break - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) - total_correct += sum(target_class .== predicted_class) - total += length(target_class) - end - return total_correct / total -end - -# burn in accuracy -accuracy(model, train_dataloader) - -loss(x, y) = logitcrossentropy(model(x), y) - -# burn in loss -loss(img, lab) - -opt = Adam(0.05) -iter = 0 - -cb() = begin - global iter += 1 - # Monitor that the weights do infact update - # Every 10 training iterations show accuracy - if iter % 10 == 1 - train_accuracy = accuracy(model, train_dataloader) * 100 - test_accuracy = accuracy(model, test_dataloader; - n_batches = length(test_dataloader)) * 100 - @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n", - iter, train_accuracy, test_accuracy) - end -end - -Flux.train!(loss, Flux.params(down, nn_ode.p, fc), train_dataloader, opt, cb = cb) -@test accuracy(model, test_dataloader; n_batches = length(test_dataloader)) > 0.8 +using DiffEqFlux, Statistics, + ComponentArrays, CUDA, Zygote, MLDatasets, OrdinaryDiffEq, Printf, Test, LuxCUDA, Random +using Optimization, OptimizationOptimisers +using MLDatasets: MNIST +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview +using OneHotArrays + +const cdev = cpu_device() +const gdev = gpu_device() + +CUDA.allowscalar(false) +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) + +function loadmnist(batchsize = bs) + # Use MLDataUtils LabelEnc for natural onehot conversion + function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + end + # Load MNIST + mnist = MNIST(; split = :train) + imgs, labels_raw = mnist.features, mnist.targets + # Process images into (H,W,C,BS) batches + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> + gdev + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> gdev + y_train = batchview(y_train, batchsize) + return x_train, y_train +end + +# Main +const bs = 128 +x_train, y_train = loadmnist(bs) + +down = Chain(Conv((3, 3), 1 => 64, relu; stride = 1), GroupNorm(64, 64), + Conv((4, 4), 64 => 64, relu; stride = 2, pad = 1), GroupNorm(64, 64), + Conv((4, 4), 64 => 64; stride = 2, pad = 1)) + +dudt = Chain(Conv((3, 3), 64 => 64, tanh; stride = 1, pad = 1), + Conv((3, 3), 64 => 64, tanh; stride = 1, pad = 1)) + +fc = Chain(GroupNorm(64, 64), x -> relu.(x), MeanPool((6, 6)), + x -> reshape(x, (64, :)), Dense(64, 10)) + +nn_ode = NeuralODE(dudt, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, + reltol = 1e-3, abstol = 1e-3, save_start = false) + +function DiffEqArray_to_Array(x) + xarr = gdev(x) + return xarr[:, :, :, :, 1] +end + +# Build our over-all model topology +m = Chain(down, # (28, 28, 1, BS) -> (6, 6, 64, BS) + nn_ode, # (6, 6, 64, BS) -> (6, 6, 64, BS, 1) + DiffEqArray_to_Array, # (6, 6, 64, BS, 1) -> (6, 6, 64, BS) + fc) # (6, 6, 64, BS) -> (10, BS) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> gdev +st = st |> gdev + +# To understand the intermediate NN-ODE layer, we can examine it's dimensionality +img = x_train[1][:, :, :, 1:1] |> gdev +lab = x_train[2][:, 1:1] |> gdev + +x_m, _ = m(img, ps, st) + +classify(x) = argmax.(eachcol(x)) + +function accuracy(model, data, ps, st; n_batches = 10) + total_correct = 0 + total = 0 + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(cdev(y)) + predicted_class = classify(cdev(first(model(x, ps, st)))) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end + +# burn in accuracy +accuracy(m, zip(x_train, y_train), ps, st) + +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end + +#burn in loss +loss_function(ps, x_train[1], y_train[1]) + +opt = OptimizationOptimisers.Adam(0.05) +iter = 0 + +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) + +function callback(ps, l, pred) + global iter += 1 + #Monitor that the weights do infact update + #Every 10 training iterations show accuracy + if (iter % 10 == 0) + @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))" + end + return false +end + +# Train the NN-ODE and monitor the loss and weights. +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); maxiters = 10, callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 diff --git a/test/mnist_gpu.jl b/test/mnist_gpu.jl index 85033eabfe..4f94c76188 100644 --- a/test/mnist_gpu.jl +++ b/test/mnist_gpu.jl @@ -1,110 +1,118 @@ -using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, - ComponentArrays, Random, Optimization, OptimizationOptimisers -using MLDatasets: MNIST -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs - -CUDA.allowscalar(false) -ENV["DATADEPS_ALWAYS_ACCEPT"] = true - -logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) - -function loadmnist(batchsize=bs) - # Use MLDataUtils LabelEnc for natural onehot conversion - onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) - # Load MNIST - mnist = MNIST(split=:train) - imgs, labels_raw = mnist.features, mnist.targets - # Process images into (H,W,C,BS) batches - x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> Lux.gpu - x_train = batchview(x_train, batchsize) - # Onehot and batch the labels - y_train = onehot(labels_raw) |> Lux.gpu - y_train = batchview(y_train, batchsize) - return x_train, y_train -end - -# Main -const bs = 128 -x_train, y_train = loadmnist(bs) - -down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) -nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), - Lux.Dense(10, 20, tanh)) -fc = Lux.Dense(20, 10) - -nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, reltol=1e-3, - abstol=1e-3, save_start=false) - -""" - DiffEqArray_to_Array(x) - -Cheap conversion of a `DiffEqArray` instance to a Matrix. -""" -function DiffEqArray_to_Array(x) - xarr = Lux.gpu(x) - return reshape(xarr, size(xarr)[1:2]) -end - -#Build our over-all model topology -m = Lux.Chain(; down, nn_ode, convert=Lux.WrappedFunction(DiffEqArray_to_Array), fc) -ps, st = Lux.setup(Random.default_rng(), m) -ps = ComponentArray(ps) |> Lux.gpu -st = st |> Lux.gpu - -#We can also build the model topology without a NN-ODE -m_no_ode = Lux.Chain(; down, nn, fc) -ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode) -ps_no_ode = ComponentArray(ps_no_ode) |> Lux.gpu -st_no_ode = st_no_ode |> Lux.gpu - -#To understand the intermediate NN-ODE layer, we can examine it's dimensionality -x_d = first(down(x_train[1], ps.down, st.down)) - -# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer. -x_m = first(m(x_train[1], ps, st)) -#Or without the NN-ODE layer. -x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode)) - -classify(x) = argmax.(eachcol(x)) - -function accuracy(model, data, ps, st; n_batches=100) - total_correct = 0 - total = 0 - st = Lux.testmode(st) - for (x, y) in collect(data)[1:n_batches] - target_class = classify(Lux.cpu(y)) - predicted_class = classify(Lux.cpu(first(model(x, ps, st)))) - total_correct += sum(target_class .== predicted_class) - total += length(target_class) - end - return total_correct / total -end -#burn in accuracy -accuracy(m, zip(x_train, y_train), ps, st) - -function loss_function(ps, x, y) - pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred -end - -#burn in loss -loss_function(ps, x_train[1], y_train[1]) - -opt = OptimizationOptimisers.Adam(0.05) -iter = 0 - -opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), - Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) - -function callback(ps, l, pred) - global iter += 1 - #Monitor that the weights do infact update - #Every 10 training iterations show accuracy - (iter % 10 == 0) && @show accuracy(m, zip(x_train, y_train), ps, st) - return false -end - -# Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) -@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 +using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, + ComponentArrays, Random, Optimization, OptimizationOptimisers, LuxCUDA +using MLDatasets: MNIST +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs + +CUDA.allowscalar(false) +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +const cdev = cpu_device() +const gdev = gpu_device() + +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) + +function loadmnist(batchsize = bs) + # Use MLDataUtils LabelEnc for natural onehot conversion + function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + end + # Load MNIST + mnist = MNIST(; split = :train) + imgs, labels_raw = mnist.features, mnist.targets + # Process images into (H,W,C,BS) batches + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> + gdev + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> gdev + y_train = batchview(y_train, batchsize) + return x_train, y_train +end + +# Main +const bs = 128 +x_train, y_train = loadmnist(bs) + +down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) +nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 20, tanh)) +fc = Lux.Dense(20, 10) + +nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, reltol = 1e-3, + abstol = 1e-3, save_start = false) + +""" + DiffEqArray_to_Array(x) + +Cheap conversion of a `DiffEqArray` instance to a Matrix. +""" +function DiffEqArray_to_Array(x) + xarr = gdev(x) + return reshape(xarr, size(xarr)[1:2]) +end + +#Build our over-all model topology +m = Lux.Chain(; down, nn_ode, convert = Lux.WrappedFunction(DiffEqArray_to_Array), fc) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> gdev +st = st |> gdev + +#We can also build the model topology without a NN-ODE +m_no_ode = Lux.Chain(; down, nn, fc) +ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode) +ps_no_ode = ComponentArray(ps_no_ode) |> gdev +st_no_ode = st_no_ode |> gdev + +#To understand the intermediate NN-ODE layer, we can examine it's dimensionality +x_d = first(down(x_train[1], ps.down, st.down)) + +# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer. +x_m = first(m(x_train[1], ps, st)) +#Or without the NN-ODE layer. +x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode)) + +classify(x) = argmax.(eachcol(x)) + +function accuracy(model, data, ps, st; n_batches = 100) + total_correct = 0 + total = 0 + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(cdev(y)) + predicted_class = classify(cdev(first(model(x, ps, st)))) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end +#burn in accuracy +accuracy(m, zip(x_train, y_train), ps, st) + +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end + +#burn in loss +loss_function(ps, x_train[1], y_train[1]) + +opt = OptimizationOptimisers.Adam(0.05) +iter = 0 + +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) + +function callback(ps, l, pred) + global iter += 1 + #Monitor that the weights do infact update + #Every 10 training iterations show accuracy + if (iter % 10 == 0) + @info "[MNIST GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))" + end + return false +end + +# Train the NN-ODE and monitor the loss and weights. +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 diff --git a/test/multiple_shoot.jl b/test/multiple_shoot.jl index 765d984bd7..ed3f715ab4 100644 --- a/test/multiple_shoot.jl +++ b/test/multiple_shoot.jl @@ -1,4 +1,5 @@ -using ComponentArrays, DiffEqFlux, Zygote, Lux, Optimization, OptimizationOptmisers, OrdinaryDiffEq, Test, Random +using ComponentArrays, DiffEqFlux, Zygote, Lux, Optimization, OptimizationOptimisers, + OrdinaryDiffEq, Test, Random using DiffEqFlux: group_ranges rng = Random.default_rng() @@ -13,34 +14,28 @@ rng = Random.default_rng() datasize = 30 u0 = Float32[2.0, 0.0] tspan = (0.0f0, 5.0f0) -tsteps = range(tspan[1], tspan[2], length = datasize) +tsteps = range(tspan[1], tspan[2]; length = datasize) # Get the data function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' + du .= ((u .^ 3)'true_A)' end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) +ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) # Define the Neural Network -nn = Lux.Chain(x -> x.^3, - Lux.Dense(2, 16, tanh), - Lux.Dense(16, 2)) +nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)) p_init, st = Lux.setup(rng, nn) p_init = ComponentArray(p_init) -neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps) -prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, p_init) +neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) +prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init) -function predict_single_shooting(p) - return Array(neuralode(u0, p, st)[1]) -end +predict_single_shooting(p) = Array(first(neuralode(u0, p, st))) # Define loss function -function loss_function(data, pred) - return sum(abs2, data - pred) -end +loss_function(data, pred) = sum(abs2, data - pred) ## Evaluate Single Shooting function loss_single_shooting(p) @@ -50,13 +45,12 @@ function loss_single_shooting(p) end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((p,_)->loss_single_shooting(p), adtype) +optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) optprob = Optimization.OptimizationProblem(optf, p_init) -res_single_shooting = Optimization.solve(optprob, Adam(0.05), - maxiters = 300) +res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) loss_ss, _ = loss_single_shooting(res_single_shooting.minimizer) -println("Single shooting loss: $(loss_ss)") +@info "Single shooting loss: $(loss_ss)" ## Test Multiple Shooting group_size = 3 @@ -64,14 +58,14 @@ continuity_term = 200 function loss_multiple_shooting(p) return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), - group_size; continuity_term, - abstol=1e-8, reltol=1e-6) # test solver kwargs + group_size; continuity_term, + abstol = 1e-8, reltol = 1e-6) # test solver kwargs end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((p,_)->loss_multiple_shooting(p), adtype) +optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype) optprob = Optimization.OptimizationProblem(optf, p_init) -res_ms = Optimization.solve(optprob, Adam(0.05), maxiters = 300) +res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) # Calculate single shooting loss with parameter from multiple_shoot training loss_ms, _ = loss_single_shooting(res_ms.minimizer) @@ -88,14 +82,14 @@ end function loss_multiple_shooting_abs2(p) return multiple_shoot(p, ode_data, tsteps, prob_node, - loss_function, continuity_loss_abs2, Tsit5(), - group_size; continuity_term) + loss_function, continuity_loss_abs2, Tsit5(), + group_size; continuity_term) end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((p,_)->loss_multiple_shooting_abs2(p), adtype) +optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_abs2(p), adtype) optprob = Optimization.OptimizationProblem(optf, p_init) -res_ms_abs2 = Optimization.solve(optprob, Adam(0.05), maxiters = 300) +res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) loss_ms_abs2, _ = loss_single_shooting(res_ms_abs2.minimizer) println("Multiple shooting loss with abs2: $(loss_ms_abs2)") @@ -103,16 +97,15 @@ println("Multiple shooting loss with abs2: $(loss_ms_abs2)") ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) function loss_multiple_shooting_fd(p) - return multiple_shoot(p, ode_data, tsteps, prob_node, - loss_function, continuity_loss_abs2, Tsit5(), - group_size; continuity_term, - sensealg=ForwardDiffSensitivity()) + return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, + continuity_loss_abs2, Tsit5(), group_size; continuity_term, + sensealg = ForwardDiffSensitivity()) end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((p,_)->loss_multiple_shooting_fd(p), adtype) +optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype) optprob = Optimization.OptimizationProblem(optf, p_init) -res_ms_fd = Optimization.solve(optprob, Adam(0.05), maxiters = 300) +res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) # Calculate single shooting loss with parameter from multiple_shoot training loss_ms_fd, _ = loss_single_shooting(res_ms_fd.minimizer) @@ -122,41 +115,39 @@ println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") # Integration return codes `!= :Success` should return infinite loss. # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. loss_fail, _ = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), - datasize; maxiters=1, verbose=false) + datasize; maxiters = 1, verbose = false) @test loss_fail == Inf ## Test for DomainErrors @test_throws DomainError multiple_shoot(p_init, ode_data, tsteps, prob_node, - loss_function, Tsit5(), 1) + loss_function, Tsit5(), 1) @test_throws DomainError multiple_shoot(p_init, ode_data, tsteps, prob_node, - loss_function, Tsit5(), datasize + 1) + loss_function, Tsit5(), datasize + 1) ## Ensembles u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]] function prob_func(prob, i, repeat) - remake(prob, u0 = u0s[i]) + remake(prob; u0 = u0s[i]) end -ensemble_prob = EnsembleProblem(prob_node, prob_func = prob_func) -ensemble_prob_trueODE = EnsembleProblem(prob_trueode, prob_func = prob_func) +ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func) +ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func) ensemble_alg = EnsembleThreads() trajectories = 2 -ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg, trajectories = trajectories, saveat = tsteps)) +ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, + saveat = tsteps)) group_size = 3 continuity_term = 200 function loss_multiple_shooting_ens(p) return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, - loss_function, Tsit5(), - group_size; continuity_term, - trajectories, - abstol=1e-8, reltol=1e-6) # test solver kwargs + loss_function, Tsit5(), group_size; continuity_term, trajectories, + abstol = 1e-8, reltol = 1e-6) # test solver kwargs end adtype = Optimization.AutoZygote() -optf = Optimization.OptimizationFunction((p,_)->loss_multiple_shooting_ens(p), adtype) +optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_ens(p), adtype) optprob = Optimization.OptimizationProblem(optf, p_init) -res_ms_ensembles = Optimization.solve(optprob, - Adam(0.05), maxiters = 300) +res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) loss_ms_ensembles, _ = loss_single_shooting(res_ms_ensembles.minimizer) diff --git a/test/neural_dae.jl b/test/neural_dae.jl index e366210d33..c92fdd0aaf 100644 --- a/test/neural_dae.jl +++ b/test/neural_dae.jl @@ -1,69 +1,72 @@ -using ComponentArrays, DiffEqFlux, Zygote, Optimization, OrdinaryDiffEq +using ComponentArrays, + DiffEqFlux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random #A desired MWE for now, not a test yet. -function rober(du,u,p,t) - y₁,y₂,y₃ = u - k₁,k₂,k₃ = p - du[1] = -k₁*y₁ + k₃*y₂*y₃ - du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 - du[3] = y₁ + y₂ + y₃ - 1 - nothing +function rober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ + du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2 + du[3] = y₁ + y₂ + y₃ - 1 + nothing end -M = [1. 0 0 - 0 1. 0 - 0 0 0] -prob_mm = ODEProblem(ODEFunction(rober,mass_matrix=M),[1.0,0.0,0.0],(0.0,10.0),(0.04,3e7,1e4)) -sol = solve(prob_mm,Rodas5(),reltol=1e-8,abstol=1e-8) +M = [1.0 0 0 + 0 1.0 0 + 0 0 0] +prob_mm = ODEProblem(ODEFunction(rober; mass_matrix = M), + [1.0, 0.0, 0.0], + (0.0, 10.0), + (0.04, 3e7, 1e4)) +sol = solve(prob_mm, Rodas5(); reltol = 1e-8, abstol = 1e-8) +dudt2 = Chain(x -> x .^ 3, Dense(6, 50, tanh), Dense(50, 3)) -dudt2 = Flux.Chain(x -> x.^3,Flux.Dense(6,50,tanh),Flux.Dense(50,2)) +u₀ = [1.0, 0, 0] +du₀ = [-0.04, 0.04, 0.0] +tspan = (0.0, 10.0) -ndae = NeuralDAE(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, DImplicitEuler(), - differential_vars = [true,true,false]) -truedu0 = similar(u₀) -f(truedu0,u₀,p,0.0) +ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, DFBDF(); + differential_vars = [true, true, false]) +ps, st = Lux.setup(Xoshiro(0), ndae) +ps = ComponentArray(ps) -ndae(u₀,truedu0,Float64.(ndae.p)) +ndae((u₀, du₀), ps, st) -function predict_n_dae(p) - ndae(u₀,p) -end +predict_n_dae(p) = first(ndae(u₀, p, st)) function loss(p) pred = predict_n_dae(p) - loss = sum(abs2,sol .- pred) - loss,pred + loss = sum(abs2, sol .- pred) + return loss, pred end -p = p .+ rand(3) .* p - optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, BFGS(initial_stepnorm = 0.0001)) +optprob = Optimization.OptimizationProblem(optfunc, ps) +res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001)) # Same stuff with Lux rng = Random.default_rng() -dudt2 = Lux.Chain(x -> x.^3,Lux.Dense(6,50,tanh),Lux.Dense(50,2)) +dudt2 = Chain(x -> x .^ 3, Dense(6, 50, tanh), Dense(50, 2)) p, st = Lux.setup(rng, dudt2) p = ComponentArray(p) -ndae = NeuralDAE(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, DImplicitEuler(), - differential_vars = [true,true,false]) +ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, DImplicitEuler(); + differential_vars = [true, true, false]) truedu0 = similar(u₀) -f(truedu0,u₀,p,0.0) +f(truedu0, u₀, p, 0.0) -ndae(u₀,p,st,truedu0) +ndae(u₀, p, st, truedu0) function predict_n_dae(p) - ndae(u₀,p,st)[1] + ndae(u₀, p, st)[1] end function loss(p) pred = predict_n_dae(p) - loss = sum(abs2,sol .- pred) - loss,pred + loss = sum(abs2, sol .- pred) + loss, pred end optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, BFGS(initial_stepnorm = 0.0001)) \ No newline at end of file +res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001)) diff --git a/test/neural_de.jl b/test/neural_de.jl index 3b6894b772..799514d6a8 100644 --- a/test/neural_de.jl +++ b/test/neural_de.jl @@ -1,165 +1,170 @@ -using DiffEqFlux, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random - -mp = Float32[0.1,0.1] -x = Float32[2.; 0.] -xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) -tspan = (0.0f0,1.0f0) -dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) - -NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(x) - -NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(xs) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(xs) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(xs) - -@info "Test some gradients" - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -@info "Test some adjoints" - -# Adjoint -@testset "adjoint mode" begin - node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:1.0) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.1) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) -end - -@info "Test Tracker" - -# RD -@testset "Tracker mode" begin - node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint()) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint()) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) -end - -@info "Test non-ODEs" - -dudt2 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) -NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.1)(x) -sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1) - -grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) -@test ! iszero(grads[x]) -@test ! iszero(grads[sode.p]) -@test ! iszero(grads[sode.p][end]) - -grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[sode.p]) -@test ! iszero(grads[sode.p][end]) - -dudt22 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,4),x->reshape(x,2,2)) -NeuralSDE(dudt,dudt22,(0.0f0,.1f0),2,LambaEM(),saveat=0.01)(x) - -sode = NeuralSDE(dudt,dudt22,(0.0f0,0.1f0),2,LambaEM(),saveat=0.0:0.01:0.1) - -grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) -@test ! iszero(grads[x]) -@test ! iszero(grads[sode.p]) -@test ! iszero(grads[sode.p][end]) - -@test_broken grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) -@test_broken ! iszero(grads[xs]) -@test ! iszero(grads[sode.p]) -@test ! iszero(grads[sode.p][end]) - -ddudt = Flux.Chain(Flux.Dense(6,50,tanh),Flux.Dense(50,2)) -NeuralCDDE(ddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.1)(x) -dode = NeuralCDDE(ddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.0:0.1:2.0) - -grads = Zygote.gradient(()->sum(dode(x)),Flux.params(x,dode)) -@test ! iszero(grads[x]) -@test ! iszero(grads[dode.p]) - -@test_broken grads = Zygote.gradient(()->sum(dode(xs)),Flux.params(xs,dode)) isa Tuple -@test_broken ! iszero(grads[xs]) -@test ! iszero(grads[dode.p]) - -@testset "DimMover" begin - r = rand(2, 3, 4, 5) - @test r[:, :, 1, :] == FluxBatchOrder(r)[:, :, :, 1] -end +using ComponentArrays, + DiffEqFlux, Lux, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random +import Flux + +rng = Random.default_rng() + +@testset "Neural DE: $(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + @testset "Neural ODE" begin + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ((; save_everystep = false, + save_start = false), + (; abstol = 1e-12, reltol = 1e-12, save_everystep = false, + save_start = false), + (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end + end + + diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + tspan = (0.0f0, 0.1f0) + @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (x, xs), + solver in (EulerHeun(), LambaEM(), SOSRI()) + + sode = NeuralDSDE(dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, + dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralDSDE(aug_dudt, aug_diffusion, tspan, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + + diffusion_sde = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 4), x -> reshape(x, 2, 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 4), x -> reshape(x, 2, 2)) + end + + aug_diffusion_sde = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 16), x -> reshape(x, 4, 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) + end + + @testset "NeuralSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), + solver in (EulerHeun(), LambaEM()) + + sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralSDE(aug_dudt, aug_diffusion_sde, tspan, 4, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(6 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(6 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(12 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(12 => 50, tanh), Dense(50 => 4)) + end + + @testset "NeuralCDDE u0: $(typeof(u0))" for u0 in (x, xs) + dode = NeuralCDDE(dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + pd, st = Lux.setup(rng, dode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ dode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + dode = NeuralCDDE(aug_dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + anode = AugmentedNDELayer(dode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end +end + +@testset "DimMover" begin + r = rand(2, 3, 4, 5) + layer = DimMover() + ps, st = Lux.setup(rng, layer) + + @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] +end diff --git a/test/neural_de_gpu.jl b/test/neural_de_gpu.jl index a4d1c03b1c..2b752936bc 100644 --- a/test/neural_de_gpu.jl +++ b/test/neural_de_gpu.jl @@ -1,104 +1,94 @@ -using DiffEqFlux, Lux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, - ComponentArrays - -CUDA.allowscalar(false) - -x = Float32[2.0; 0.0] |> Lux.gpu -xs = hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0]) |> Lux.gpu -tspan = (0.0f0, 25.0f0) - -mp = Lux.Chain(Lux.Dense(2, 2)) - -dudt = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) -ps_dudt, st_dudt = Lux.setup(Random.default_rng(), dudt) -ps_dudt = ComponentArray(ps_dudt) |> Lux.gpu -st_dudt = st_dudt |> Lux.gpu - -NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(x, ps_dudt, st_dudt) -NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(x, ps_dudt, st_dudt) -NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(x, ps_dudt, st_dudt) - -NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(xs, ps_dudt, st_dudt) -NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(xs, ps_dudt, st_dudt) -NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(xs, ps_dudt, st_dudt) - -node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) -ps_node, st_node = Lux.setup(Random.default_rng(), node) -ps_node = ComponentArray(ps_node) |> Lux.gpu -st_node = st_node |> Lux.gpu -grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, - sensealg=TrackerAdjoint()) -grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, - sensealg=BacksolveAdjoint()) -grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -# Adjoint -@testset "adjoint mode" begin - node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) - grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - node = NeuralODE(dudt, tspan, Tsit5(), saveat=0.0:0.1:10.0) - grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - node = NeuralODE(dudt, tspan, Tsit5(), saveat=1.0f-1) - grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) - @test !iszero(grads[1]) - @test !iszero(grads[2]) -end - -ndsde = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=0.0:0.1:2.0) -ps_ndsde, st_ndsde = Lux.setup(Random.default_rng(), ndsde) -ps_ndsde = ComponentArray(ps_ndsde) |> Lux.gpu -st_ndsde = st_ndsde |> Lux.gpu -ndsde(x, ps_ndsde, st_ndsde) - -sode = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=Float32.(0.0:0.1:2.0), - dt=1.0f-1, sensealg=TrackerAdjoint()) -ps_sode, st_sode = Lux.setup(Random.default_rng(), sode) -ps_sode = ComponentArray(ps_sode) |> Lux.gpu -st_sode = st_sode |> Lux.gpu -grads = Zygote.gradient((x, ps) -> sum(first(sode(x, ps, st_sode))), x, ps_sode) -@test !iszero(grads[1]) -@test !iszero(grads[2]) - -grads = Zygote.gradient((xs, ps) -> sum(first(sode(xs, ps, st_sode))), xs, ps_sode) -@test !iszero(grads[1]) -@test !iszero(grads[2]) +using DiffEqFlux, Lux, LuxCUDA, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, + Random, ComponentArrays +import Flux + +CUDA.allowscalar(false) + +rng = Random.default_rng() + +const gdev = gpu_device() +const cdev = cpu_device() + +@testset "[CUDA] Neural DE: $(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] |> gdev + x = Float32[2.0; 0.0] |> gdev + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + @testset "Neural ODE" begin + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ((; save_everystep = false, + save_start = false), + (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) + CUDA.@allowscalar begin + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + + anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + CUDA.@allowscalar begin + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end + end + end + + diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + tspan = (0.0f0, 0.1f0) + @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), + solver in (SOSRI(),) + # CuVector seems broken on CI but I can't reproduce the failure locally + + sode = NeuralDSDE(dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, + dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + CUDA.@allowscalar begin + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + end +end diff --git a/test/neural_de_lux.jl b/test/neural_de_lux.jl deleted file mode 100644 index b261e95ca2..0000000000 --- a/test/neural_de_lux.jl +++ /dev/null @@ -1,143 +0,0 @@ -using ComponentArrays, DiffEqFlux, Zygote, Lux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random - -rng = Random.default_rng() - -mp = Float32[0.1,0.1] -x = Float32[2.; 0.] -xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) -tspan = (0.0f0,1.0f0) -luxdudt = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2)) - -## Lux - -@info "Test some Lux layers" - -node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false) -pd, st = Lux.setup(rng, node) -pd = ComponentArray(pd) -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -#test with low tolerance ode solver -node = NeuralODE(luxdudt, tspan, Tsit5(), abstol=1e-12, reltol=1e-12, save_everystep=false, save_start=false) -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) - -@info "Test some adjoints" - -# Adjoint -@testset "adjoint mode" begin - node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false) - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0) - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.1) - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) -end - -@info "Test Tracker" - -# RD -@testset "Tracker mode" begin - node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint()) - @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - @test ! iszero(grads[1]) - @test ! iszero(grads[2]) - - @test_throws Any grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - #@test_broken ! iszero(grads[1]) - #@test_broken ! iszero(grads[2]) - - node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint()) - @test_throws Any grad = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) - #@test ! iszero(grads[1]) - #@test ! iszero(grads[2]) - - @test_throws Any grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) - #@test_broken ! iszero(grads[1]) - #@test_broken ! iszero(grads[2]) -end - -@info "Test non-ODEs" - -luxdudt2 = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2)) - -sode = NeuralDSDE(luxdudt,luxdudt2,(0.0f0,.1f0),EulerHeun(),saveat=0.0:0.01:0.1,dt=0.1) -pd, st = Lux.setup(rng, sode) -pd = ComponentArray(pd) - -grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),x,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) -@test ! iszero(grads[2][end]) - -grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),xs,pd,st) -@test ! iszero(grads[1]) -@test ! iszero(grads[2]) -@test ! iszero(grads[2][end]) - -luxdudt22 = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,4),x->reshape(x,2,2)) - -sode = NeuralSDE(luxdudt,luxdudt22,(0.0f0,0.1f0),2,EulerHeun(),saveat=0.0:0.01:0.1,dt=0.01) -pd,st = Lux.setup(rng, sode) -pd = ComponentArray(pd) - -grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),x,pd,st) -@test_broken ! iszero(grads[1]) -@test_broken ! iszero(grads[2]) -@test_broken ! iszero(grads[2][end]) - -@test_throws Any grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)),xs,pd,st) \ No newline at end of file diff --git a/test/neural_gde.jl b/test/neural_gde.jl index 5cab7dac4f..e8e0e9a2b5 100644 --- a/test/neural_gde.jl +++ b/test/neural_gde.jl @@ -1,35 +1,44 @@ -using DiffEqFlux, GeometricFlux, GraphSignals, OrdinaryDiffEq, Test +using DiffEqFlux, ComponentArrays, GeometricFlux, GraphSignals, OrdinaryDiffEq, Random, + Test, OptimizationOptimisers, Optimization, Statistics +import Flux # Fully Connected Graph adj_mat = FeaturedGraph(Float32[0 1 1 1 - 1 0 1 1 - 1 1 0 1 - 1 1 1 0]) + 1 0 1 1 + 1 1 0 1 + 1 1 1 0]) features = [-10.0f0 -9.0f0 9.0f0 10.0f0 - 0.0f0 0.0f0 0.0f0 0.0f0] - -target = [1.0 1.0 0.0 0.0 - 0.0 0.0 1.0 1.0] - -model = Flux.Chain( - NeuralODE( - GCNConv(adj_mat, 2=>2), - (0.f0, 1.f0), Tsit5(), save_everystep = false, - reltol = 1e-3, abstol = 1e-3, save_start = false - ), - x -> reshape(cpu(x), size(x)[1:2]) -) - -ps = Flux.params(model) -opt = Adam(0.1) - -initial_loss = Flux.Losses.logitcrossentropy(model(features), target) - -# for i in 1:100 -# gs = gradient(() -> Flux.Losses.logitcrossentropy(model(features), target), ps) -# Flux.Optimise.update!(opt, ps, gs) -# end -updated_loss = Flux.Losses.logitcrossentropy(model(features), target) + 0.0f0 0.0f0 0.0f0 0.0f0] + +target = Float32[1.0 1.0 0.0 0.0 + 0.0 0.0 1.0 1.0] + +model = Chain(NeuralODE(WithGraph(adj_mat, GCNConv(2 => 2)), (0.0f0, 1.0f0), Tsit5(); + save_everystep = false, reltol = 1e-3, abstol = 1e-3, save_start = false), + x -> reshape(Array(x), size(x)[1:2])) + +ps, st = Lux.setup(Xoshiro(0), model) +ps = ComponentArray(ps) + +logitcrossentropy(ŷ, y; dims = 1) = mean(.-sum(y .* logsoftmax(ŷ; dims); dims)) + +lux_model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + +initial_loss = logitcrossentropy(lux_model(features, ps), target) + +loss_function(p) = logitcrossentropy(lux_model(features, p), target) + +function callback(p, l) + @info "[NeuralGraphODE] Loss: $l" + return false +end + +optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, ps) +res = Optimization.solve(optprob, Adam(0.1); callback, maxiters = 100) + +updated_loss = logitcrossentropy(lux_model(features, ps), target) @test_broken updated_loss < initial_loss diff --git a/test/neural_ode_mm.jl b/test/neural_ode_mm.jl index 20470ffa83..3f57c9f5e0 100644 --- a/test/neural_ode_mm.jl +++ b/test/neural_ode_mm.jl @@ -1,47 +1,47 @@ -using ComponentArrays, DiffEqFlux, Lux, Zygote, Random, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Test -rng = Random.default_rng() - -#A desired MWE for now, not a test yet. -function f(du,u,p,t) - y₁,y₂,y₃ = u - k₁,k₂,k₃ = p - du[1] = -k₁*y₁ + k₃*y₂*y₃ - du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 - du[3] = y₁ + y₂ + y₃ - 1 - nothing -end -u₀ = [1.0, 0, 0] -M = [1. 0 0 - 0 1. 0 - 0 0 0] -tspan = (0.0,1.0) -p = [0.04,3e7,1e4] -func = ODEFunction(f,mass_matrix=M) -prob = ODEProblem(func,u₀,tspan,p) -sol = solve(prob,Rodas5(),saveat=0.1) - -dudt2 = Lux.Chain(Lux.Dense(3,64,tanh),Lux.Dense(64,2)) -p,st = Lux.setup(rng, dudt2) -p = ComponentArray(p) -ndae = NeuralODEMM(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff=false),saveat=0.1) -ndae(u₀,p,st) - -function predict_n_dae(p) - ndae(u₀,p,st)[1] -end -function loss(p) - pred = predict_n_dae(p) - loss = sum(abs2,Array(sol) .- pred) - loss,pred -end - -cb = function (p,l,pred) #callback function to observe training - display(l) - return false -end - -l1 = first(loss(p)) -optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, BFGS(initial_stepnorm = 0.001), callback = cb, maxiters = 100) -@test res.minimum < l1 +using ComponentArrays, + DiffEqFlux, Lux, Zygote, Random, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Test +rng = Random.default_rng() + +#A desired MWE for now, not a test yet. +function f(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ + du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2 + du[3] = y₁ + y₂ + y₃ - 1 + nothing +end +u₀ = [1.0, 0, 0] +M = [1.0 0 0 + 0 1.0 0 + 0 0 0] +tspan = (0.0, 1.0) +p = [0.04, 3e7, 1e4] +func = ODEFunction(f; mass_matrix = M) +prob = ODEProblem(func, u₀, tspan, p) +sol = solve(prob, Rodas5(); saveat = 0.1) + +dudt2 = Chain(Dense(3 => 64, tanh), Dense(64 => 2)) +p, st = Lux.setup(rng, dudt2) +p = ComponentArray{Float64}(p) +ndae = NeuralODEMM(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, + Rodas5(; autodiff = false); saveat = 0.1) +ndae(u₀, p, st) + +function loss(p) + pred = first(ndae(u₀, p, st)) + loss = sum(abs2, Array(sol) .- pred) + return loss, pred +end + +cb = function (p, l, pred) + @info "[NeuralODEMM] Loss: $l" + return false +end + +l1 = first(loss(p)) +optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, p) +res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.001); callback = cb, + maxiters = 100) +@test res.minimum < l1 diff --git a/test/newton_neural_ode.jl b/test/newton_neural_ode.jl index bd3189d919..b4a576a18c 100644 --- a/test/newton_neural_ode.jl +++ b/test/newton_neural_ode.jl @@ -1,44 +1,61 @@ -using DiffEqFlux, Flux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random, Test - -Random.seed!(100) - -n = 1 # number of ODEs -tspan = (0f0, 1f0) - -d = 5 # number of data pairs -x = rand(Float32, n, 5) -y = rand(Float32, n, 5) - -cb = function (p,l) - @show l - false -end - -NN = Flux.Chain(Flux.Dense(n, 5n, tanh), - Flux.Dense(5n, n)) - -@info "ROCK4" -nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1f-4, saveat=[tspan[end]]) - -loss_function(θ) = Flux.Losses.mse(y, nODE(x, θ)[end]) -l1 = loss_function(nODE.p) -optf = Optimization.OptimizationFunction((x,p)->loss_function(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optf, nODE.p) - -res = Optimization.solve(optprob, NewtonTrustRegion(), maxiters=100, callback=cb) -@test loss_function(res.minimizer) < l1 -res = Optimization.solve(optprob, OptimizationOptimJL.Optim.KrylovTrustRegion(), maxiters=100, callback=cb) -@test loss_function(res.minimizer) < l1 - -@info "ROCK2" -nODE = NeuralODE(NN, tspan, ROCK2(), reltol=1f-4, saveat=[tspan[end]]) - -loss_function(θ) = Flux.Losses.mse(y, nODE(x, θ)[end]) -l1 = loss_function(nODE.p) -optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, nODE.p) - -res = Optimization.solve(optprob, NewtonTrustRegion(), maxiters = 100, callback=cb) -@test loss_function(res.minimizer) < l1 -res = Optimization.solve(optprob, OptimizationOptimJL.Optim.KrylovTrustRegion(), maxiters = 100, callback=cb) -@test loss_function(res.minimizer) < l1 +using DiffEqFlux, ComponentArrays, + Lux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random, Test + +Random.seed!(100) + +n = 1 # number of ODEs +tspan = (0.0f0, 1.0f0) + +d = 5 # number of data pairs +x = rand(Float32, n, 5) +y = rand(Float32, n, 5) + +cb = function (p, l) + @info "[Newton NeuralODE] Loss: $l" + false +end + +NN = Chain(Dense(n => 5n, tanh), Dense(5n => n)) + +@info "ROCK4" +nODE = NeuralODE(NN, tspan, ROCK4(); reltol = 1.0f-4, saveat = [tspan[end]]) + +ps, st = Lux.setup(Xoshiro(0), nODE) +ps = ComponentArray(ps) +stnODE = Lux.Experimental.StatefulLuxLayer(nODE, ps, st) + +# KrylovTrustRegion is hardcoded to use `Array` +psd, psax = getdata(ps), getaxes(ps) + +loss_function(θ) = sum(abs2, y .- stnODE(x, ComponentArray(θ, psax))[end]) +l1 = loss_function(psd) +optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optf, psd) + +res = Optimization.solve(optprob, NewtonTrustRegion(); maxiters = 100, callback = cb) +@test loss_function(res.minimizer) < l1 +res = Optimization.solve(optprob, OptimizationOptimJL.Optim.KrylovTrustRegion(); + maxiters = 100, callback = cb) +@test loss_function(res.minimizer) < l1 + +@info "ROCK2" +nODE = NeuralODE(NN, tspan, ROCK2(); reltol = 1.0f-4, saveat = [tspan[end]]) +ps, st = Lux.setup(Xoshiro(0), nODE) +ps = ComponentArray(ps) +stnODE = Lux.Experimental.StatefulLuxLayer(nODE, ps, st) + +# KrylovTrustRegion is hardcoded to use `Array` +psd, psax = getdata(ps), getaxes(ps) + +loss_function(θ) = sum(abs2, y .- stnODE(x, ComponentArray(θ, psax))[end]) +l1 = loss_function(psd) +optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, psd) + +res = Optimization.solve(optprob, NewtonTrustRegion(); maxiters = 100, callback = cb) +@test loss_function(res.minimizer) < l1 +res = Optimization.solve(optprob, OptimizationOptimJL.Optim.KrylovTrustRegion(); + maxiters = 100, callback = cb) +@test loss_function(res.minimizer) < l1 diff --git a/test/runtests.jl b/test/runtests.jl index 0bf6c449e6..7fdc08228a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,19 +15,11 @@ const is_CI = haskey(ENV, "CI") end if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "BasicNeuralDE" - @safetestset "Neural DE Tests with Lux" begin - include("neural_de_lux.jl") - end @safetestset "Neural DE Tests" begin include("neural_de.jl") end - @safetestset "Augmented Neural DE Tests" begin - include("augmented_nde.jl") - end - #@safetestset "Neural Graph DE" begin include("neural_gde.jl") end - - @safetestset "Neural ODE MM Tests" begin - include("neural_ode_mm.jl") + @safetestset "Neural Graph DE" begin + include("neural_gde.jl") end @safetestset "Tensor Product Layer" begin include("tensor_product_test.jl") @@ -38,6 +30,13 @@ const is_CI = haskey(ENV, "CI") @safetestset "Multiple shooting" begin include("multiple_shoot.jl") end + @safetestset "Neural ODE MM Tests" begin + include("neural_ode_mm.jl") + end + # DAE Tests were never included + # @safetestset "Neural DAE Tests" begin + # include("neural_dae.jl") + # end end if GROUP == "All" || GROUP == "AdvancedNeuralDE" @@ -52,7 +51,7 @@ const is_CI = haskey(ENV, "CI") end end - if GROUP == "Newton" + if GROUP == "All" || GROUP == "Newton" @safetestset "Newton Neural ODE Tests" begin include("newton_neural_ode.jl") end @@ -69,4 +68,20 @@ const is_CI = haskey(ENV, "CI") include("mnist_conv_gpu.jl") end end -end \ No newline at end of file + + if GROUP == "All" || GROUP == "Aqua" + @safetestset "Aqua Q/A" begin + using Aqua, DiffEqFlux, LinearAlgebra + + # TODO: Enable persistent tasks once the downstream PRs are merged + Aqua.test_all(DiffEqFlux; ambiguities = false, piracies = false, + persistent_tasks = false) + + Aqua.test_ambiguities(DiffEqFlux; recursive = false) + + # FIXME: Remove Tridiagonal piracy after + # https://github.com/JuliaDiff/ChainRules.jl/issues/713 is merged! + Aqua.test_piracies(DiffEqFlux; treat_as_own = [LinearAlgebra.Tridiagonal]) + end + end +end diff --git a/test/second_order_ode.jl b/test/second_order_ode.jl index 680f2a531e..73fd84eeb5 100644 --- a/test/second_order_ode.jl +++ b/test/second_order_ode.jl @@ -1,93 +1,79 @@ -using ComponentArrays, DiffEqFlux, Lux, Zygote, Random, Optimization, OrdinaryDiffEq, RecursiveArrayTools -rng = Random.default_rng() - -u0 = Float32[0.; 2.] -du0 = Float32[0.; 0.] -tspan = (0.0f0, 1.0f0) -t = range(tspan[1], tspan[2], length=20) - -model = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) -p, st = Lux.setup(rng, model) -p = ComponentArray(p) -ff(du,u,p,t) = model(u,p,st)[1] -prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, p) - -function predict(p) - Array(solve(prob, Tsit5(), p=p, saveat=t, sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP()))) -end - -correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) - -function loss_n_ode(p) - pred = predict(p) - sum(abs2, correct_pos .- pred[1:2, :]), pred -end - -data = Iterators.repeated((), 1000) -opt = Adam(0.01) - -l1 = loss_n_ode(p) - -callback = function (p,l,pred) - @show l - l < 0.01 && Flux.stop() -end - -optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, opt, callback=callback, maxiters = 100) -l2 = loss_n_ode(res.minimizer) -@test l2 < l1 - -function predict(p) - Array(solve(prob, Tsit5(), p=p, saveat=t, sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()))) -end - -correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) - -function loss_n_ode(p) - pred = predict(p) - sum(abs2, correct_pos .- pred[1:2, :]), pred -end - -data = Iterators.repeated((), 1000) -opt = Adam(0.01) - -loss_n_ode(p) - -callback = function (p,l,pred) - @show l - l < 0.01 && Flux.stop() -end -optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, opt, callback=callback, maxiters = 100) -l2 = loss_n_ode(res.minimizer) -@test l2 < l1 - -function predict(p) - Array(solve(prob, Tsit5(), p=p, saveat=t, sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()))) -end - -correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) - -function loss_n_ode(p) - pred = predict(p) - sum(abs2, correct_pos .- pred[1:2, :]), pred -end - -data = Iterators.repeated((), 1000) -opt = Adam(0.01) - -loss_n_ode(p) - -callback = function (p,l,pred) - @show l - l < 0.01 && Flux.stop() -end - -optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optfunc, p) -res = Optimization.solve(optprob, opt, callback=callback, maxiters = 100) -l2 = loss_n_ode(res.minimizer) -@test l2 < l1 +using ComponentArrays, + DiffEqFlux, Lux, Zygote, Random, Optimization, OptimizationOptimisers, OrdinaryDiffEq + +rng = Random.default_rng() + +u0 = Float32[0.0; 2.0] +du0 = Float32[0.0; 0.0] +tspan = (0.0f0, 1.0f0) +t = range(tspan[1], tspan[2]; length = 20) + +model = Chain(Dense(2, 50, tanh), Dense(50, 2)) +p, st = Lux.setup(rng, model) +p = ComponentArray(p) +ff(du, u, p, t) = first(model(u, p, st)) +prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, p) + +function predict(p) + return Array(solve(prob, Tsit5(); p, saveat = t, + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()))) +end + +correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) + +function loss_n_ode(p) + pred = predict(p) + return sum(abs2, correct_pos .- pred[1:2, :]), pred +end + +l1 = loss_n_ode(p) + +function callback(p, l, pred) + @info "[SecondOrderODE] Loss: $l" + return l < 0.01 +end + +optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, p) +res = Optimization.solve(optprob, Adam(0.01f0); callback = callback, maxiters = 100) +l2 = loss_n_ode(res.minimizer) +@test l2 < l1 + +function predict(p) + return Array(solve(prob, Tsit5(); p, saveat = t, + sensealg = QuadratureAdjoint(; autojacvec = ZygoteVJP()))) +end + +correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) + +function loss_n_ode(p) + pred = predict(p) + return sum(abs2, correct_pos .- pred[1:2, :]), pred +end + +optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, p) +res = Optimization.solve(optprob, Adam(0.01f0); callback = callback, maxiters = 100) +l2 = loss_n_ode(res.minimizer) +@test l2 < l1 + +function predict(p) + return Array(solve(prob, Tsit5(); p, saveat = t, + sensealg = BacksolveAdjoint(; autojacvec = ZygoteVJP()))) +end + +correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end]))) + +function loss_n_ode(p) + pred = predict(p) + return sum(abs2, correct_pos .- pred[1:2, :]), pred +end + +optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), + Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optfunc, p) +res = Optimization.solve(optprob, Adam(0.01f0); callback = callback, maxiters = 100) +l2 = loss_n_ode(res.minimizer) +@test l2 < l1 diff --git a/test/spline_layer_test.jl b/test/spline_layer_test.jl index f9330bf76f..535c0fab53 100644 --- a/test/spline_layer_test.jl +++ b/test/spline_layer_test.jl @@ -1,58 +1,58 @@ -using DiffEqFlux, Zygote, DataInterpolations, Distributions, Optimization, LinearAlgebra, Test +using DiffEqFlux, ComponentArrays, Zygote, DataInterpolations, Distributions, Optimization, + OptimizationOptimisers, LinearAlgebra, Random, Test function run_test(f, layer, atol) + ps, st = Lux.setup(Xoshiro(0), layer) + ps = ComponentArray(ps) + model = Lux.Experimental.StatefulLuxLayer(layer, ps, st) data_train_vals = rand(500) data_train_fn = f.(data_train_vals) function loss_function(θ) - data_pred = [layer(x, θ) for x in data_train_vals] - loss = sum(abs.(data_pred.-data_train_fn))/length(data_train_fn) + data_pred = [model(x, θ) for x in data_train_vals] + loss = sum(abs.(data_pred .- data_train_fn)) / length(data_train_fn) return loss end - function callback(p,l) - @show l + function callback(p, l) + @info "[SplineLayer] Loss: $l" return false end - optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optfunc, layer.saved_points) - res = Optimization.solve(optprob, Adam(0.1), callback=callback, maxiters = 100) + optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), + Optimization.AutoZygote()) + optprob = Optimization.OptimizationProblem(optfunc, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters = 100) optprob = Optimization.OptimizationProblem(optfunc, res.minimizer) - res = Optimization.solve(optprob, Adam(0.1), callback=callback, maxiters = 100) + res = Optimization.solve(optprob, Adam(0.1); callback = callback, maxiters = 100) opt = res.minimizer data_validate_vals = rand(100) data_validate_fn = f.(data_validate_vals) - data_validate_pred = [layer(x,opt) for x in data_validate_vals] + data_validate_pred = [model(x, opt) for x in data_validate_vals] - output = sum(abs.(data_validate_pred.-data_validate_fn))/length(data_validate_fn) - @show output + output = sum(abs.(data_validate_pred .- data_validate_fn)) / length(data_validate_fn) return output < atol end ##test 01: affine function, Linear Interpolation a, b = rand(2) -f = x -> a*x + b -layer = SplineLayer((0.0,1.0),0.01,LinearInterpolation) -@test run_test(f, layer, 0.1) +layer = SplineLayer((0.0, 1.0), 0.01, LinearInterpolation) +@test run_test(x -> a * x + b, layer, 0.1) ##test 02: non-linear function, Quadratic Interpolation a, b, c = rand(3) -f = x -> a*x^2+ b*x + x -layer = SplineLayer((0.0,1.0),0.01,QuadraticInterpolation) -@test run_test(f, layer, 0.1) +layer = SplineLayer((0.0, 1.0), 0.01, QuadraticInterpolation) +@test run_test(x -> a * x^2 + b * x + x, layer, 0.1) ##test 03: non-linear function, Quadratic Spline a, b, c = rand(3) -f = x -> a*sin(b*x+c) -layer = SplineLayer((0.0,1.0),0.1,QuadraticSpline) -@test run_test(f, layer, 0.1) +layer = SplineLayer((0.0, 1.0), 0.1, QuadraticSpline) +@test run_test(x -> a * sin(b * x + c), layer, 0.1) ##test 04: non-linear function, Cubic Spline -f = x -> exp(x)*x^2 -layer = SplineLayer((0.0,1.0),0.1,CubicSpline) -@test run_test(f, layer, 0.1) +layer = SplineLayer((0.0, 1.0), 0.1, CubicSpline) +@test run_test(x -> exp(x) * x^2, layer, 0.1) diff --git a/test/stiff_nested_ad.jl b/test/stiff_nested_ad.jl index b228d64b37..9c3bbddbc4 100644 --- a/test/stiff_nested_ad.jl +++ b/test/stiff_nested_ad.jl @@ -1,53 +1,43 @@ -using DiffEqFlux, Flux, Zygote, OrdinaryDiffEq, Test - -u0 = [2.; 0.] -datasize = 30 -tspan = (0.0f0,1.5f0) - -function trueODEfunc(du,u,p,t) - true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' -end -t = range(tspan[1],tspan[2],length=datasize) -prob = ODEProblem(trueODEfunc,u0,tspan) -ode_data = Array(solve(prob,Tsit5(),saveat=t)) - -model = Flux.Chain(x -> x.^3, - Flux.Dense(2,50,tanh), - Flux.Dense(50,2)) |> f64 -neuralde = NeuralODE(model,tspan,Rodas5(),saveat=t,reltol=1e-7,abstol=1e-9) - -function predict_n_ode() - neuralde(u0) -end -loss_n_ode() = sum(abs2,ode_data .- predict_n_ode()) - -data = Iterators.repeated((), 10) -opt = Adam(0.1) -cb = function () #callback function to observe training - display(loss_n_ode()) -end - -# Display the ODE with the initial parameter values. -cb() - -neuralde = NeuralODE(model,tspan,Rodas5(),saveat=t,reltol=1e-7,abstol=1e-9) -ps = Flux.params(neuralde) -loss1 = loss_n_ode() -Flux.train!(loss_n_ode, ps, data, opt, cb = cb) -loss2 = loss_n_ode() -@test loss2 < loss1 - -neuralde = NeuralODE(model,tspan,KenCarp4(),saveat=t,reltol=1e-7,abstol=1e-9) -ps = Flux.params(neuralde) -loss1 = loss_n_ode() -Flux.train!(loss_n_ode, ps, data, opt, cb = cb) -loss2 = loss_n_ode() -@test loss2 < loss1 - -neuralde = NeuralODE(model,tspan,RadauIIA5(),saveat=t,reltol=1e-7,abstol=1e-9) -ps = Flux.params(neuralde) -loss1 = loss_n_ode() -Flux.train!(loss_n_ode, ps, data, opt, cb = cb) -loss2 = loss_n_ode() -@test loss2 < loss1 +using DiffEqFlux, ComponentArrays, Zygote, OrdinaryDiffEq, Test, Optimization, + OptimizationOptimisers, Random +import Flux + +u0 = [2.0; 0.0] +datasize = 30 +tspan = (0.0f0, 1.5f0) + +function trueODEfunc(du, u, p, t) + true_A = [-0.1 2.0; -2.0 -0.1] + du .= ((u .^ 3)'true_A)' +end +t = range(tspan[1], tspan[2]; length = datasize) +prob = ODEProblem(trueODEfunc, u0, tspan) +ode_data = Array(solve(prob, Tsit5(); saveat = t)) + +model = Chain(x -> x .^ 3, Dense(2 => 50, tanh), Dense(50 => 2)) + +predict_n_ode(lux_model, p) = lux_model(u0, p) +loss_n_ode(lux_model, p) = sum(abs2, ode_data .- predict_n_ode(lux_model, p)) + +function callback(solver) + return function (p, l) + @info "[StiffNestedAD $(nameof(typeof(solver)))] Loss: $l" + return false + end +end + +@testset "Solver: $(nameof(typeof(solver)))" for solver in (KenCarp4(), + Rodas5(), RadauIIA5()) + neuralde = NeuralODE(model, tspan, solver; saveat = t, reltol = 1e-7, abstol = 1e-9) + ps, st = Lux.setup(Xoshiro(0), neuralde) + ps = ComponentArray(ps) + lux_model = Lux.Experimental.StatefulLuxLayer(neuralde, ps, st) + loss1 = loss_n_ode(lux_model, ps) + optfunc = Optimization.OptimizationFunction((x, p) -> loss_n_ode(lux_model, x), + Optimization.AutoZygote()) + optprob = Optimization.OptimizationProblem(optfunc, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = callback(solver), + maxiters = 100) + loss2 = loss_n_ode(lux_model, res.minimizer) + @test loss2 < loss1 +end diff --git a/test/tensor_product_test.jl b/test/tensor_product_test.jl index dafce2a4a5..9981232f73 100644 --- a/test/tensor_product_test.jl +++ b/test/tensor_product_test.jl @@ -1,49 +1,52 @@ using DiffEqFlux, Distributions, Zygote, Optimization, OptimizationOptimJL, - OptimizationOptimisers, LinearAlgebra, Test + OptimizationOptimisers, LinearAlgebra, Random, ComponentArrays, Test -function run_test(f, layer, atol) +function run_test(f, layer, atol, N) + ps, st = Lux.setup(Xoshiro(0), layer) + ps = ComponentArray(ps) + model = Lux.Experimental.StatefulLuxLayer(layer, ps, st) - data_train_vals = [rand(length(layer.model)) for k in 1:500] + data_train_vals = [rand(N) for k in 1:500] data_train_fn = f.(data_train_vals) - function loss_function(component) - data_pred = [layer(x,component) for x in data_train_vals] - loss = sum(norm.(data_pred.-data_train_fn))/length(data_train_fn) + function loss_function(p) + data_pred = [model(x, p) for x in data_train_vals] + loss = sum(norm.(data_pred .- data_train_fn)) / length(data_train_fn) return loss end - function cb(p,l) - @show l + function cb(p, l) + @info "[TensorProductLayer] Loss: $l" return false end - optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optfunc, layer.p) - res = Optimization.solve(optprob, Adam(0.1), callback=cb, maxiters = 100) + optfunc = Optimization.OptimizationFunction((x, p) -> loss_function(x), + Optimization.AutoZygote()) + optprob = Optimization.OptimizationProblem(optfunc, ps) + res = Optimization.solve(optprob, Adam(0.1); callback = cb, maxiters = 100) optprob = Optimization.OptimizationProblem(optfunc, res.minimizer) - res = Optimization.solve(optprob, Adam(0.01), callback=cb, maxiters = 100) + res = Optimization.solve(optprob, Adam(0.01); callback = cb, maxiters = 100) optprob = Optimization.OptimizationProblem(optfunc, res.minimizer) - res = Optimization.solve(optprob, BFGS(), callback=cb, maxiters = 200) + res = Optimization.solve(optprob, BFGS(); callback = cb, maxiters = 200) opt = res.minimizer - data_validate_vals = [rand(length(layer.model)) for k in 1:100] + data_validate_vals = [rand(N) for k in 1:100] data_validate_fn = f.(data_validate_vals) - data_validate_pred = [layer(x,opt) for x in data_validate_vals] + data_validate_pred = [model(x, opt) for x in data_validate_vals] - return sum(norm.(data_validate_pred.-data_validate_fn))/length(data_validate_fn) < atol + return sum(norm.(data_validate_pred .- data_validate_fn)) / length(data_validate_fn) < + atol end ##test 01: affine function, Chebyshev and Polynomial basis -A = rand(2,2) +A = rand(2, 2) b = rand(2) -f = x -> A*x + b layer = TensorLayer([ChebyshevBasis(10), PolynomialBasis(10)], 2) -@test run_test(f, layer, 0.05) +@test run_test(x -> A * x + b, layer, 0.05, 2) ##test 02: non-linear function, Chebyshev and Legendre basis -A = rand(2,2) +A = rand(2, 2) b = rand(2) -f = x -> A*x*norm(x)+ b*sin(norm(x)) layer = TensorLayer([ChebyshevBasis(7), FourierBasis(7)], 2) -@test run_test(f, layer, 0.10) +@test run_test(x -> A * x * norm(x) + b * sin(norm(x)), layer, 0.10, 2)