From 60b4bea083094edcf8d6b8b1aec3d2494c80f678 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 12:48:42 +0100 Subject: [PATCH 01/14] Add a feature registry for models --- src/Registries/Registries.jl | 4 +- src/Registries/models.jl | 246 +++++++++++++++++++++++++++++++++++ src/datablock/block.jl | 36 +++++ 3 files changed, 285 insertions(+), 1 deletion(-) diff --git a/src/Registries/Registries.jl b/src/Registries/Registries.jl index 49d38ff7c4..ac20e93e5d 100644 --- a/src/Registries/Registries.jl +++ b/src/Registries/Registries.jl @@ -1,6 +1,6 @@ module Registries -using ..FastAI +using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock using ..FastAI.Datasets using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify @@ -48,10 +48,12 @@ end include("datasets.jl") include("tasks.jl") include("recipes.jl") +include("models.jl") export datasets, learningtasks, datarecipes, + models, find, info, load diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 8b13789179..0b0f549a9f 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1 +1,247 @@ + +const _MODELS_DESCRIPTION = """ +A `FeatureRegistry` for models. Allows you to find and load models for various learning +tasks using a unified interface. Call `models()` to see a table view of available models: + +```julia +using FastAI +models() +``` + +Which models are available depends on the loaded packages. For example, FastVision.jl adds +vision models from Metalhead to the registry. Index the registry with a model ID to get more +information about that model: + +```julia +using FastAI: models +using FastVision # loading the package extends the list of available models + +models()["metalhead/resnet18"] +``` + +If you've selected a model, call `load` to then instantiate a model: + +```julia +model = load("metalhead/resnet18") +``` + +By default, `load` loads a default version of the model without any pretrained weights. + +`load(model)` also accepts keyword arguments that allow you to specify variants of the model and +weight checkpoints that should be loaded. + +Loading a checkpoint of pretrained weights: + +- `load(entry; pretrained = true)`: Use any pretrained weights, if they are available. +- `load(entry; checkpoint = "checkpoint-name")`: Use the weights with given name. See + `entry.checkpoints` for available checkpoints (if any). +- `load(entry; pretrained = false)`: Don't use pretrained weights + +Loading a model variant for a specific task: + +- `load(entry; input = ImageTensor, output = OneHotLabel)`: Load a model variant matching + an input and output block. +- `load(entry; variant = "backbone"): Load a model variant by name. See `entry.variants` for + available variants. +""" + + +""" + struct ModelVariant(; transform, input, output) + +A `ModelVariant` is a model transformation that changes a model so that its input and output +are subblocks (see [`issubblock`](#)) of `blocks = (inblock, outblock)`. + +""" +struct ModelVariant + transformfn::Any # callable + xblock::BlockLike + yblock::BlockLike +end +_default_transform(model, xblock, yblock; kwargs...) = model +ModelVariant(; transform = _default_transform, input = Any, output = Any) = + ModelVariant(transform, input, output) + + +# Registry definition + +function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) + fields = (; + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field( + String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format, + ), + backend = Field( + Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`.", + ), + variants = Field( + Vector{Pair{String,ModelVariant}}, + name = "Variants", + description = "Model variants suitable for different learning tasks", + defaultfn = (row, key) -> Pair{String, ModelVariant}[], + formatfn = d -> join(collect(keys(d)), ", "), + ), + checkpoints = Field( + Vector{String}; + name = "Checkpoints", + description = "Pretrained weight checkpoints that can be loaded for the model", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[], + ), + loadfn = Field( + Any; + name = "Load function", + description = """ + Function that loads the base version of the model, optionally with weights. + It is called with the name of the selected checkpoint fro `checkpoints`, + i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with + `nothing`, i.e. loadfn(`nothing`). + + Any unknown keyword arguments passed to `load`, i.e. + `load(registry[id]; kwargs...)` will be passed along to `loadfn`. + """, + optional = false, + ) + ) + return Registry(fields; name, loadfn = identity, description = description) +end + +""" + _loadmodel(row) + +Load a model specified by `row` from a model registry. + + +""" +function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) + (; loadfn, checkpoints, variants) = row + + # Finding matching configuration + checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) + + pretrained && isnothing(checkpoint) && throw(NoCheckpointFoundError(checkpoints, checkpoint)) + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + + # Loading + basemodel = loadfn(checkpoint, kwargs...) + model = variant.transformfn(basemodel, input, output) + + return model +end + +struct NoModelVariantFoundError <: Exception + variants::Vector{Pair{String, ModelVariant}} + input::BlockLike + output::BlockLike + variant::Union{String, Nothing} +end + +struct NoCheckpointFoundError <: Exception + checkpoints::Vector{String} + checkpoint::Union{String, Nothing} +end + + + +const MODELS = _modelregistry() + + +""" + models() + +$_MODELS_DESCRIPTION +""" +models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) + + + +function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) + if isempty(checkpoints) + nothing + elseif !isnothing(name) + i = findfirst(==(name), checkpoints) + isnothing(i) ? nothing : checkpoints[i] + elseif pretrained + first(values(checkpoints)) + else + nothing + end +end + +function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname::Union{String, Nothing}, xblock, yblock) + if !isnothing(variantname) + variants = filter(variants) do (name, _) + name == variantname + end + end + i = findfirst(variants) do (_, variant) + issubblock(variant.xblock, xblock) && issubblock(variant.yblock, yblock) + end + isnothing(i) ? nothing : variants[i][2] +end + + +@testset "Model registry" begin + @testset "Basic" begin + @test_nowarn _modelregistry() + reg = _modelregistry() + push!(reg, (; + id = "test", + loadfn = _ -> 1, + )) + end + + @testset "_loadmodel" begin + reg = _modelregistry() + @test_nowarn push!(reg, (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => ModelVariant(), + "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, k + 1), Any, Label), + ] + )) + entry = reg["test"] + @test _loadmodel(entry) == (nothing, 1) + @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) + @test _loadmodel(entry; checkpoint = "checkpoint2") == ("checkpoint2", 1) + @test_throws NoCheckpointFoundError _loadmodel(entry; checkpoint = "checkpoint3") + + @test _loadmodel(entry; output = Label) == (nothing, 2) + @test _loadmodel(entry; variant = "ext") == (nothing, 2) + @test _loadmodel(entry; pretrained = true, output = Label) == ("checkpoint", 2) + @test_throws NoModelVariantFoundError _loadmodel(entry; input = Label) + end + + @testset "_findvariant" begin + vars = ["1" => ModelVariant(identity, Any, Any), "2" => ModelVariant(identity, Any, Label)] + # no restrictions => select first variant + @test _findvariant(vars, nothing, Any, Any) == vars[1][2] + # name => select named variant + @test _findvariant(vars, "2", Any, Any) == vars[2][2] + # name not found => nothing + @test _findvariant(vars, "3", Any, Any) === nothing + # restrict block => select matching + @test _findvariant(vars, nothing, Any, Label) == vars[2][2] + # restrict block not found => nothing + @test _findvariant(vars, nothing, Any, LabelMulti) === nothing + end + + @testset "_findcheckpoint" begin + chs = ["check1", "check2"] + @test _findcheckpoint(chs) === nothing + @test _findcheckpoint(chs, pretrained = true) === "check1" + @test _findcheckpoint(chs, pretrained = true, name = "check2") === "check2" + @test _findcheckpoint(chs, pretrained = true, name = "check3") === nothing + end +end diff --git a/src/datablock/block.jl b/src/datablock/block.jl index 462d3ece10..c2423df49f 100644 --- a/src/datablock/block.jl +++ b/src/datablock/block.jl @@ -131,3 +131,39 @@ and other diagrams. """ blockname(block::Block) = string(nameof(typeof(block))) blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")" + +const BlockLike = Union{<:AbstractBlock, Type{<:AbstractBlock}, <:Tuple, Type{Any}} + +""" + function issubblock(subblock, superblock) + +Predicate whether `subblock` is a subblock of `superblock`. This means that `subblock` is + +- a subtype of a type `superblock <: Type{AbstractBlock}` +- an instance of a subtype of `superblock <: Type{AbstractBlock}` +- equal to `superblock` + +Both arguments can also be tuples. In that case, each element of the tuple `subblock` is +compared recursively against the elements of the tuple `superblock`. +""" +function issubblock end + +issubblock(_, _) = false +issubblock(sub::BlockLike, super::Type{Any}) = true +issubblock(sub::Tuple, super::Tuple) = + (length(sub) == length(super)) && all(map(issubblock, sub, super)) +issubblock(sub::Type{<:AbstractBlock}, super::Type{<:AbstractBlock}) = sub <: super +issubblock(sub::AbstractBlock, super::Type{<:AbstractBlock}) = issubblock(typeof(sub), super) +issubblock(sub::AbstractBlock, super::AbstractBlock) = sub == super + +@testset "issubblock" begin + @test issubblock(Label, Any) + @test issubblock((Label,), (Any,)) + @test issubblock((Label,), Any) + @test !issubblock(Label, (Any,)) + @test issubblock(Label{String}, Label) + @test !issubblock(Label, Label{String}) + @test issubblock(Label{Int}(1:10), Label{Int}) + @test issubblock(Label{Int}(1:10), Label{Int}(1:10)) + @test !issubblock(Label{Int}, Label{Int}(1:10)) +end From 936cd1c479ebad7881fb4557ba85b30aa1b18c11 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 12:58:01 +0100 Subject: [PATCH 02/14] Use 1.6 supported syntax --- src/Registries/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 0b0f549a9f..49735e38b5 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -122,7 +122,7 @@ Load a model specified by `row` from a model registry. """ function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) - (; loadfn, checkpoints, variants) = row + loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) From 89c8a6133a78423097f07dae928deee96224052f Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 19:15:07 +0100 Subject: [PATCH 03/14] Fix model variant printing --- src/Registries/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 49735e38b5..4f4b896f4d 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -87,7 +87,7 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) name = "Variants", description = "Model variants suitable for different learning tasks", defaultfn = (row, key) -> Pair{String, ModelVariant}[], - formatfn = d -> join(collect(keys(d)), ", "), + formatfn = d -> join(first.(d), ", "), ), checkpoints = Field( Vector{String}; From 3919a1d6d00b7a622cfba290bee685fb741bf227 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 19:16:39 +0100 Subject: [PATCH 04/14] WIP: Add Metalhead.jl models to registry --- FastVision/Project.toml | 12 ++- FastVision/src/FastVision.jl | 13 +++- FastVision/src/blocks/convfeatures.jl | 23 ++++++ FastVision/src/modelregistry.jl | 102 ++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 5 deletions(-) create mode 100644 FastVision/src/blocks/convfeatures.jl create mode 100644 FastVision/src/modelregistry.jl diff --git a/FastVision/Project.toml b/FastVision/Project.toml index da9b3c8bf6..f7f551961e 100644 --- a/FastVision/Project.toml +++ b/FastVision/Project.toml @@ -17,7 +17,10 @@ IndirectArrays = "9b13fd28-a010-5f03-acff-a1bbcff69959" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -33,14 +36,15 @@ FastAI = "0.5" FixedPointNumbers = "0.8" Flux = "0.12, 0.13" ImageIO = "0.6" -ImageInTerminal = "0.4" +ImageInTerminal = "0.4, 0.5" IndirectArrays = "0.5, 1" InlineTest = "0.2" -MLUtils = "0.2" -MakieCore = "0.3" +MLUtils = "0.2, 0.3" +MakieCore = "0.3, 0.4, 0.5" +Metalhead = "0.8" ProgressMeter = "1" ShowCases = "0.1" StaticArrays = "1.1" -UnicodePlots = "2" +UnicodePlots = "2, 3" Zygote = "0.6" julia = "1.6" diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 293e088e5d..c7a148d34a 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -38,6 +38,7 @@ using FastAI: # blocks Context, Training, Validation, Inference, Datasets using FastAI.Datasets +using FastAI.Registries: ModelVariant # extending import FastAI: @@ -45,7 +46,7 @@ import FastAI: encodedblock, decodedblock, showblock!, mockblock, setup, encodestate, decodestate -import Flux +using Flux: Flux, Chain, Conv, Dense import MLUtils: getobs, numobs, mapobs, eachobs import Colors: colormaps_sequential, Colorant, Color, Gray, Normed, RGB, alphacolor, deuteranopic, distinguishable_colors @@ -63,10 +64,13 @@ import IndirectArrays: IndirectArray import MakieCore import MakieCore: @recipe import MakieCore.Observables: @map +import Metalhead: Metalhead import ProgressMeter: Progress, next! +using Setfield: @set import StaticArrays: SVector import Statistics: mean, std import UnicodePlots +using Random: Random using InlineTest using ShowCases @@ -76,6 +80,7 @@ include("blocks/bounded.jl") include("blocks/image.jl") include("blocks/mask.jl") include("blocks/keypoints.jl") +include("blocks/convfeatures.jl") include("encodings/onehot.jl") include("encodings/imagepreprocessing.jl") @@ -93,6 +98,7 @@ include("tasks/keypointregression.jl") include("datasets.jl") include("recipes.jl") include("makie.jl") +include("modelregistry.jl") include("tests.jl") @@ -103,6 +109,11 @@ function __init__() push!(FastAI.learningtasks(), t) end end + foreach(values(_models)) do t + if !haskey(FastAI.models(), t.id) + push!(FastAI.models(), t) + end + end end export Image, Mask, Keypoints, Bounded, diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl new file mode 100644 index 0000000000..3425373ec4 --- /dev/null +++ b/FastVision/src/blocks/convfeatures.jl @@ -0,0 +1,23 @@ + + +""" + ConvFeatures{N}(n) <: Block + ConvFeatures(n, size) + +Block representing features from a convolutional neural network backbone +with `n` feature channels and `N` spatial dimensions. +""" +struct ConvFeatures{N} <: Block + n::Int + size::NTuple{N, DimSize} +end + + +ConvFeatures{N}(n) where N = ConvFeatures{N}(n, ntuple(_ -> :, N)) + +function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M,N,T} + M == N+1 || return false + return checksize(block.size, size(a)) +end + +FastAI.mockblock(block::ConvFeatures) = rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl new file mode 100644 index 0000000000..b5854e3cf3 --- /dev/null +++ b/FastVision/src/modelregistry.jl @@ -0,0 +1,102 @@ + + +const _models = Dict{String, Any}() + +function cnn_variants(; nfeatures = :, hasweights = false) + variants = Pair{String, ModelVariant}[] + + hasweights && push!(variants, "imagenet_1k" => ModelVariant( + input=ImageTensor{2}(3), + # TODO: use actual ImageNet classes + output=FastAI.OneHotLabel{Int}(1:1000), + )) + push!(variants, "classifier" => ModelVariant( + make_cnn_classifier, + ImageTensor{2}, + FastAI.OneHotTensor{0}, + )) + push!(variants, "backbone" => ModelVariant( + make_cnn_backbone, + ImageTensor{2}, + ConvFeatures{2}(nfeatures), + )) + + return variants +end + +function make_cnn_classifier(model, input::ImageTensor, output::OneHotTensor{0}) + backbone = _backbone_with_channels(model.layers[1], input.nchannels) + head = _head_with_classes(model.layers[2], length(output.classes)) + return Chain(backbone, head) +end + +function make_cnn_backbone(model, input::ImageTensor{N}, output::ConvFeatures{N}) where N + backbone = _backbone_with_channels(model.layers[1], input.nchannels) + return backbone +end + +function _backbone_with_channels(backbone, n) + layer = backbone.layers[1].layers[1] + layer isa Conv || throw(ArgumentError( + """To change the number of input channels, + `backbone.layers[1].layers[1]` must be a `Conv` layer.""")) + + sz = size(layer.weight) + ks = sz[begin:end-2] + in_, out = sz[end-1:end] + in_ == n && return backbone + + layer = @set layer.weight = Flux.kaiming_normal(Random.GLOBAL_RNG, ks..., n, out) + return @set backbone.layers[1].layers[1] = layer +end + +function _head_with_classes(head, n) + head.layers[end] isa Dense || throw(ArgumentError( + """To change the number of output classes, + the last layer in head must be a `Dense` layer.""")) + c, f = size(head[end].weight) + if c == n + # Already has correct number of classes + head + else + @set head.layers[end] = Dense(f, n) + end +end + +function metalhead_loadfn(modelfn, args...) + return function (checkpoint; kwargs...) + return modelfn(args...; pretrain=!isnothing(checkpoint), kwargs...) + end +end + +for depth in (18,) + hasweights = true + nfeatures = 512 + id = "metalhead/resnet$depth" + _models[id] = (; + id = id, + variants = cnn_variants(; hasweights, nfeatures), + checkpoints = hasweights ? ["imagenet1k"] : String[], + backend = :flux, + loadfn = metalhead_loadfn(Metalhead.ResNet, depth) + ) +end + + +@testset "Model variants" begin + @testset "make_cnn_classifier" begin + m = Metalhead.ResNet(18) + clf = make_cnn_classifier(m, ImageTensor{2}(3), FastAI.OneHotLabel{Int}(1:10)) + @test Flux.outputsize(clf, (256, 256, 3, 1)) == (10, 1) + + clf2 = make_cnn_classifier(m, ImageTensor{2}(5), FastAI.OneHotLabel{Int}(1:100)) + @test Flux.outputsize(clf2, (256, 256, 5, 1)) == (100, 1) + end + + @testset "make_cnn_backbone" begin + m = Metalhead.ResNet(18) + clf = make_cnn_backbone(m, ImageTensor{2}(10), ConvFeatures{2}(512)) + @test Flux.outputsize(clf, (256, 256, 10, 1)) == (8, 8, 512, 1) + + end +end From 4a0e6e01a631fc014619bdf66304ecc3754bdcf4 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 27 Nov 2022 11:05:44 +0100 Subject: [PATCH 05/14] Use correct `load` function in model registry. Formats and adds more docs --- src/Registries/models.jl | 161 +++++++++++++++++++++------------------ 1 file changed, 85 insertions(+), 76 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 4f4b896f4d..71abe69e6f 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1,4 +1,8 @@ +# # Model registry +# +# This file defines [`models`](#), a feature registry for models. +# ## Registry definition const _MODELS_DESCRIPTION = """ A `FeatureRegistry` for models. Allows you to find and load models for various learning @@ -46,13 +50,23 @@ Loading a model variant for a specific task: available variants. """ - """ - struct ModelVariant(; transform, input, output) + struct ModelVariant(; transform, xblock, yblock) A `ModelVariant` is a model transformation that changes a model so that its input and output -are subblocks (see [`issubblock`](#)) of `blocks = (inblock, outblock)`. +are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. + +The model transformation function `transform` takes a model and two concrete _instances_ +of the variant's compatible blocks, returning a transformed model. + + `transform(model, xblock, yblock)` + +- `model` is the original model that is transformed +- `xblock` is the [`Block`](#) of the data that is input to the model. +- `yblock` is the [`Block`](#) of the data that the model outputs. +If you're working with a [`SupervisedTask`](#) `task`, these blocks correspond to +`inputblock = getblocks(task).x` and `outputblock = getblocks(task).y` """ struct ModelVariant transformfn::Any # callable @@ -60,74 +74,63 @@ struct ModelVariant yblock::BlockLike end _default_transform(model, xblock, yblock; kwargs...) = model -ModelVariant(; transform = _default_transform, input = Any, output = Any) = - ModelVariant(transform, input, output) - - -# Registry definition +function ModelVariant(; transform = _default_transform, xblock = Any, yblock = Any) + ModelVariant(transform, xblock, yblock) +end function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) fields = (; - id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), - description = Field( - String; - name = "Description", - optional = true, - description = "More information about the model", - formatfn = FeatureRegistries.md_format, - ), - backend = Field( - Symbol, - name = "Backend", - default = :flux, - description = "The backend deep learning framework that the model uses. The default is `:flux`.", - ), - variants = Field( - Vector{Pair{String,ModelVariant}}, - name = "Variants", - description = "Model variants suitable for different learning tasks", - defaultfn = (row, key) -> Pair{String, ModelVariant}[], - formatfn = d -> join(first.(d), ", "), - ), - checkpoints = Field( - Vector{String}; - name = "Checkpoints", - description = "Pretrained weight checkpoints that can be loaded for the model", - formatfn = cs -> join(cs, ", "), - defaultfn = (row, key) -> String[], - ), - loadfn = Field( - Any; - name = "Load function", - description = """ - Function that loads the base version of the model, optionally with weights. - It is called with the name of the selected checkpoint fro `checkpoints`, - i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with - `nothing`, i.e. loadfn(`nothing`). - - Any unknown keyword arguments passed to `load`, i.e. - `load(registry[id]; kwargs...)` will be passed along to `loadfn`. - """, - optional = false, - ) - ) - return Registry(fields; name, loadfn = identity, description = description) + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field(String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format), + backend = Field(Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`."), + variants = Field(Vector{Pair{String, ModelVariant}}, + name = "Variants", + optional = false, + description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", + formatfn = d -> join(first.(d), ", ")), + checkpoints = Field(Vector{String}; + name = "Checkpoints", + description = """ + Pretrained weight checkpoints that can be loaded for the model. Checkpoints are listed as a + `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[]), + loadfn = Field(Any; + name = "Load function", + description = """ + Function that loads the base version of the model, optionally with weights. + It is called with the name of the selected checkpoint fro `checkpoints`, + i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with + `nothing`, i.e. loadfn(`nothing`). + + Any unknown keyword arguments passed to `load`, i.e. + `load(registry[id]; kwargs...)` will be passed along to `loadfn`. + """, + optional = false)) + return Registry(fields; name, loadfn = _loadmodel, description = description) end """ _loadmodel(row) Load a model specified by `row` from a model registry. - - """ -function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) +function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, + pretrained = !isnothing(checkpoint), kwargs...) loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) - pretrained && isnothing(checkpoint) && throw(NoCheckpointFoundError(checkpoints, checkpoint)) + pretrained && isnothing(checkpoint) && + throw(NoCheckpointFoundError(checkpoints, checkpoint)) variant = _findvariant(variants, variant, input, output) isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) @@ -138,6 +141,7 @@ function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = return model end +# ### Errors struct NoModelVariantFoundError <: Exception variants::Vector{Pair{String, ModelVariant}} input::BlockLike @@ -150,11 +154,8 @@ struct NoCheckpointFoundError <: Exception checkpoint::Union{String, Nothing} end - - const MODELS = _modelregistry() - """ models() @@ -162,8 +163,6 @@ $_MODELS_DESCRIPTION """ models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) - - function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) if isempty(checkpoints) nothing @@ -177,7 +176,8 @@ function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = end end -function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname::Union{String, Nothing}, xblock, yblock) +function _findvariant(variants::Vector{Pair{String, ModelVariant}}, + variantname::Union{String, Nothing}, xblock, yblock) if !isnothing(variantname) variants = filter(variants) do (name, _) name == variantname @@ -189,28 +189,34 @@ function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname:: isnothing(i) ? nothing : variants[i][2] end +# ## Tests @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() push!(reg, (; - id = "test", - loadfn = _ -> 1, - )) + id = "test", + loadfn = _ -> 1, + variants = ["base" => ModelVariant()])) + + @test load(reg["test"]) == 1 + @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end @testset "_loadmodel" begin reg = _modelregistry() - @test_nowarn push!(reg, (; - id = "test", - loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), - checkpoints = ["checkpoint", "checkpoint2"], - variants = [ - "base" => ModelVariant(), - "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, k + 1), Any, Label), - ] - )) + @test_nowarn push!(reg, + (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => ModelVariant(), + "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, + k + 1), + Any, Label), + ])) entry = reg["test"] @test _loadmodel(entry) == (nothing, 1) @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) @@ -224,7 +230,10 @@ end end @testset "_findvariant" begin - vars = ["1" => ModelVariant(identity, Any, Any), "2" => ModelVariant(identity, Any, Label)] + vars = [ + "1" => ModelVariant(identity, Any, Any), + "2" => ModelVariant(identity, Any, Label), + ] # no restrictions => select first variant @test _findvariant(vars, nothing, Any, Any) == vars[1][2] # name => select named variant @@ -233,7 +242,7 @@ end @test _findvariant(vars, "3", Any, Any) === nothing # restrict block => select matching @test _findvariant(vars, nothing, Any, Label) == vars[2][2] - # restrict block not found => nothing + # restrict block not found => nothing @test _findvariant(vars, nothing, Any, LabelMulti) === nothing end From 0a9353e49d79213239446bd5469c280b5f449632 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 27 Nov 2022 13:11:34 +0100 Subject: [PATCH 06/14] Add all Metalhead models --- FastVision/src/blocks/convfeatures.jl | 12 +-- FastVision/src/modelregistry.jl | 142 +++++++++++++++++++------- 2 files changed, 110 insertions(+), 44 deletions(-) diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl index 3425373ec4..c1ae6ed554 100644 --- a/FastVision/src/blocks/convfeatures.jl +++ b/FastVision/src/blocks/convfeatures.jl @@ -1,5 +1,4 @@ - """ ConvFeatures{N}(n) <: Block ConvFeatures(n, size) @@ -12,12 +11,13 @@ struct ConvFeatures{N} <: Block size::NTuple{N, DimSize} end +ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N)) -ConvFeatures{N}(n) where N = ConvFeatures{N}(n, ntuple(_ -> :, N)) - -function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M,N,T} - M == N+1 || return false +function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T} + M == N + 1 || return false return checksize(block.size, size(a)) end -FastAI.mockblock(block::ConvFeatures) = rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) +function FastAI.mockblock(block::ConvFeatures) + rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) +end diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl index b5854e3cf3..4d0def0de7 100644 --- a/FastVision/src/modelregistry.jl +++ b/FastVision/src/modelregistry.jl @@ -1,25 +1,21 @@ - const _models = Dict{String, Any}() function cnn_variants(; nfeatures = :, hasweights = false) variants = Pair{String, ModelVariant}[] - hasweights && push!(variants, "imagenet_1k" => ModelVariant( - input=ImageTensor{2}(3), - # TODO: use actual ImageNet classes - output=FastAI.OneHotLabel{Int}(1:1000), - )) - push!(variants, "classifier" => ModelVariant( - make_cnn_classifier, - ImageTensor{2}, - FastAI.OneHotTensor{0}, - )) - push!(variants, "backbone" => ModelVariant( - make_cnn_backbone, - ImageTensor{2}, - ConvFeatures{2}(nfeatures), - )) + hasweights && push!(variants, + "imagenet_1k" => ModelVariant(xblock = ImageTensor{2}(3), + # TODO: use actual ImageNet classes + yblock = FastAI.OneHotLabel{Int}(1:1000))) + push!(variants, + "classifier" => ModelVariant(make_cnn_classifier, + ImageTensor{2}, + FastAI.OneHotTensor{0})) + push!(variants, + "backbone" => ModelVariant(make_cnn_backbone, + ImageTensor{2}, + ConvFeatures{2}(nfeatures))) return variants end @@ -30,20 +26,27 @@ function make_cnn_classifier(model, input::ImageTensor, output::OneHotTensor{0}) return Chain(backbone, head) end -function make_cnn_backbone(model, input::ImageTensor{N}, output::ConvFeatures{N}) where N +function make_cnn_classifier(model, ::Type{Any}, ::Type{Any}) + return model +end + +function make_cnn_backbone(model, input::ImageTensor{N}, ::ConvFeatures{N}) where {N} backbone = _backbone_with_channels(model.layers[1], input.nchannels) return backbone end +function make_cnn_backbone(model, ::Type{Any}, ::Type{Any}) + return model.layers[1] +end + function _backbone_with_channels(backbone, n) layer = backbone.layers[1].layers[1] - layer isa Conv || throw(ArgumentError( - """To change the number of input channels, - `backbone.layers[1].layers[1]` must be a `Conv` layer.""")) + layer isa Conv || throw(ArgumentError("""To change the number of input channels, + `backbone.layers[1].layers[1]` must be a `Conv` layer.""")) sz = size(layer.weight) - ks = sz[begin:end-2] - in_, out = sz[end-1:end] + ks = sz[begin:(end - 2)] + in_, out = sz[(end - 1):end] in_ == n && return backbone layer = @set layer.weight = Flux.kaiming_normal(Random.GLOBAL_RNG, ks..., n, out) @@ -51,9 +54,9 @@ function _backbone_with_channels(backbone, n) end function _head_with_classes(head, n) - head.layers[end] isa Dense || throw(ArgumentError( - """To change the number of output classes, - the last layer in head must be a `Dense` layer.""")) + head.layers[end] isa Dense || + throw(ArgumentError("""To change the number of output classes, + the last layer in head must be a `Dense` layer.""")) c, f = size(head[end].weight) if c == n # Already has correct number of classes @@ -65,24 +68,84 @@ end function metalhead_loadfn(modelfn, args...) return function (checkpoint; kwargs...) - return modelfn(args...; pretrain=!isnothing(checkpoint), kwargs...) + return modelfn(args...; pretrain = !isnothing(checkpoint), kwargs...) end end -for depth in (18,) - hasweights = true - nfeatures = 512 - id = "metalhead/resnet$depth" +# model config: id, description, basefn, variant, hasweights, nfeatures +const METALHEAD_CONFIGS = [ + ("metalhead/resnet18", "ResNet18", metalhead_loadfn(Metalhead.ResNet, 18), true, 512), + ("metalhead/resnet34", "ResNet34", metalhead_loadfn(Metalhead.ResNet, 34), true, 512), + ("metalhead/resnet50", "ResNet50", metalhead_loadfn(Metalhead.ResNet, 50), true, 2048), + ("metalhead/resnet101", "ResNet101", metalhead_loadfn(Metalhead.ResNet, 101), true, + 2048), + ("metalhead/resnet152", "ResNet152", metalhead_loadfn(Metalhead.ResNet, 152), true, + 2048), + ("metalhead/wideresnet50", "WideResNet50", metalhead_loadfn(Metalhead.WideResNet, 50), + true, 2048), + ("metalhead/wideresnet101", "WideResNet101", + metalhead_loadfn(Metalhead.WideResNet, 101), true, 2048), + ("metalhead/wideresnet152", "WideResNet152", + metalhead_loadfn(Metalhead.WideResNet, 152), true, 2048), + ("metalhead/googlenet", "GoogLeNet", metalhead_loadfn(Metalhead.GoogLeNet), false, + 1024), + ("metalhead/inceptionv3", "InceptionV3", metalhead_loadfn(Metalhead.Inceptionv3), false, + 2048), + ("metalhead/inceptionv4", "InceptionV4", metalhead_loadfn(Metalhead.Inceptionv4), false, + 1536), + ("metalhead/squeezenet", "SqueezeNet", metalhead_loadfn(Metalhead.SqueezeNet), true, + 512), + ("metalhead/densenet-121", "DenseNet121", metalhead_loadfn(Metalhead.DenseNet, 121), + false, 1024), + ("metalhead/densenet-161", "DenseNet161", metalhead_loadfn(Metalhead.DenseNet, 161), + false, 1472), + ("metalhead/densenet-169", "DenseNet169", metalhead_loadfn(Metalhead.DenseNet, 169), + false, 1664), + ("metalhead/densenet-201", "DenseNet201", metalhead_loadfn(Metalhead.DenseNet, 201), + false, 1920), + ("metalhead/resnext50", "ResNeXt50", metalhead_loadfn(Metalhead.ResNeXt, 50), true, + 2048), + ("metalhead/resnext101", "ResNeXt101", metalhead_loadfn(Metalhead.ResNeXt, 101), true, + 2048), + ("metalhead/resnext152", "ResNeXt152", metalhead_loadfn(Metalhead.ResNeXt, 152), true, + 2048), + ("metalhead/mobilenetv1", "MobileNetV1", metalhead_loadfn(Metalhead.MobileNetv1), false, + 1024), + ("metalhead/mobilenetv2", "MobileNetV2", metalhead_loadfn(Metalhead.MobileNetv2), false, + 1280), + ("metalhead/mobilenetv3-small", "MobileNetV3 Small", + metalhead_loadfn(Metalhead.MobileNetv3, :small), false, 576), + ("metalhead/mobilenetv3-large", "MobileNetV3 Large", + metalhead_loadfn(Metalhead.MobileNetv3, :large), false, 960), + ("metalhead/efficientnet-b0", "EfficientNet-B0", + metalhead_loadfn(Metalhead.EfficientNet, :b0), false, 1280), + ("metalhead/efficientnet-b0", "EfficientNet-B0", + metalhead_loadfn(Metalhead.EfficientNet, :b0), false, 1280), + ("metalhead/efficientnet-b1", "EfficientNet-B1", + metalhead_loadfn(Metalhead.EfficientNet, :b1), false, 1280), + ("metalhead/efficientnet-b2", "EfficientNet-B2", + metalhead_loadfn(Metalhead.EfficientNet, :b2), false, 1280), + ("metalhead/efficientnet-b3", "EfficientNet-B3", + metalhead_loadfn(Metalhead.EfficientNet, :b3), false, 1280), + ("metalhead/efficientnet-b4", "EfficientNet-B4", + metalhead_loadfn(Metalhead.EfficientNet, :b4), false, 1280), + ("metalhead/efficientnet-b5", "EfficientNet-B5", + metalhead_loadfn(Metalhead.EfficientNet, :b5), false, 1280), + ("metalhead/efficientnet-b6", "EfficientNet-B6", + metalhead_loadfn(Metalhead.EfficientNet, :b6), false, 1280), + ("metalhead/efficientnet-b7", "EfficientNet-B7", + metalhead_loadfn(Metalhead.EfficientNet, :b7), false, 1280), + ("metalhead/efficientnet-b8", "EfficientNet-B8", + metalhead_loadfn(Metalhead.EfficientNet, :b8), false, 1280), +] +for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS _models[id] = (; - id = id, - variants = cnn_variants(; hasweights, nfeatures), - checkpoints = hasweights ? ["imagenet1k"] : String[], - backend = :flux, - loadfn = metalhead_loadfn(Metalhead.ResNet, depth) - ) + id, loadfn, description, + variants = cnn_variants(; hasweights, nfeatures), + checkpoints = hasweights ? ["imagenet1k"] : String[], + backend = :flux) end - @testset "Model variants" begin @testset "make_cnn_classifier" begin m = Metalhead.ResNet(18) @@ -97,6 +160,9 @@ end m = Metalhead.ResNet(18) clf = make_cnn_backbone(m, ImageTensor{2}(10), ConvFeatures{2}(512)) @test Flux.outputsize(clf, (256, 256, 10, 1)) == (8, 8, 512, 1) - end end + +@testset "Metalhead models" begin for id in models(id = "metalhead")[:, :id] + @test_nowarn load(models()[id]; variant = "backbone") +end end From 9aef6d268e877d5f32d355fa35af4d4ff4777427 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 4 Dec 2022 12:12:17 +0100 Subject: [PATCH 07/14] Change `ModelVariant` API Now handles both loading checkpoints and possible transformations. This makes it easier to ntegrate with third-party model libraries that likewise handle both with a single function. --- src/Registries/models.jl | 101 ++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 71abe69e6f..f236f8a06b 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -50,33 +50,39 @@ Loading a model variant for a specific task: available variants. """ + +# ## `ModelVariant` interface """ - struct ModelVariant(; transform, xblock, yblock) + abstract type ModelVariant + +A `ModelVariant` handles loading a model, optionally with pretrained weights and +transforming it so that it can be used for specific learning tasks. + -A `ModelVariant` is a model transformation that changes a model so that its input and output are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. -The model transformation function `transform` takes a model and two concrete _instances_ -of the variant's compatible blocks, returning a transformed model. +## Interface - `transform(model, xblock, yblock)` +- [`compatibleblocks`](#)`(variant)` returns a tuple `(xblock, yblock)` of [`BlockLike`](#) that + are compatible with the model. This means that a variant can be used for a task with + input and output blocks `blocks`, if [`issubblock`](#)`(blocks, compatibleblocks(variant))`. +- [`loadvariant`](#)`(::ModelVariant, xblock, yblock, checkpoint; kwargs...)` loads a model + compatible with block instances `xblock` and `yblock`, with (optionally) weights + from `checkpoint`. +""" +abstract type ModelVariant end -- `model` is the original model that is transformed -- `xblock` is the [`Block`](#) of the data that is input to the model. -- `yblock` is the [`Block`](#) of the data that the model outputs. +""" + compatibleblocks(::ModelVariant) -If you're working with a [`SupervisedTask`](#) `task`, these blocks correspond to -`inputblock = getblocks(task).x` and `outputblock = getblocks(task).y` +Indicate compatible input and output block for a model variant. """ -struct ModelVariant - transformfn::Any # callable - xblock::BlockLike - yblock::BlockLike -end -_default_transform(model, xblock, yblock; kwargs...) = model -function ModelVariant(; transform = _default_transform, xblock = Any, yblock = Any) - ModelVariant(transform, xblock, yblock) -end +function compatibleblocks end + +function loadvariant end + + +# ## Model registry creation function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) fields = (; @@ -102,18 +108,7 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", formatfn = cs -> join(cs, ", "), defaultfn = (row, key) -> String[]), - loadfn = Field(Any; - name = "Load function", - description = """ - Function that loads the base version of the model, optionally with weights. - It is called with the name of the selected checkpoint fro `checkpoints`, - i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with - `nothing`, i.e. loadfn(`nothing`). - - Any unknown keyword arguments passed to `load`, i.e. - `load(registry[id]; kwargs...)` will be passed along to `loadfn`. - """, - optional = false)) + ) return Registry(fields; name, loadfn = _loadmodel, description = description) end @@ -124,7 +119,7 @@ Load a model specified by `row` from a model registry. """ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) - loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support + checkpoints, variants = row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) @@ -135,25 +130,27 @@ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoin isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) # Loading - basemodel = loadfn(checkpoint, kwargs...) - model = variant.transformfn(basemodel, input, output) - - return model + return loadvariant(variant, input, output, checkpoint; kwargs...) end # ### Errors + +# TODO: Implement Base.showerror struct NoModelVariantFoundError <: Exception - variants::Vector{Pair{String, ModelVariant}} + variants::Vector input::BlockLike output::BlockLike variant::Union{String, Nothing} end +# TODO: Implement Base.showerror struct NoCheckpointFoundError <: Exception checkpoints::Vector{String} checkpoint::Union{String, Nothing} end +# ## Create the default registry instance + const MODELS = _modelregistry() """ @@ -163,6 +160,8 @@ $_MODELS_DESCRIPTION """ models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) +# ## Helpers + function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) if isempty(checkpoints) nothing @@ -176,7 +175,7 @@ function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = end end -function _findvariant(variants::Vector{Pair{String, ModelVariant}}, +function _findvariant(variants::Vector, variantname::Union{String, Nothing}, xblock, yblock) if !isnothing(variantname) variants = filter(variants) do (name, _) @@ -184,23 +183,31 @@ function _findvariant(variants::Vector{Pair{String, ModelVariant}}, end end i = findfirst(variants) do (_, variant) - issubblock(variant.xblock, xblock) && issubblock(variant.yblock, yblock) + v_xblock, v_yblock = compatibleblocks(variant) + issubblock(v_xblock, xblock) && issubblock(v_yblock, yblock) end isnothing(i) ? nothing : variants[i][2] end # ## Tests +struct MockVariant <: ModelVariant + model + blocks +end + +compatibleblocks(variant::MockVariant) = variant.blocks +loadvariant(variant::MockVariant, x, y, ch) = (ch, variant.model) + @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() push!(reg, (; id = "test", - loadfn = _ -> 1, - variants = ["base" => ModelVariant()])) + variants = ["base" => MockVariant(1, (Any, Any))])) - @test load(reg["test"]) == 1 + @test load(reg["test"]) == (nothing, 1) @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end @@ -212,10 +219,8 @@ end loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), checkpoints = ["checkpoint", "checkpoint2"], variants = [ - "base" => ModelVariant(), - "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, - k + 1), - Any, Label), + "base" => MockVariant(1, (Any, Any)), + "ext" => MockVariant(2, (Any, Label)), ])) entry = reg["test"] @test _loadmodel(entry) == (nothing, 1) @@ -231,8 +236,8 @@ end @testset "_findvariant" begin vars = [ - "1" => ModelVariant(identity, Any, Any), - "2" => ModelVariant(identity, Any, Label), + "1" => MockVariant(1, (Any, Any)), + "2" => MockVariant(1, (Any, Label)), ] # no restrictions => select first variant @test _findvariant(vars, nothing, Any, Any) == vars[1][2] From bc7acd0fe16fbe469392f3a2403a9863ce9190c0 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 11 Dec 2022 16:47:50 +0100 Subject: [PATCH 08/14] WIP: adapt based on changes to model variant API --- FastVision/src/FastVision.jl | 2 +- FastVision/src/modelregistry.jl | 170 ++++++++++++++------------------ 2 files changed, 76 insertions(+), 96 deletions(-) diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index c7a148d34a..792dcb2e83 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -38,7 +38,7 @@ using FastAI: # blocks Context, Training, Validation, Inference, Datasets using FastAI.Datasets -using FastAI.Registries: ModelVariant +import FastAI.Registries: ModelVariant, compatibleblocks, loadvariant # extending import FastAI: diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl index 4d0def0de7..a7ca1e4c09 100644 --- a/FastVision/src/modelregistry.jl +++ b/FastVision/src/modelregistry.jl @@ -1,147 +1,127 @@ -const _models = Dict{String, Any}() +# ## Model variants for Metalhead.jl models -function cnn_variants(; nfeatures = :, hasweights = false) - variants = Pair{String, ModelVariant}[] - - hasweights && push!(variants, - "imagenet_1k" => ModelVariant(xblock = ImageTensor{2}(3), - # TODO: use actual ImageNet classes - yblock = FastAI.OneHotLabel{Int}(1:1000))) - push!(variants, - "classifier" => ModelVariant(make_cnn_classifier, - ImageTensor{2}, - FastAI.OneHotTensor{0})) - push!(variants, - "backbone" => ModelVariant(make_cnn_backbone, - ImageTensor{2}, - ConvFeatures{2}(nfeatures))) - - return variants +struct MetalheadClassifierVariant <: ModelVariant + fn end - -function make_cnn_classifier(model, input::ImageTensor, output::OneHotTensor{0}) - backbone = _backbone_with_channels(model.layers[1], input.nchannels) - head = _head_with_classes(model.layers[2], length(output.classes)) - return Chain(backbone, head) +compatibleblocks(::MetalheadClassifierVariant) = (ImageTensor{2}, FastAI.OneHotTensor{0}) +function loadvariant(v::MetalheadClassifierVariant, xblock::ImageTensor{2}, yblock::FastAI.OneHotTensor{0}, checkpoint; kwargs...) + return v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels, + nclasses=length(yblock.classes), kwargs...) end - -function make_cnn_classifier(model, ::Type{Any}, ::Type{Any}) - return model +function loadvariant(v::MetalheadClassifierVariant, xblock, yblock, checkpoint; kwargs...) + return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) end -function make_cnn_backbone(model, input::ImageTensor{N}, ::ConvFeatures{N}) where {N} - backbone = _backbone_with_channels(model.layers[1], input.nchannels) - return backbone +struct MetalheadImageNetVariant <: ModelVariant + fn +end +compatibleblocks(::MetalheadImageNetVariant) = (ImageTensor{2}(3), FastAI.OneHotTensor{0, Int}(1:1000)) +function loadvariant(v::MetalheadImageNetVariant, xblock, yblock, checkpoint; kwargs...) + return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) end -function make_cnn_backbone(model, ::Type{Any}, ::Type{Any}) +struct MetalheadBackboneVariant <: ModelVariant + fn + nfeatures::Int +end +compatibleblocks(variant::MetalheadBackboneVariant) = (ImageTensor{2}, ConvFeatures{2}(variant.nfeatures)) +function loadvariant(v::MetalheadBackboneVariant, xblock::ImageTensor{2}, yblock::ConvFeatures{2}, checkpoint; kwargs...) + model = v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels, + kwargs...) + return model.layers[1] +end +function loadvariant(v::MetalheadBackboneVariant, xblock, yblock, checkpoint; kwargs...) + model = v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) return model.layers[1] end -function _backbone_with_channels(backbone, n) - layer = backbone.layers[1].layers[1] - layer isa Conv || throw(ArgumentError("""To change the number of input channels, - `backbone.layers[1].layers[1]` must be a `Conv` layer.""")) +function metalheadvariants(modelfn, nfeatures) + return [ + "imagenet1k" => MetalheadImageNetVariant(modelfn), + "classifier" => MetalheadClassifierVariant(modelfn), + "backbone" => MetalheadBackboneVariant(modelfn, nfeatures), + ] +end + - sz = size(layer.weight) - ks = sz[begin:(end - 2)] - in_, out = sz[(end - 1):end] - in_ == n && return backbone +const _models = Dict{String, Any}() - layer = @set layer.weight = Flux.kaiming_normal(Random.GLOBAL_RNG, ks..., n, out) - return @set backbone.layers[1].layers[1] = layer -end -function _head_with_classes(head, n) - head.layers[end] isa Dense || - throw(ArgumentError("""To change the number of output classes, - the last layer in head must be a `Dense` layer.""")) - c, f = size(head[end].weight) - if c == n - # Already has correct number of classes - head - else - @set head.layers[end] = Dense(f, n) - end -end +fix(fn, args...; kwargs...) = (_args...; _kwargs...) -> fn(args..., _args...; kwargs..., _kwargs...) + -function metalhead_loadfn(modelfn, args...) - return function (checkpoint; kwargs...) - return modelfn(args...; pretrain = !isnothing(checkpoint), kwargs...) - end -end # model config: id, description, basefn, variant, hasweights, nfeatures const METALHEAD_CONFIGS = [ - ("metalhead/resnet18", "ResNet18", metalhead_loadfn(Metalhead.ResNet, 18), true, 512), - ("metalhead/resnet34", "ResNet34", metalhead_loadfn(Metalhead.ResNet, 34), true, 512), - ("metalhead/resnet50", "ResNet50", metalhead_loadfn(Metalhead.ResNet, 50), true, 2048), - ("metalhead/resnet101", "ResNet101", metalhead_loadfn(Metalhead.ResNet, 101), true, + ("metalhead/resnet18", "ResNet18", fix(Metalhead.ResNet, 18), true, 512), + ("metalhead/resnet34", "ResNet34", fix(Metalhead.ResNet, 34), true, 512), + ("metalhead/resnet50", "ResNet50", fix(Metalhead.ResNet, 50), true, 2048), + ("metalhead/resnet101", "ResNet101", fix(Metalhead.ResNet, 101), true, 2048), - ("metalhead/resnet152", "ResNet152", metalhead_loadfn(Metalhead.ResNet, 152), true, + ("metalhead/resnet152", "ResNet152", fix(Metalhead.ResNet, 152), true, 2048), - ("metalhead/wideresnet50", "WideResNet50", metalhead_loadfn(Metalhead.WideResNet, 50), + ("metalhead/wideresnet50", "WideResNet50", fix(Metalhead.WideResNet, 50), true, 2048), ("metalhead/wideresnet101", "WideResNet101", - metalhead_loadfn(Metalhead.WideResNet, 101), true, 2048), + fix(Metalhead.WideResNet, 101), true, 2048), ("metalhead/wideresnet152", "WideResNet152", - metalhead_loadfn(Metalhead.WideResNet, 152), true, 2048), - ("metalhead/googlenet", "GoogLeNet", metalhead_loadfn(Metalhead.GoogLeNet), false, + fix(Metalhead.WideResNet, 152), true, 2048), + ("metalhead/googlenet", "GoogLeNet", Metalhead.GoogLeNet, false, 1024), - ("metalhead/inceptionv3", "InceptionV3", metalhead_loadfn(Metalhead.Inceptionv3), false, + ("metalhead/inceptionv3", "InceptionV3", Metalhead.Inceptionv3, false, 2048), - ("metalhead/inceptionv4", "InceptionV4", metalhead_loadfn(Metalhead.Inceptionv4), false, + ("metalhead/inceptionv4", "InceptionV4", Metalhead.Inceptionv4, false, 1536), - ("metalhead/squeezenet", "SqueezeNet", metalhead_loadfn(Metalhead.SqueezeNet), true, + ("metalhead/squeezenet", "SqueezeNet", Metalhead.SqueezeNet, true, 512), - ("metalhead/densenet-121", "DenseNet121", metalhead_loadfn(Metalhead.DenseNet, 121), + ("metalhead/densenet-121", "DenseNet121", fix(Metalhead.DenseNet, 121), false, 1024), - ("metalhead/densenet-161", "DenseNet161", metalhead_loadfn(Metalhead.DenseNet, 161), + ("metalhead/densenet-161", "DenseNet161", fix(Metalhead.DenseNet, 161), false, 1472), - ("metalhead/densenet-169", "DenseNet169", metalhead_loadfn(Metalhead.DenseNet, 169), + ("metalhead/densenet-169", "DenseNet169", fix(Metalhead.DenseNet, 169), false, 1664), - ("metalhead/densenet-201", "DenseNet201", metalhead_loadfn(Metalhead.DenseNet, 201), + ("metalhead/densenet-201", "DenseNet201", fix(Metalhead.DenseNet, 201), false, 1920), - ("metalhead/resnext50", "ResNeXt50", metalhead_loadfn(Metalhead.ResNeXt, 50), true, + ("metalhead/resnext50", "ResNeXt50", fix(Metalhead.ResNeXt, 50), true, 2048), - ("metalhead/resnext101", "ResNeXt101", metalhead_loadfn(Metalhead.ResNeXt, 101), true, + ("metalhead/resnext101", "ResNeXt101", fix(Metalhead.ResNeXt, 101), true, 2048), - ("metalhead/resnext152", "ResNeXt152", metalhead_loadfn(Metalhead.ResNeXt, 152), true, + ("metalhead/resnext152", "ResNeXt152", fix(Metalhead.ResNeXt, 152), true, 2048), - ("metalhead/mobilenetv1", "MobileNetV1", metalhead_loadfn(Metalhead.MobileNetv1), false, + ("metalhead/mobilenetv1", "MobileNetV1", Metalhead.MobileNetv1, false, 1024), - ("metalhead/mobilenetv2", "MobileNetV2", metalhead_loadfn(Metalhead.MobileNetv2), false, + ("metalhead/mobilenetv2", "MobileNetV2", Metalhead.MobileNetv2, false, 1280), ("metalhead/mobilenetv3-small", "MobileNetV3 Small", - metalhead_loadfn(Metalhead.MobileNetv3, :small), false, 576), + fix(Metalhead.MobileNetv3, :small), false, 576), ("metalhead/mobilenetv3-large", "MobileNetV3 Large", - metalhead_loadfn(Metalhead.MobileNetv3, :large), false, 960), + fix(Metalhead.MobileNetv3, :large), false, 960), ("metalhead/efficientnet-b0", "EfficientNet-B0", - metalhead_loadfn(Metalhead.EfficientNet, :b0), false, 1280), + fix(Metalhead.EfficientNet, :b0), false, 1280), ("metalhead/efficientnet-b0", "EfficientNet-B0", - metalhead_loadfn(Metalhead.EfficientNet, :b0), false, 1280), + fix(Metalhead.EfficientNet, :b0), false, 1280), ("metalhead/efficientnet-b1", "EfficientNet-B1", - metalhead_loadfn(Metalhead.EfficientNet, :b1), false, 1280), + fix(Metalhead.EfficientNet, :b1), false, 1280), ("metalhead/efficientnet-b2", "EfficientNet-B2", - metalhead_loadfn(Metalhead.EfficientNet, :b2), false, 1280), + fix(Metalhead.EfficientNet, :b2), false, 1280), ("metalhead/efficientnet-b3", "EfficientNet-B3", - metalhead_loadfn(Metalhead.EfficientNet, :b3), false, 1280), + fix(Metalhead.EfficientNet, :b3), false, 1280), ("metalhead/efficientnet-b4", "EfficientNet-B4", - metalhead_loadfn(Metalhead.EfficientNet, :b4), false, 1280), + fix(Metalhead.EfficientNet, :b4), false, 1280), ("metalhead/efficientnet-b5", "EfficientNet-B5", - metalhead_loadfn(Metalhead.EfficientNet, :b5), false, 1280), + fix(Metalhead.EfficientNet, :b5), false, 1280), ("metalhead/efficientnet-b6", "EfficientNet-B6", - metalhead_loadfn(Metalhead.EfficientNet, :b6), false, 1280), + fix(Metalhead.EfficientNet, :b6), false, 1280), ("metalhead/efficientnet-b7", "EfficientNet-B7", - metalhead_loadfn(Metalhead.EfficientNet, :b7), false, 1280), + fix(Metalhead.EfficientNet, :b7), false, 1280), ("metalhead/efficientnet-b8", "EfficientNet-B8", - metalhead_loadfn(Metalhead.EfficientNet, :b8), false, 1280), + fix(Metalhead.EfficientNet, :b8), false, 1280), ] for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS _models[id] = (; - id, loadfn, description, - variants = cnn_variants(; hasweights, nfeatures), + id, description, + variants = metalheadvariants(loadfn, nfeatures), checkpoints = hasweights ? ["imagenet1k"] : String[], backend = :flux) end @@ -149,10 +129,10 @@ end @testset "Model variants" begin @testset "make_cnn_classifier" begin m = Metalhead.ResNet(18) - clf = make_cnn_classifier(m, ImageTensor{2}(3), FastAI.OneHotLabel{Int}(1:10)) + clf = make_cnn_classifier(m, ImageTensor{2}(3), FastAI.OneHotTensor{0, Int}(1:10)) @test Flux.outputsize(clf, (256, 256, 3, 1)) == (10, 1) - clf2 = make_cnn_classifier(m, ImageTensor{2}(5), FastAI.OneHotLabel{Int}(1:100)) + clf2 = make_cnn_classifier(m, ImageTensor{2}(5), FastAI.OneHotTensor{0, Int}(1:100)) @test Flux.outputsize(clf2, (256, 256, 5, 1)) == (100, 1) end From 85c88c36012eaaa59dab86e46a28ade0a71a42f2 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Fri, 3 Feb 2023 16:12:59 +0100 Subject: [PATCH 09/14] Model registry now has a field :loadfn A `loadfn([checkpoint])` holds the default loading function for a model. As a result, the :variants field no longer has to be populated. --- src/Registries/models.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index f236f8a06b..5f284dd030 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -96,9 +96,14 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) name = "Backend", default = :flux, description = "The backend deep learning framework that the model uses. The default is `:flux`."), + loadfn = Field(Any, + name = "Load function", + optional = false, + description = "A function `loadfn(checkpoint)` that loads a default version of the model, possibly with `checkpoint` weights.", + ), variants = Field(Vector{Pair{String, ModelVariant}}, name = "Variants", - optional = false, + default = Pair{String, ModelVariant}[], description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", formatfn = d -> join(first.(d), ", ")), checkpoints = Field(Vector{String}; @@ -123,14 +128,22 @@ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoin # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) - - pretrained && isnothing(checkpoint) && + if (pretrained && isnothing(checkpoint)) throw(NoCheckpointFoundError(checkpoints, checkpoint)) - variant = _findvariant(variants, variant, input, output) - isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + end - # Loading - return loadvariant(variant, input, output, checkpoint; kwargs...) + # If no variant is asked for, use the base model loading function that only takes + # care of the checkpoint. + if isnothing(variant) && input === Any && output === Any + return row.loadfn(checkpoint) + # If a variant is specified, either by name (through `variant`) or through block + # constraints `input` or `output`, try to find a matching variant. + # care of the checkpoint. + else + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + return loadvariant(variant, input, output, checkpoint; kwargs...) + end end # ### Errors @@ -197,17 +210,15 @@ struct MockVariant <: ModelVariant end compatibleblocks(variant::MockVariant) = variant.blocks -loadvariant(variant::MockVariant, x, y, ch) = (ch, variant.model) +loadvariant(variant::MockVariant, _, _, ch) = (ch, variant.model) @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() - push!(reg, (; - id = "test", - variants = ["base" => MockVariant(1, (Any, Any))])) + push!(reg, (; id = "test", loadfn = (checkpoint,) -> checkpoint)) - @test load(reg["test"]) == (nothing, 1) + @test load(reg["test"]) === nothing @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end From ea4e1538c6228a4c0b1aa9589d665805ad053841 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 4 Feb 2023 15:31:43 +0100 Subject: [PATCH 10/14] Use base `loadfn` for default model instead of variant --- FastVision/src/modelregistry.jl | 43 +++++++++++++-------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl index a7ca1e4c09..17001accc1 100644 --- a/FastVision/src/modelregistry.jl +++ b/FastVision/src/modelregistry.jl @@ -13,14 +13,6 @@ function loadvariant(v::MetalheadClassifierVariant, xblock, yblock, checkpoint; return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) end -struct MetalheadImageNetVariant <: ModelVariant - fn -end -compatibleblocks(::MetalheadImageNetVariant) = (ImageTensor{2}(3), FastAI.OneHotTensor{0, Int}(1:1000)) -function loadvariant(v::MetalheadImageNetVariant, xblock, yblock, checkpoint; kwargs...) - return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...) -end - struct MetalheadBackboneVariant <: ModelVariant fn nfeatures::Int @@ -38,7 +30,6 @@ end function metalheadvariants(modelfn, nfeatures) return [ - "imagenet1k" => MetalheadImageNetVariant(modelfn), "classifier" => MetalheadClassifierVariant(modelfn), "backbone" => MetalheadBackboneVariant(modelfn, nfeatures), ] @@ -118,31 +109,31 @@ const METALHEAD_CONFIGS = [ ("metalhead/efficientnet-b8", "EfficientNet-B8", fix(Metalhead.EfficientNet, :b8), false, 1280), ] + +metalheadloadfn(fn, hasweights) = function loadfn(ckpt; kwargs...) + hasweights ? fn(; pretrain = ckpt !== nothing, kwargs...) : fn(; kwargs...) +end + for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS _models[id] = (; id, description, + loadfn = metalheadloadfn(loadfn, hasweights), variants = metalheadvariants(loadfn, nfeatures), checkpoints = hasweights ? ["imagenet1k"] : String[], backend = :flux) end @testset "Model variants" begin - @testset "make_cnn_classifier" begin - m = Metalhead.ResNet(18) - clf = make_cnn_classifier(m, ImageTensor{2}(3), FastAI.OneHotTensor{0, Int}(1:10)) - @test Flux.outputsize(clf, (256, 256, 3, 1)) == (10, 1) - - clf2 = make_cnn_classifier(m, ImageTensor{2}(5), FastAI.OneHotTensor{0, Int}(1:100)) - @test Flux.outputsize(clf2, (256, 256, 5, 1)) == (100, 1) - end - - @testset "make_cnn_backbone" begin - m = Metalhead.ResNet(18) - clf = make_cnn_backbone(m, ImageTensor{2}(10), ConvFeatures{2}(512)) - @test Flux.outputsize(clf, (256, 256, 10, 1)) == (8, 8, 512, 1) - end end -@testset "Metalhead models" begin for id in models(id = "metalhead")[:, :id] - @test_nowarn load(models()[id]; variant = "backbone") -end end +@testset "Metalhead models" begin + @test_nowarn load(models()["metalhead/resnet18"]; variant = "backbone") + @test_nowarn load(models()["metalhead/resnet18"]; variant = "classifier") + @test_nowarn load(models()["metalhead/resnet18"]; output = FastAI.OneHotLabel) + @test_nowarn load(models()["metalhead/resnet18"]; input = FastVision.ImageTensor) + @test_throws FastAI.Registries.NoModelVariantFoundError load(models()["metalhead/resnet18"]; output = FastAI.Label) + #= for id in models(id = "metalhead")[:, :id] + @test_nowarn load(models()[id]; variant = "backbone") + end =# + +end From 16715ae701684a460b00713725ccb837a2e443b5 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 4 Feb 2023 15:32:20 +0100 Subject: [PATCH 11/14] Remove dead code --- FastVision/src/modelregistry.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/FastVision/src/modelregistry.jl b/FastVision/src/modelregistry.jl index 17001accc1..d8c8e94388 100644 --- a/FastVision/src/modelregistry.jl +++ b/FastVision/src/modelregistry.jl @@ -123,8 +123,6 @@ for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS backend = :flux) end -@testset "Model variants" begin -end @testset "Metalhead models" begin @test_nowarn load(models()["metalhead/resnet18"]; variant = "backbone") @@ -132,8 +130,4 @@ end @test_nowarn load(models()["metalhead/resnet18"]; output = FastAI.OneHotLabel) @test_nowarn load(models()["metalhead/resnet18"]; input = FastVision.ImageTensor) @test_throws FastAI.Registries.NoModelVariantFoundError load(models()["metalhead/resnet18"]; output = FastAI.Label) - #= for id in models(id = "metalhead")[:, :id] - @test_nowarn load(models()[id]; variant = "backbone") - end =# - end From c7d1073b136e513add7544ba5f23655c3394e198 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 4 Feb 2023 15:44:09 +0100 Subject: [PATCH 12/14] Remove unneeded deps --- FastVision/Project.toml | 2 -- FastVision/src/FastVision.jl | 2 -- 2 files changed, 4 deletions(-) diff --git a/FastVision/Project.toml b/FastVision/Project.toml index f7f551961e..be4b29d09a 100644 --- a/FastVision/Project.toml +++ b/FastVision/Project.toml @@ -19,8 +19,6 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 792dcb2e83..96578e1aef 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -66,11 +66,9 @@ import MakieCore: @recipe import MakieCore.Observables: @map import Metalhead: Metalhead import ProgressMeter: Progress, next! -using Setfield: @set import StaticArrays: SVector import Statistics: mean, std import UnicodePlots -using Random: Random using InlineTest using ShowCases From f3e9bb56c17472591424337d622edb57c7844bb1 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 4 Feb 2023 15:56:03 +0100 Subject: [PATCH 13/14] Add `ConvFeatures` block to represent bakbone outputs --- FastVision/src/FastVision.jl | 1 + FastVision/src/blocks/convfeatures.jl | 52 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 FastVision/src/blocks/convfeatures.jl diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 293e088e5d..4a1599117e 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -76,6 +76,7 @@ include("blocks/bounded.jl") include("blocks/image.jl") include("blocks/mask.jl") include("blocks/keypoints.jl") +include("blocks/convfeatures.jl") include("encodings/onehot.jl") include("encodings/imagepreprocessing.jl") diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl new file mode 100644 index 0000000000..8d9fcfd769 --- /dev/null +++ b/FastVision/src/blocks/convfeatures.jl @@ -0,0 +1,52 @@ + +""" + ConvFeatures{N}(n) <: Block + ConvFeatures(n, size) + +Block representing features from a convolutional neural network backbone +with `n` feature channels and `N` spatial dimensions. + +For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output +that is passed + +## Examples + +A feature block with 512 channels and variable spatial dimensions: + +```julia +FastVision.ConvFeatures{2}(512) +# or equivalently +FastVision.ConvFeatures(512, (:, :)) +``` + +A feature block with 512 channels and fixed spatial dimensions: + +```julia +FastVision.ConvFeatures(512, (4, 4)) +``` + +""" +struct ConvFeatures{N} <: Block + n::Int + size::NTuple{N, DimSize} +end + +ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N)) + +function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T} + M == N + 1 || return false + return checksize(block.size, size(a)[begin:N]) +end + +function FastAI.mockblock(block::ConvFeatures) + rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) +end + + +@testset "ConvFeatures [block]" begin + @test ConvFeatures(16, (:, :)) == ConvFeatures{2}(16) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 2, 2, 16)) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 3, 2, 16)) + @test checkblock(ConvFeatures(16, (2, 2)), rand(Float32, 2, 2, 16)) + @test !checkblock(ConvFeatures(16, (2, :)), rand(Float32, 3, 2, 16)) +end From 7d69cef6c9351620299e975ffa57215a4cc31fe4 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 4 Feb 2023 15:57:46 +0100 Subject: [PATCH 14/14] Finish docstring. --- FastVision/src/blocks/convfeatures.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl index 8d9fcfd769..d5c4352e52 100644 --- a/FastVision/src/blocks/convfeatures.jl +++ b/FastVision/src/blocks/convfeatures.jl @@ -7,7 +7,7 @@ Block representing features from a convolutional neural network backbone with `n` feature channels and `N` spatial dimensions. For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output -that is passed +that is passed further to the classifier head. ## Examples