diff --git a/FastVision/Project.toml b/FastVision/Project.toml index 2a3e9b5d11..4f5f8bf0ec 100644 --- a/FastVision/Project.toml +++ b/FastVision/Project.toml @@ -17,6 +17,7 @@ IndirectArrays = "9b13fd28-a010-5f03-acff-a1bbcff69959" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -33,14 +34,15 @@ FastAI = "0.5" FixedPointNumbers = "0.8" Flux = "0.12, 0.13" ImageIO = "0.6" -ImageInTerminal = "0.4" +ImageInTerminal = "0.4, 0.5" IndirectArrays = "0.5, 1" InlineTest = "0.2" MLUtils = "0.2, 0.3, 0.4" MakieCore = "0.3, 0.4, 0.5, 0.6" +Metalhead = "0.8" ProgressMeter = "1" ShowCases = "0.1" StaticArrays = "1.1" -UnicodePlots = "2" +UnicodePlots = "2, 3" Zygote = "0.6" julia = "1.6" diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 293e088e5d..96578e1aef 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -38,6 +38,7 @@ using FastAI: # blocks Context, Training, Validation, Inference, Datasets using FastAI.Datasets +import FastAI.Registries: ModelVariant, compatibleblocks, loadvariant # extending import FastAI: @@ -45,7 +46,7 @@ import FastAI: encodedblock, decodedblock, showblock!, mockblock, setup, encodestate, decodestate -import Flux +using Flux: Flux, Chain, Conv, Dense import MLUtils: getobs, numobs, mapobs, eachobs import Colors: colormaps_sequential, Colorant, Color, Gray, Normed, RGB, alphacolor, deuteranopic, distinguishable_colors @@ -63,6 +64,7 @@ import IndirectArrays: IndirectArray import MakieCore import MakieCore: @recipe import MakieCore.Observables: @map +import Metalhead: Metalhead import ProgressMeter: Progress, next! import StaticArrays: SVector import Statistics: mean, std @@ -76,6 +78,7 @@ include("blocks/bounded.jl") include("blocks/image.jl") include("blocks/mask.jl") include("blocks/keypoints.jl") +include("blocks/convfeatures.jl") include("encodings/onehot.jl") include("encodings/imagepreprocessing.jl") @@ -93,6 +96,7 @@ include("tasks/keypointregression.jl") include("datasets.jl") include("recipes.jl") include("makie.jl") +include("modelregistry.jl") include("tests.jl") @@ -103,6 +107,11 @@ function __init__() push!(FastAI.learningtasks(), t) end end + foreach(values(_models)) do t + if !haskey(FastAI.models(), t.id) + push!(FastAI.models(), t) + end + end end export Image, Mask, Keypoints, Bounded, diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl new file mode 100644 index 0000000000..d5c4352e52 --- /dev/null +++ b/FastVision/src/blocks/convfeatures.jl @@ -0,0 +1,52 @@ + +""" + ConvFeatures{N}(n) <: Block + ConvFeatures(n, size) + +Block representing features from a convolutional neural network backbone +with `n` feature channels and `N` spatial dimensions. + +For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output +that is passed further to the classifier head. + +## Examples + +A feature block with 512 channels and variable spatial dimensions: + +```julia +FastVision.ConvFeatures{2}(512) +# or equivalently +FastVision.ConvFeatures(512, (:, :)) +``` + +A feature block with 512 channels and fixed spatial dimensions: + +```julia +FastVision.ConvFeatures(512, (4, 4)) +``` + +""" +struct ConvFeatures{N} <: Block + n::Int + size::NTuple{N, DimSize} +end + +ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N)) + +function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T} + M == N + 1 || return false + return checksize(block.size, size(a)[begin:N]) +end + +function FastAI.mockblock(block::ConvFeatures) + rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) +end + + +@testset "ConvFeatures [block]" begin + @test ConvFeatures(16, (:, :)) == ConvFeatures{2}(16) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 2, 2, 16)) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 3, 2, 16)) + @test checkblock(ConvFeatures(16, (2, 2)), rand(Float32, 2, 2, 16)) + @test !checkblock(ConvFeatures(16, (2, :)), rand(Float32, 3, 2, 16)) +end diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl new file mode 100644 index 0000000000..d8c8e94388 --- /dev/null +++ b/FastVision/src/modelregistry.jl @@ -0,0 +1,133 @@ + +# ## Model variants for Metalhead.jl models + +struct MetalheadClassifierVariant <: ModelVariant + fn +end +compatibleblocks(::MetalheadClassifierVariant) = (ImageTensor{2}, FastAI.OneHotTensor{0}) +function loadvariant(v::MetalheadClassifierVariant, xblock::ImageTensor{2}, yblock::FastAI.OneHotTensor{0}, checkpoint; kwargs...) + return v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels, + nclasses=length(yblock.classes), kwargs...) +end +function loadvariant(v::MetalheadClassifierVariant, xblock, yblock, checkpoint; kwargs...) + return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) +end + +struct MetalheadBackboneVariant <: ModelVariant + fn + nfeatures::Int +end +compatibleblocks(variant::MetalheadBackboneVariant) = (ImageTensor{2}, ConvFeatures{2}(variant.nfeatures)) +function loadvariant(v::MetalheadBackboneVariant, xblock::ImageTensor{2}, yblock::ConvFeatures{2}, checkpoint; kwargs...) + model = v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels, + kwargs...) + return model.layers[1] +end +function loadvariant(v::MetalheadBackboneVariant, xblock, yblock, checkpoint; kwargs...) + model = v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) + return model.layers[1] +end + +function metalheadvariants(modelfn, nfeatures) + return [ + "classifier" => MetalheadClassifierVariant(modelfn), + "backbone" => MetalheadBackboneVariant(modelfn, nfeatures), + ] +end + + +const _models = Dict{String, Any}() + + +fix(fn, args...; kwargs...) = (_args...; _kwargs...) -> fn(args..., _args...; kwargs..., _kwargs...) + + + +# model config: id, description, basefn, variant, hasweights, nfeatures +const METALHEAD_CONFIGS = [ + ("metalhead/resnet18", "ResNet18", fix(Metalhead.ResNet, 18), true, 512), + ("metalhead/resnet34", "ResNet34", fix(Metalhead.ResNet, 34), true, 512), + ("metalhead/resnet50", "ResNet50", fix(Metalhead.ResNet, 50), true, 2048), + ("metalhead/resnet101", "ResNet101", fix(Metalhead.ResNet, 101), true, + 2048), + ("metalhead/resnet152", "ResNet152", fix(Metalhead.ResNet, 152), true, + 2048), + ("metalhead/wideresnet50", "WideResNet50", fix(Metalhead.WideResNet, 50), + true, 2048), + ("metalhead/wideresnet101", "WideResNet101", + fix(Metalhead.WideResNet, 101), true, 2048), + ("metalhead/wideresnet152", "WideResNet152", + fix(Metalhead.WideResNet, 152), true, 2048), + ("metalhead/googlenet", "GoogLeNet", Metalhead.GoogLeNet, false, + 1024), + ("metalhead/inceptionv3", "InceptionV3", Metalhead.Inceptionv3, false, + 2048), + ("metalhead/inceptionv4", "InceptionV4", Metalhead.Inceptionv4, false, + 1536), + ("metalhead/squeezenet", "SqueezeNet", Metalhead.SqueezeNet, true, + 512), + ("metalhead/densenet-121", "DenseNet121", fix(Metalhead.DenseNet, 121), + false, 1024), + ("metalhead/densenet-161", "DenseNet161", fix(Metalhead.DenseNet, 161), + false, 1472), + ("metalhead/densenet-169", "DenseNet169", fix(Metalhead.DenseNet, 169), + false, 1664), + ("metalhead/densenet-201", "DenseNet201", fix(Metalhead.DenseNet, 201), + false, 1920), + ("metalhead/resnext50", "ResNeXt50", fix(Metalhead.ResNeXt, 50), true, + 2048), + ("metalhead/resnext101", "ResNeXt101", fix(Metalhead.ResNeXt, 101), true, + 2048), + ("metalhead/resnext152", "ResNeXt152", fix(Metalhead.ResNeXt, 152), true, + 2048), + ("metalhead/mobilenetv1", "MobileNetV1", Metalhead.MobileNetv1, false, + 1024), + ("metalhead/mobilenetv2", "MobileNetV2", Metalhead.MobileNetv2, false, + 1280), + ("metalhead/mobilenetv3-small", "MobileNetV3 Small", + fix(Metalhead.MobileNetv3, :small), false, 576), + ("metalhead/mobilenetv3-large", "MobileNetV3 Large", + fix(Metalhead.MobileNetv3, :large), false, 960), + ("metalhead/efficientnet-b0", "EfficientNet-B0", + fix(Metalhead.EfficientNet, :b0), false, 1280), + ("metalhead/efficientnet-b0", "EfficientNet-B0", + fix(Metalhead.EfficientNet, :b0), false, 1280), + ("metalhead/efficientnet-b1", "EfficientNet-B1", + fix(Metalhead.EfficientNet, :b1), false, 1280), + ("metalhead/efficientnet-b2", "EfficientNet-B2", + fix(Metalhead.EfficientNet, :b2), false, 1280), + ("metalhead/efficientnet-b3", "EfficientNet-B3", + fix(Metalhead.EfficientNet, :b3), false, 1280), + ("metalhead/efficientnet-b4", "EfficientNet-B4", + fix(Metalhead.EfficientNet, :b4), false, 1280), + ("metalhead/efficientnet-b5", "EfficientNet-B5", + fix(Metalhead.EfficientNet, :b5), false, 1280), + ("metalhead/efficientnet-b6", "EfficientNet-B6", + fix(Metalhead.EfficientNet, :b6), false, 1280), + ("metalhead/efficientnet-b7", "EfficientNet-B7", + fix(Metalhead.EfficientNet, :b7), false, 1280), + ("metalhead/efficientnet-b8", "EfficientNet-B8", + fix(Metalhead.EfficientNet, :b8), false, 1280), +] + +metalheadloadfn(fn, hasweights) = function loadfn(ckpt; kwargs...) + hasweights ? fn(; pretrain = ckpt !== nothing, kwargs...) : fn(; kwargs...) +end + +for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS + _models[id] = (; + id, description, + loadfn = metalheadloadfn(loadfn, hasweights), + variants = metalheadvariants(loadfn, nfeatures), + checkpoints = hasweights ? ["imagenet1k"] : String[], + backend = :flux) +end + + +@testset "Metalhead models" begin + @test_nowarn load(models()["metalhead/resnet18"]; variant = "backbone") + @test_nowarn load(models()["metalhead/resnet18"]; variant = "classifier") + @test_nowarn load(models()["metalhead/resnet18"]; output = FastAI.OneHotLabel) + @test_nowarn load(models()["metalhead/resnet18"]; input = FastVision.ImageTensor) + @test_throws FastAI.Registries.NoModelVariantFoundError load(models()["metalhead/resnet18"]; output = FastAI.Label) +end diff --git a/src/Registries/Registries.jl b/src/Registries/Registries.jl index 49d38ff7c4..ac20e93e5d 100644 --- a/src/Registries/Registries.jl +++ b/src/Registries/Registries.jl @@ -1,6 +1,6 @@ module Registries -using ..FastAI +using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock using ..FastAI.Datasets using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify @@ -48,10 +48,12 @@ end include("datasets.jl") include("tasks.jl") include("recipes.jl") +include("models.jl") export datasets, learningtasks, datarecipes, + models, find, info, load diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 8b13789179..5f284dd030 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1 +1,272 @@ +# # Model registry +# +# This file defines [`models`](#), a feature registry for models. +# ## Registry definition + +const _MODELS_DESCRIPTION = """ +A `FeatureRegistry` for models. Allows you to find and load models for various learning +tasks using a unified interface. Call `models()` to see a table view of available models: + +```julia +using FastAI +models() +``` + +Which models are available depends on the loaded packages. For example, FastVision.jl adds +vision models from Metalhead to the registry. Index the registry with a model ID to get more +information about that model: + +```julia +using FastAI: models +using FastVision # loading the package extends the list of available models + +models()["metalhead/resnet18"] +``` + +If you've selected a model, call `load` to then instantiate a model: + +```julia +model = load("metalhead/resnet18") +``` + +By default, `load` loads a default version of the model without any pretrained weights. + +`load(model)` also accepts keyword arguments that allow you to specify variants of the model and +weight checkpoints that should be loaded. + +Loading a checkpoint of pretrained weights: + +- `load(entry; pretrained = true)`: Use any pretrained weights, if they are available. +- `load(entry; checkpoint = "checkpoint-name")`: Use the weights with given name. See + `entry.checkpoints` for available checkpoints (if any). +- `load(entry; pretrained = false)`: Don't use pretrained weights + +Loading a model variant for a specific task: + +- `load(entry; input = ImageTensor, output = OneHotLabel)`: Load a model variant matching + an input and output block. +- `load(entry; variant = "backbone"): Load a model variant by name. See `entry.variants` for + available variants. +""" + + +# ## `ModelVariant` interface +""" + abstract type ModelVariant + +A `ModelVariant` handles loading a model, optionally with pretrained weights and +transforming it so that it can be used for specific learning tasks. + + +are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. + +## Interface + +- [`compatibleblocks`](#)`(variant)` returns a tuple `(xblock, yblock)` of [`BlockLike`](#) that + are compatible with the model. This means that a variant can be used for a task with + input and output blocks `blocks`, if [`issubblock`](#)`(blocks, compatibleblocks(variant))`. +- [`loadvariant`](#)`(::ModelVariant, xblock, yblock, checkpoint; kwargs...)` loads a model + compatible with block instances `xblock` and `yblock`, with (optionally) weights + from `checkpoint`. +""" +abstract type ModelVariant end + +""" + compatibleblocks(::ModelVariant) + +Indicate compatible input and output block for a model variant. +""" +function compatibleblocks end + +function loadvariant end + + +# ## Model registry creation + +function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) + fields = (; + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field(String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format), + backend = Field(Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`."), + loadfn = Field(Any, + name = "Load function", + optional = false, + description = "A function `loadfn(checkpoint)` that loads a default version of the model, possibly with `checkpoint` weights.", + ), + variants = Field(Vector{Pair{String, ModelVariant}}, + name = "Variants", + default = Pair{String, ModelVariant}[], + description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", + formatfn = d -> join(first.(d), ", ")), + checkpoints = Field(Vector{String}; + name = "Checkpoints", + description = """ + Pretrained weight checkpoints that can be loaded for the model. Checkpoints are listed as a + `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[]), + ) + return Registry(fields; name, loadfn = _loadmodel, description = description) +end + +""" + _loadmodel(row) + +Load a model specified by `row` from a model registry. +""" +function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, + pretrained = !isnothing(checkpoint), kwargs...) + checkpoints, variants = row.checkpoints, row.variants # 1.6 support + + # Finding matching configuration + checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) + if (pretrained && isnothing(checkpoint)) + throw(NoCheckpointFoundError(checkpoints, checkpoint)) + end + + # If no variant is asked for, use the base model loading function that only takes + # care of the checkpoint. + if isnothing(variant) && input === Any && output === Any + return row.loadfn(checkpoint) + # If a variant is specified, either by name (through `variant`) or through block + # constraints `input` or `output`, try to find a matching variant. + # care of the checkpoint. + else + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + return loadvariant(variant, input, output, checkpoint; kwargs...) + end +end + +# ### Errors + +# TODO: Implement Base.showerror +struct NoModelVariantFoundError <: Exception + variants::Vector + input::BlockLike + output::BlockLike + variant::Union{String, Nothing} +end + +# TODO: Implement Base.showerror +struct NoCheckpointFoundError <: Exception + checkpoints::Vector{String} + checkpoint::Union{String, Nothing} +end + +# ## Create the default registry instance + +const MODELS = _modelregistry() + +""" + models() + +$_MODELS_DESCRIPTION +""" +models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) + +# ## Helpers + +function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) + if isempty(checkpoints) + nothing + elseif !isnothing(name) + i = findfirst(==(name), checkpoints) + isnothing(i) ? nothing : checkpoints[i] + elseif pretrained + first(values(checkpoints)) + else + nothing + end +end + +function _findvariant(variants::Vector, + variantname::Union{String, Nothing}, xblock, yblock) + if !isnothing(variantname) + variants = filter(variants) do (name, _) + name == variantname + end + end + i = findfirst(variants) do (_, variant) + v_xblock, v_yblock = compatibleblocks(variant) + issubblock(v_xblock, xblock) && issubblock(v_yblock, yblock) + end + isnothing(i) ? nothing : variants[i][2] +end + +# ## Tests + +struct MockVariant <: ModelVariant + model + blocks +end + +compatibleblocks(variant::MockVariant) = variant.blocks +loadvariant(variant::MockVariant, _, _, ch) = (ch, variant.model) + +@testset "Model registry" begin + @testset "Basic" begin + @test_nowarn _modelregistry() + reg = _modelregistry() + push!(reg, (; id = "test", loadfn = (checkpoint,) -> checkpoint)) + + @test load(reg["test"]) === nothing + @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) + end + + @testset "_loadmodel" begin + reg = _modelregistry() + @test_nowarn push!(reg, + (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => MockVariant(1, (Any, Any)), + "ext" => MockVariant(2, (Any, Label)), + ])) + entry = reg["test"] + @test _loadmodel(entry) == (nothing, 1) + @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) + @test _loadmodel(entry; checkpoint = "checkpoint2") == ("checkpoint2", 1) + @test_throws NoCheckpointFoundError _loadmodel(entry; checkpoint = "checkpoint3") + + @test _loadmodel(entry; output = Label) == (nothing, 2) + @test _loadmodel(entry; variant = "ext") == (nothing, 2) + @test _loadmodel(entry; pretrained = true, output = Label) == ("checkpoint", 2) + @test_throws NoModelVariantFoundError _loadmodel(entry; input = Label) + end + + @testset "_findvariant" begin + vars = [ + "1" => MockVariant(1, (Any, Any)), + "2" => MockVariant(1, (Any, Label)), + ] + # no restrictions => select first variant + @test _findvariant(vars, nothing, Any, Any) == vars[1][2] + # name => select named variant + @test _findvariant(vars, "2", Any, Any) == vars[2][2] + # name not found => nothing + @test _findvariant(vars, "3", Any, Any) === nothing + # restrict block => select matching + @test _findvariant(vars, nothing, Any, Label) == vars[2][2] + # restrict block not found => nothing + @test _findvariant(vars, nothing, Any, LabelMulti) === nothing + end + + @testset "_findcheckpoint" begin + chs = ["check1", "check2"] + @test _findcheckpoint(chs) === nothing + @test _findcheckpoint(chs, pretrained = true) === "check1" + @test _findcheckpoint(chs, pretrained = true, name = "check2") === "check2" + @test _findcheckpoint(chs, pretrained = true, name = "check3") === nothing + end +end diff --git a/src/datablock/block.jl b/src/datablock/block.jl index 462d3ece10..c2423df49f 100644 --- a/src/datablock/block.jl +++ b/src/datablock/block.jl @@ -131,3 +131,39 @@ and other diagrams. """ blockname(block::Block) = string(nameof(typeof(block))) blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")" + +const BlockLike = Union{<:AbstractBlock, Type{<:AbstractBlock}, <:Tuple, Type{Any}} + +""" + function issubblock(subblock, superblock) + +Predicate whether `subblock` is a subblock of `superblock`. This means that `subblock` is + +- a subtype of a type `superblock <: Type{AbstractBlock}` +- an instance of a subtype of `superblock <: Type{AbstractBlock}` +- equal to `superblock` + +Both arguments can also be tuples. In that case, each element of the tuple `subblock` is +compared recursively against the elements of the tuple `superblock`. +""" +function issubblock end + +issubblock(_, _) = false +issubblock(sub::BlockLike, super::Type{Any}) = true +issubblock(sub::Tuple, super::Tuple) = + (length(sub) == length(super)) && all(map(issubblock, sub, super)) +issubblock(sub::Type{<:AbstractBlock}, super::Type{<:AbstractBlock}) = sub <: super +issubblock(sub::AbstractBlock, super::Type{<:AbstractBlock}) = issubblock(typeof(sub), super) +issubblock(sub::AbstractBlock, super::AbstractBlock) = sub == super + +@testset "issubblock" begin + @test issubblock(Label, Any) + @test issubblock((Label,), (Any,)) + @test issubblock((Label,), Any) + @test !issubblock(Label, (Any,)) + @test issubblock(Label{String}, Label) + @test !issubblock(Label, Label{String}) + @test issubblock(Label{Int}(1:10), Label{Int}) + @test issubblock(Label{Int}(1:10), Label{Int}(1:10)) + @test !issubblock(Label{Int}, Label{Int}(1:10)) +end