diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 293e088e5d..4a1599117e 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -76,6 +76,7 @@ include("blocks/bounded.jl") include("blocks/image.jl") include("blocks/mask.jl") include("blocks/keypoints.jl") +include("blocks/convfeatures.jl") include("encodings/onehot.jl") include("encodings/imagepreprocessing.jl") diff --git a/FastVision/src/blocks/convfeatures.jl b/FastVision/src/blocks/convfeatures.jl new file mode 100644 index 0000000000..d5c4352e52 --- /dev/null +++ b/FastVision/src/blocks/convfeatures.jl @@ -0,0 +1,52 @@ + +""" + ConvFeatures{N}(n) <: Block + ConvFeatures(n, size) + +Block representing features from a convolutional neural network backbone +with `n` feature channels and `N` spatial dimensions. + +For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output +that is passed further to the classifier head. + +## Examples + +A feature block with 512 channels and variable spatial dimensions: + +```julia +FastVision.ConvFeatures{2}(512) +# or equivalently +FastVision.ConvFeatures(512, (:, :)) +``` + +A feature block with 512 channels and fixed spatial dimensions: + +```julia +FastVision.ConvFeatures(512, (4, 4)) +``` + +""" +struct ConvFeatures{N} <: Block + n::Int + size::NTuple{N, DimSize} +end + +ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N)) + +function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T} + M == N + 1 || return false + return checksize(block.size, size(a)[begin:N]) +end + +function FastAI.mockblock(block::ConvFeatures) + rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n) +end + + +@testset "ConvFeatures [block]" begin + @test ConvFeatures(16, (:, :)) == ConvFeatures{2}(16) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 2, 2, 16)) + @test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 3, 2, 16)) + @test checkblock(ConvFeatures(16, (2, 2)), rand(Float32, 2, 2, 16)) + @test !checkblock(ConvFeatures(16, (2, :)), rand(Float32, 3, 2, 16)) +end