diff --git a/FastTabular/Project.toml b/FastTabular/Project.toml index 949ebf8a47..fb32287f25 100644 --- a/FastTabular/Project.toml +++ b/FastTabular/Project.toml @@ -11,6 +11,8 @@ FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f" FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" +Invariants = "115d0255-0791-41d2-b533-80bc4cbe6c10" +InvariantsCore = "69cbffe8-09de-43b1-81db-93034495284f" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" diff --git a/FastTabular/src/encodings/tabularpreprocessing.jl b/FastTabular/src/encodings/tabularpreprocessing.jl index e3be1473bb..d4f868ce44 100644 --- a/FastTabular/src/encodings/tabularpreprocessing.jl +++ b/FastTabular/src/encodings/tabularpreprocessing.jl @@ -14,6 +14,13 @@ function EncodedTableRow(catcols, contcols, categorydict) EncodedTableRow{length(catcols), length(contcols)}(catcols, contcols, categorydict) end +function mockblock(block::EncodedTableRow) + b = TableRow(block.catcols, block.contcols, block.categorydict) + obs = mockblock(b) + enc = setup(TabularPreprocessing, b, TableDataset(map(x -> [x], obs))) + return encode(enc, Validation(), b, obs) +end + function checkblock(::EncodedTableRow{M, N}, x::Tuple{Vector, Vector}) where {M, N} length(x[1]) == M && length(x[2]) == N end diff --git a/FastVision/Project.toml b/FastVision/Project.toml index 2990f061ee..1f09868825 100644 --- a/FastVision/Project.toml +++ b/FastVision/Project.toml @@ -15,6 +15,7 @@ ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" IndirectArrays = "9b13fd28-a010-5f03-acff-a1bbcff69959" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" +Invariants = "115d0255-0791-41d2-b533-80bc4cbe6c10" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" diff --git a/FastVision/src/FastVision.jl b/FastVision/src/FastVision.jl index 293e088e5d..2f3981e0e3 100644 --- a/FastVision/src/FastVision.jl +++ b/FastVision/src/FastVision.jl @@ -58,11 +58,13 @@ import DataAugmentation: apply, Identity, ToEltype, ImageToTensor, Normalize, AdjustBrightness, AdjustContrast, Maybe, FlipX, FlipY, WarpAffine, Rotate, Zoom, ResizePadDivisible, itemdata +import Invariants: Invariants, md, invariant, check import ImageInTerminal import IndirectArrays: IndirectArray import MakieCore import MakieCore: @recipe import MakieCore.Observables: @map + import ProgressMeter: Progress, next! import StaticArrays: SVector import Statistics: mean, std diff --git a/FastVision/src/blocks/bounded.jl b/FastVision/src/blocks/bounded.jl index 6df09cd485..6ef71c1c4e 100644 --- a/FastVision/src/blocks/bounded.jl +++ b/FastVision/src/blocks/bounded.jl @@ -49,9 +49,6 @@ will update the bounds: block = Image{2}() Bounded(Bounded(block, (16, 16)), (8, 8)) == Bounded(block, (8, 8)) ``` - - - """ struct Bounded{N, B <: AbstractBlock} <: WrapperBlock block::B @@ -67,6 +64,18 @@ function checkblock(bounded::Bounded{N}, a::AbstractArray{N}) where {N} return checksize(bounded.size, size(a)) && checkblock(parent(bounded), a) end + +function FastAI.invariant_checkblock(block::Bounded{N}; blockvar = "block", obsvar = "obs", kwargs...) where N + return invariant( + FastAI.__inv_checkblock_title(block, blockvar, obsvar), + [ + FastAI.invariant_checkblock(parent(block)), + ]; + kwargs... + ) +end + + @testset "Bounded [block, wrapper]" begin @test_nowarn Bounded(Image{2}(), (16, 16)) bounded = Bounded(Image{2}(), (16, 16)) @@ -75,3 +84,12 @@ end # composition @test Bounded(bounded, (16, 16)) == bounded end + +@testset "checksize" begin + @test checksize((10, 1), (10, 1)) + @test !checksize((100, 1), (10, 1)) + @test checksize((:, :, :), (1, 2, 3)) + @test !checksize((:, :, :), (1, 2)) + @test checksize((10, :, 1), (10, 20, 1)) + @test !checksize((10, :, 2), (10, 20, 1)) +end diff --git a/FastVision/src/blocks/image.jl b/FastVision/src/blocks/image.jl index b2e8096aeb..bbbb78966a 100644 --- a/FastVision/src/blocks/image.jl +++ b/FastVision/src/blocks/image.jl @@ -55,6 +55,65 @@ setup(::Type{Image}, data) = Image{ndims(getobs(data, 1))}() # Visualization -function showblock!(io, ::ShowText, block::Image{2}, obs) +showblock!(io, ::ShowText, block::Image{2}, obs::AbstractMatrix{<:Colorant}) = ImageInTerminal.imshow(io, obs) +showblock!(io, ::ShowText, block::Image{2}, obs::AbstractMatrix{<:Real}) = + ImageInTerminal.imshow(io, colorview(Gray, obs)) + + +function FastAI.invariant_checkblock(block::Image{N}; blockvar = "block", obsvar = "obs", kwargs...) where N + return invariant( + FastAI.__inv_checkblock_title(block, blockvar, obsvar), + [ + invariant("`$obsvar` is an `AbstractArray`", + description = md("`$obsvar` should be of type `AbstractArray`.")) do obs + if !(obs isa AbstractArray) + return "Instead, got invalid type `$(nameof(typeof(obs)))`." |> md + end + end, + invariant("`$obsvar` is `$N`-dimensional") do obs + if ndims(obs) != N + return "Instead, got invalid dimensionality `$N`." |> md + end + end, + invariant("`$obsvar` should have a color or numerical element type") do obs + if !((eltype(obs) <: Color) ||(eltype(obs) <: Real)) + return "Instead, got invalid element type `$(eltype(obs))`." |> md + end + end, + ]; + kwargs... + ) end + +#= + +function isblockinvariant(block::Image{N}; obsvar = "data", blockvar = "block") where {N} + return SequenceInvariant( + [ + BooleanInvariant( + obs -> obs isa AbstractArray, + name = "Image data is an array", + messagefn = obs -> """Expected `$obsvar` to be a subtype of + `AbstractArray`, but instead got type `$(typeof(obs))`.""", + ), + BooleanInvariant( + obs -> ndims(obs) == N, + name = "Image data is `$N`-dimensional", + messagefn = obs -> """Expected `$obsvar` to be an `$N`-dimensional array, + but instead got a `$(ndims(obs))`-dimensional array.""", + ), + BooleanInvariant( + obs -> eltype(obs) <: Color || eltype(obs) <: Number, + name = "Image data has a color or numerical type.", + messagefn = obs -> """Expected `$obsvar` to have an element type that is a + color (`eltype($obsvar) <: Color`) or a number (`eltype($obsvar) + <: Color`), but instead found `eltype($obsvar) == $(eltype(obs)).` + """ + ) + ], + "`$obsvar` is a valid `$(typeof(block))`", + "" + ) +end +=# diff --git a/FastVision/src/blocks/mask.jl b/FastVision/src/blocks/mask.jl index 7e85473231..e3973264d4 100644 --- a/FastVision/src/blocks/mask.jl +++ b/FastVision/src/blocks/mask.jl @@ -15,10 +15,41 @@ function checkblock(block::Mask{N, T}, a::AbstractArray{T, N}) where {N, T} return all(map(x -> x ∈ block.classes, a)) end -function mockblock(mask::Mask{N, T}) where {N, T} - rand(mask.classes, ntuple(_ -> 16, N))::AbstractArray{T, N} +mockblock(mask::Mask{N, T}) where {N, T} = rand(mask.classes, ntuple(_ -> 16, N))::AbstractArray{T, N} + +function FastAI.invariant_checkblock(block::Mask{N}; blockvar = "block", obsvar = "obs", kwargs...) where N + return invariant( + FastAI.__inv_checkblock_title(block, blockvar, obsvar), + [ + invariant("`$obsvar` is an `AbstractArray`", + description = md("`$obsvar` should be of type `AbstractArray`.")) do obs + if !(obs isa AbstractArray) + return "Instead, got invalid type `$(nameof(typeof(obs)))`." |> md + end + end, + invariant("`$obsvar` is `$N`-dimensional") do obs + if ndims(obs) != N + return "Instead, got invalid dimensionality `$N`." |> md + end + end, + invariant("All elements are valid labels") do obs + valid = ∈(block.classes).(obs) + if !(all(valid)) + unknown = unique(obs[valid .== false]) + return md("""`$obsvar` should contain only valid labels, + i.e. `∀ y ∈ $obsvar: y ∈ $blockvar.classes`, but `$obsvar` includes + unknown labels: `$(sprint(show, unknown))`. + + Valid classes are: + `$(sprint(show, block.classes, context=:limit => true))`""") + end + end, + ]; + kwargs... + ) end + # Visualization function showblock!(io, ::ShowText, block::Mask{2}, obs) diff --git a/FastVision/src/encodings/imagepreprocessing.jl b/FastVision/src/encodings/imagepreprocessing.jl index 495c0dfbbe..910ed00317 100644 --- a/FastVision/src/encodings/imagepreprocessing.jl +++ b/FastVision/src/encodings/imagepreprocessing.jl @@ -21,7 +21,19 @@ function checkblock(block::ImageTensor{N}, a::AbstractArray{T, M}) where {M, N, return (N + 1 == M) && (size(a, M) == block.nchannels) end -Base.summary(io::IO, ::ImageTensor{N}) where {N} = print(io, "ImageTensor{$N}") +FastAI.blockname(io::IO, ::ImageTensor{N}) where {N} = "ImageTensor{$N}" + +function FastAI.mockblock(block::ImageTensor{N}) where {N} + return randn(Float32, ntuple(n -> n == N+1 ? block.nchannels : 16, N + 1)) +end + +function FastAI.invariant_checkblock(block::ImageTensor{N}; blockvar = "block", obsvar = "obs", kwargs...) where N + return invariant( + Invariants.hastype_invariant(AbstractArray{<:Number, N+1}), + title = FastAI.__inv_checkblock_title(block, blockvar, obsvar); + kwargs... + ) +end """ ImagePreprocessing([; kwargs...]) <: Encoding diff --git a/FastVision/src/encodings/keypointpreprocessing.jl b/FastVision/src/encodings/keypointpreprocessing.jl index 60965db8bd..269f984204 100644 --- a/FastVision/src/encodings/keypointpreprocessing.jl +++ b/FastVision/src/encodings/keypointpreprocessing.jl @@ -8,9 +8,13 @@ struct KeypointTensor{N, T, M} <: Block sz::NTuple{M, Int} end -mockblock(block::KeypointTensor{N}) where {N} = rand(SVector{N, Float32}, block.sz) -function checkblock(block::KeypointTensor{N, T}, obs::AbstractArray{T}) where {N, T} - return length(obs) == (prod(block.sz) * N) +function mockblock(block::KeypointTensor{N}) where {N} + enc = KeypointPreprocessing((16, 16)) + b = Keypoints{N}(block.sz) + return encode(enc, Validation(), b, mockblock(b)) +end +function checkblock(block::KeypointTensor{N, T, M}, obs::AbstractArray{U, M}) where {N, T, U, M} + return length(obs) == prod(block.sz) * N end """ diff --git a/FastVision/src/tasks/segmentation.jl b/FastVision/src/tasks/segmentation.jl index f018e2e696..9a4e13eac8 100644 --- a/FastVision/src/tasks/segmentation.jl +++ b/FastVision/src/tasks/segmentation.jl @@ -72,6 +72,8 @@ _tasks["imagesegmentation"] = (id = "vision/imagesegmentation", @testset "taskdataloaders" begin data, blocks = load(datarecipes()["camvid_tiny"]) traindl, _ = taskdataloaders(data, ImageSegmentation(blocks)) + # Iterate once so that precompilation does not print when testing + for batch in traindl end @test_nowarn for batch in traindl end end diff --git a/Project.toml b/Project.toml index b280c164cf..ea37b6847c 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,8 @@ FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" +Invariants = "115d0255-0791-41d2-b533-80bc4cbe6c10" +InvariantsCore = "69cbffe8-09de-43b1-81db-93034495284f" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/FastAI.jl b/src/FastAI.jl index 42cbbff4cb..c1a6bbb57c 100644 --- a/src/FastAI.jl +++ b/src/FastAI.jl @@ -13,6 +13,7 @@ using Flux.Optimise import Flux.Optimise: apply!, Optimiser, WeightDecay using FluxTraining: Learner, handle using FluxTraining.Events +import Invariants: Invariants, invariant, check, check_throw, md using JLD2: jldsave, jldopen using Markdown using PrettyTables @@ -79,11 +80,13 @@ include("serialization.jl") # submodules include("datasets/Datasets.jl") -@reexport using .Datasets +using .Datasets include("Registries/Registries.jl") @reexport using .Registries +include("invariants.jl") + export # submodules Datasets, diff --git a/src/blocks/continuous.jl b/src/blocks/continuous.jl index be7a6aed77..001df9a356 100644 --- a/src/blocks/continuous.jl +++ b/src/blocks/continuous.jl @@ -19,3 +19,35 @@ function blocklossfn(outblock::Continuous, yblock::Continuous) outblock.size == yblock.size || error("Sizes of $outblock and $yblock differ!") return Flux.Losses.mse end + + +function invariant_checkblock(block::Continuous; blockvar = "block", obsvar = "obs", kwargs...) + return invariant( + __inv_checkblock_title(block, blockvar, obsvar), + [ + Invariants.hastype_invariant(AbstractVector; var = obsvar), + invariant("length(`$obsvar`) should be $(block.size)") do obs + if !(length(obs) == block.size) + return """`$obsvar` should have `$(block.size)` features, instead + found a vector with `$(length(obs))` features.""" |> md + end + end, + Invariants.hastype_invariant( + Number, + title = "`eltype($obsvar)` should be a subtype of number", + inputfn = eltype, + ), + ]; + kwargs... + ) +end + + +@testset "Continuous [block]" begin + inv = invariant_checkblock(Continuous(5)) + + @test check(Bool, inv, zeros(5)) + @test !check(Bool, inv, "hi") + @test !check(Bool, inv, ["hi"]) + @test !check(Bool, inv, [5]) +end diff --git a/src/blocks/label.jl b/src/blocks/label.jl index 41f23c368a..fa15f2636a 100644 --- a/src/blocks/label.jl +++ b/src/blocks/label.jl @@ -35,23 +35,46 @@ mockblock(label::Label{T}) where {T} = rand(label.classes)::T setup(::Type{Label}, data) = Label(unique(eachobs(data))) + +function invariant_checkblock(block::Label; blockvar = "block", obsvar = "obs", kwargs...) + inv = invariant( + __inv_checkblock_title(block, blockvar, obsvar) + ) do obs + if !(obs ∈ block.classes) + return "\n" * ("""`$obsvar` should be a valid label, i.e. one of + `$blockvar.classes = $(sprint(show, block.classes, context=:limit => true))`. + Instead, got invalid value `$(sprint(show, obs))`. + """ |> Invariants.md) + end + end + invariant(inv; kwargs...) +end + + """ - LabelMulti(classes) + LabelMulti(classes) <: Block setup(LabelMulti, data) -`Block` for a categorical label in a multi-class context. -`data` is valid for `Label(classes)` if `data ∈ classes`. +`Block` for a categorical label in a multi-class context where multiple +labels can be associated for an input. Each label must be in `classes`. +For example, for a block `LabelMulti([1, 2, 3])`, `[1, 2]` is a valid +observation, unlike `[0, 2]` (invalid label) or `1` (not a vector of +labels). + +Use [`is_block_obs`](#) to make sure you have valid observations. ## Examples +An observation can contain all or none of the listed classes: + ```julia -block = Label(["cat", "dog"]) # an observation can be either "cat" or "dog" -@test FastAI.checkblock(block, "cat") -@test !(FastAI.checkblock(block, "horsey")) +block = LabelMulti(["cat", "dog", "person"]) +@test FastAI.checkblock(block, ["cat", "person"]) +@test !(FastAI.checkblock(block, ["horsey"])) ``` -You can use `setup` to create a `Label` instance from a data container containing -possible classes: +You can use `setup` to create a `Label` instance from a data container +containing possible classes: ```julia targets = ["cat", "dog", "dog", "dog", "cat", "dog"] @@ -67,28 +90,74 @@ function checkblock(label::LabelMulti{T}, v::AbstractVector{T}) where {T} return all(map(x -> x ∈ label.classes, v)) end -function mockblock(label::LabelMulti) - unique([rand(label.classes) for _ in 1:rand(1:length(label.classes))]) -end +mockblock(label::LabelMulti) = + unique([rand(label.classes) for _ = 1:rand(1:length(label.classes))]) + setup(::Type{LabelMulti}, data) = LabelMulti(unique(eachobs(data))) -InlineTest.@testset "Label [block]" begin + +Base.summary(io::IO, ::LabelMulti{T}) where {T} = print(io, "LabelMulti{", T, "}") + + + +function invariant_checkblock(block::LabelMulti; blockvar = "block", obsvar = "obs", kwargs...) + return invariant( + __inv_checkblock_title(block, blockvar, obsvar), + [ + invariant("`$obsvar` is an `AbstractVector`", + description = md("`$obsvar` should be of type `AbstractVector`.")) do obs + if !(obs isa AbstractVector) + return md("Instead, got invalid type `$(typeof(obs))`.") + end + end, + invariant("All elements are valid labels") do obs + valid = ∈(block.classes).(obs) + if !(all(valid)) + unknown = unique(obs[valid .== false]) + return md("""`$obsvar` should contain only valid labels, + i.e. `∀ y ∈ $obsvar: y ∈ $blockvar.classes`, but `$obsvar` includes + unknown labels: `$(sprint(show, unknown))`. + + Valid classes are: + `$(sprint(show, block.classes, context=:limit => true))`""") + end + end + ]; kwargs... + ) +end + + +# ## Tests + +@testset "Label [block]" begin block = Label(["cat", "dog"]) - InlineTest.@test FastAI.checkblock(block, "cat") - InlineTest.@test !(FastAI.checkblock(block, "horsey")) + @test FastAI.checkblock(block, "cat") + @test !(FastAI.checkblock(block, "horsey")) targets = ["cat", "dog", "dog", "dog", "cat", "dog"] block = setup(Label, targets) - InlineTest.@test block.classes == ["cat", "dog"] + @test block.classes == ["cat", "dog"] + + inv = invariant_checkblock(Label([1, 2, 3])) + @test check(Bool, inv, 1) + @test !(check(Bool, inv, 0)) end -InlineTest.@testset "LabelMulti [block]" begin + +@testset "LabelMulti [block]" begin block = LabelMulti(["cat", "dog"]) - InlineTest.@test FastAI.checkblock(block, ["cat"]) - InlineTest.@test !(FastAI.checkblock(block, ["horsey", "cat"])) + @test FastAI.checkblock(block, ["cat"]) + @test !(FastAI.checkblock(block, ["horsey", "cat"])) targets = ["cat", "dog", "dog", "dog", "cat", "dog"] block = setup(LabelMulti, targets) - InlineTest.@test block.classes == ["cat", "dog"] + @test block.classes == ["cat", "dog"] + + inv = invariant_checkblock(block) + @test_nowarn check(Exception, inv, ["cat", "dog"]) + @test check(Bool, inv, []) + @test !(check(Bool, inv, "cat")) + @test !(check(Bool, inv, ["mouse"])) + @test !(check(Bool, inv, ["mouse", "cat"])) end diff --git a/src/blocks/many.jl b/src/blocks/many.jl index 20bd69cd24..e96176d49c 100644 --- a/src/blocks/many.jl +++ b/src/blocks/many.jl @@ -35,5 +35,6 @@ end @testset "Many [block]" begin enc = OneHot() block = Many(Label(1:10)) + @test encodedblock(enc, block) isa Many{<:OneHotTensor} FastAI.testencoding(enc, block) end diff --git a/src/datablock/block.jl b/src/datablock/block.jl index 462d3ece10..eaad5c40d0 100644 --- a/src/datablock/block.jl +++ b/src/datablock/block.jl @@ -123,6 +123,71 @@ typify(T::Type) = T typify(t::Tuple) = Tuple{map(typify, t)...} typify(block::FastAI.AbstractBlock) = typeof(block) + +# ## Invariants +# +# Invariants allow specifying properties that an instance of a data for a block +# should have in more detail and such that actionable error messages can be given. + +""" + invariant_checkblock(block; kwargs...) + invariant_checkblock(blocks; kwargs...) + +Create an `Invariants.Invariant` that can be used to check whether an +observation is a valid instance of `block`. This should always agree +with `checkblock` (i.e. `checkblock(block, obs)` implies that +`check(invariant_checkblock(block), obs)`). The invariant can however +be used to give much more detailed information about the problem and +be used to throw helpful error messages from functions that depend +on these properties. +""" +function invariant_checkblock end + + +# If `invariant_checkblock` is not implemented for a block, default to +# checking that `checkblock` returns `true`. + +function invariant_checkblock(block::AbstractBlock; obsvar = "obs", blockvar = "block", kwargs...) + return invariant(__inv_checkblock_title(block, blockvar, obsvar); kwargs...) do obs + if !checkblock(block, obs) + """Expected `$obsvar` to be a valid observation for block `$(blockname(block))`, + but `checkblock($blockvar, $obsvar)` returned `false`. + This probably means that `$obsvar` is not a valid instance of the + block. Check `?$(__typename_qualified(block))` for more information on + the block and what data is valid. + """ |> md + end + end +end + +__typename_qualified(::T) where T = "$(string(parentmodule(T))).$(nameof(T))" +__inv_checkblock_title(b, bname, oname) = "`$oname` is a valid observation for `$(bname) <: $(blockname(b))`" +# For tuples of blocks, the invariant is composed of the individuals' blocks +# invariants, passing only if all the child invariants pass. + +function invariant_checkblock(blocks::Tuple; obsvar = "obss", blockvar = "blocks", kwargs...) + return invariant( + "`$obsvar` are valid instances of blocks `$blockvar`", + [ + invariant( + invariant_checkblock( + blocks[i], + obsvar = "$obsvar[$i]", + blockvar = "$blockvar[$i]" + ), + inputfn = obss -> obss[i] + ) for (i, block) in enumerate(blocks) + ], + description = md("""The given observations `obss` should be valid instances of the + blocks `$blockvar`. Since `$blockvar` is a tuple of blocks, each observation + `$obsvar[i]` should be a valid instance of the block `$blockvar[i]`. + See `?Block` for more background on blocks."""); + kwargs... + ) +end + +invariant_checkblock((title, block)::Pair; kwargs...) = invariant_checkblock(block; kwargs...) + """ blockname(block) @@ -131,3 +196,13 @@ and other diagrams. """ blockname(block::Block) = string(nameof(typeof(block))) blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")" + + +@testset "block invariants" begin + @testset "is_block" begin + @test is_block(Bool, Label(1:10), 1) + @test !is_block(Bool, Label(1:10), 0) + @test is_block(Bool, OneHotLabel{Float32}([1, 2]), [0, 1]) + @test is_block(Bool, (Label([1]), Label([2])), (1, 2)) + end +end diff --git a/src/datablock/encoding.jl b/src/datablock/encoding.jl index cc6d789ab1..e858ebd555 100644 --- a/src/datablock/encoding.jl +++ b/src/datablock/encoding.jl @@ -3,7 +3,18 @@ abstract type Encoding Transformation of `Block`s. Can encode some `Block`s ([`encode`]), and optionally -decode them [`decode`] +decode them [`decode`]. `Encoding`s describe data transformations that are applied +to `Block` data. Together `Encoding`s and `Block`s, are used to construct complex +data preprocessing pipelines for training loops. + +Encodings operate on two levels: + +- On the value level, an encoding transforms an observation. +- On the `Block` level, applying an encoding to a block tells you what the output + block is. For example, the [`OneHot`](#) encoding turns a [`Label`](#) block + into a [`OneHotLabel`](#) block. + + By introspecting the block-level transformation ## Interface @@ -29,6 +40,12 @@ decode them [`decode`] """ abstract type Encoding end +invertible(enc::Encoding, block::AbstractBlock) = + !isnothing(decodedblock(enc, encodedblockfilled(enc, block))) + +encodingname(::E) where E<:Encoding = nameof(E) +encodingname(t::Tuple) = map(encodingname, t) + """ fillblock(inblocks, outblocks) @@ -215,6 +232,11 @@ Performs some tests that the encoding interface is set up properly for and that the block is identical to `block` """ function testencoding(encoding, block, obs = mockblock(block)) + Test.@testset "Encoding `$(typeof(encoding))` for block `$block`" begin + inv = invariant_encoding(encoding, block) + @test_nowarn inv(Exception, obs) + end + return Test.@testset "Encoding `$(typeof(encoding))` for block `$block`" begin # Test that `obs` is a valid instance of `block` Test.@test checkblock(block, obs) @@ -244,3 +266,89 @@ function testencoding(encoding, block, obs = mockblock(block)) end end end + + +function invariant_encoding(encoding, block; context = Validation(), encvar = "encoding", blockvar = "block", obsvar = "obs") + B = blockname(block) + E = encodingname(encoding) + + encobs(obs) = encode(encoding, context, block, obs) + encblock() = encodedblock(encoding, block) + encblockfilled() = encodedblockfilled(encoding, block) + decblock() = decodedblock(encoding, encblockfilled()) + decblockfilled() = decodedblockfilled(encoding, encblockfilled()) + decobs(obs) = decode(encoding, context, encblockfilled(), encobs(obs)) + + return invariant( + "Encoding `$E` is implemented for block `$B`", + [ + invariant_checkblock(block; blockvar, obsvar, description=""" + Before checking that the encoding is properly implemented for the block, + we need to check that the observation `$obsvar` is a valid instance of + `$blockvar <: $B.` + + """ |> md), + invariant("`$encvar <: $E` is implemented for `$blockvar <: $B`") do _ + if isnothing(encblock()) + return """Expected `encodedblock($encvar::$E, $blockvar::$B)` to return a block, + indicating that the encoding does transform observations for block + `$blockvar`. Instead, it returned `nothing` which indicates that the + encoding does not transform observations of block `$B`. + + If the encoding should modify the block, this may mean that a method + for `FastAI.encodedblock` is missing. To fix this, implement the following + method, returning a block from it: + + ```julia + FastAI.encodedblock(::$E, ::$B) + ``` + """ |> md + end + end, + invariant(invariant_checkblock(encblockfilled(); + blockvar = "enc$blockvar", obsvar = "enc$obsvar"); + title = "Encoded `$obsvar` is a valid instance of encoded `$blockvar`", + inputfn = encobs, + description = """The encoded observation + `encobs = encode($encvar, $context, $blockvar, $obsvar)` + should be a valid observation for the encoded block + `enc$blockvar = encodedblock($encvar, $blockvar)`. + """ |> md), + invariant( + "If `$encvar <: $E` is invertible, decoding is implemented", + [ + invariant("`$encvar <: $E` is not invertible") do _ + if invertible(encoding, block) + return "The encoding *is* invertible." |> md + end + end, + invariant("Decoding is implemented") do obs + if isnothing(decblock()) + return """ + `decodedblock(encoding, encodedblock(encoding, block))` returned + `nothing`, indicating that the encoding `$encvar <: $E` does not implement + a decoding step. + + This can mean that either the encoding is not invertible, or `decodedblock` + was not implemented for block `$blockvar <: $B`. To fix this, implement EITHER + + - `decodedblock(::$E, ::$B)` and return a non-`nothing` block value from it; OR + - `invertible(::$E, ::$B) = false` if the encoding is not invertible. + """ |> md + end + end, + invariant_checkblock( + decblockfilled(), + inputfn = decobs, + title = "Decoded `encobs` is a valid instance of `$blockvar <: $B`", + description="""Decoding the encoded observation should return a valid observation.""") + ], + any, + ) + ], + description = """ + This invariant checks that the encoding `$encvar <: $E` is properly implemented + for `$blockvar <: $B`. Type `?FastAI.Encoding` to get an overview of the + interface for `Encoding`s. + """ |> md) +end diff --git a/src/datablock/wrappers.jl b/src/datablock/wrappers.jl index b30ab56126..f44d6842e5 100644 --- a/src/datablock/wrappers.jl +++ b/src/datablock/wrappers.jl @@ -1,5 +1,14 @@ # # Wrapper blocks +""" + abstract type WrapperBlock + +Supertype for blocks that "wrap" an existing block, inheriting its +functionality, allowing you to override just parts of its interface. + +For examples of `WrapperBlock`, see [`Bounded`](#) + +""" abstract type WrapperBlock <: AbstractBlock end Base.parent(w::WrapperBlock) = w.block @@ -7,10 +16,11 @@ Base.parent(b::Block) = b wrapped(w::WrapperBlock) = wrapped(parent(w)) wrapped(b::Block) = b function setwrapped(w::WrapperBlock, b) + # TODO: make recursive return Setfield.@set w.block = b end -mockblock(w::WrapperBlock) = mockblock(wrapped(w)) -checkblock(w::WrapperBlock, obs) = checkblock(wrapped(w), obs) +mockblock(w::WrapperBlock) = mockblock(parent(w)) +checkblock(w::WrapperBlock, obs) = checkblock(parent(w), obs) function blockname(wrapper::WrapperBlock) w = string(nameof(typeof(wrapper))) @@ -72,6 +82,8 @@ struct PropagateNever <: PropagateWrapper end propagate(::PropagateNever, _, _) = false propagatedecode(::PropagateNever, _, _) = false +# If not overwritten, encodings are applied to the wrapped block +propagatewrapper(::WrapperBlock) = PropagateAlways() """ struct PropagateSameBlock <: PropagateWrapper end diff --git a/src/encodings/only.jl b/src/encodings/only.jl index cb5b4699ac..15dae073aa 100644 --- a/src/encodings/only.jl +++ b/src/encodings/only.jl @@ -31,6 +31,8 @@ struct Only{E <: Encoding} <: StatefulEncoding encoding::E end +encodingname(only::Only) = "Only{$(encodingname(only.encoding))}" + function Only(name::Symbol, encoding::Encoding) return Only(Named{name}, encoding) end diff --git a/src/interpretation/showinterpretable.jl b/src/interpretation/showinterpretable.jl index f216307cf0..04400e78c8 100644 --- a/src/interpretation/showinterpretable.jl +++ b/src/interpretation/showinterpretable.jl @@ -21,6 +21,7 @@ showblockinterpretable(ShowText(), encodings, block, x) # will decode to an `Im """ function showblockinterpretable(backend::ShowBackend, encodings, block, obs) + invariant_checkblock(block)(Exception, obs) res = decodewhile(block -> !isshowable(backend, block), encodings, Validation(), @@ -53,7 +54,7 @@ end # Helpers function isshowable(backend::S, block::B) where {S <: ShowBackend, B <: AbstractBlock} - hasmethod(FastAI.showblock!, (Any, S, B, Any)) + hasmethod(FastAI.showblock!, (Any, S, B, typeof(mockblock(block)))) end """ diff --git a/src/invariants.jl b/src/invariants.jl new file mode 100644 index 0000000000..eb9bb7784b --- /dev/null +++ b/src/invariants.jl @@ -0,0 +1,192 @@ +#= +This file implements functions that allow checking common interfaces in the package using +[Invariants.jl](https://github.com/lorenzoh/Invariants.jl). + +=# + +#= +`is_block_obs(block, obs` chekcs that a value `obs` is a valid observation for a block +`block`. +=# + +""" + is_block(block, obs) + +Check whether `obs` is a valid observation for `block` and give +detailed output. + +## Examples + +{cell=is_block, show=false resultshow=false} +```julia +using FastAI +``` + +Basic check with a valid observation: + +{cell=is_block} +```julia +FastAI.is_block(LabelMulti([1, 2, 3]), [1]) +``` + +An invalid observation will show an error, detailing why the observatio is not valid for +the block: + +{cell=is_block} +```julia +FastAI.is_block(LabelMulti([1, 2, 3]), [2, "invalid:("])) +``` + +As a tuple of blocks is also a valid block, we can check that too. For example, a sample for +a supervised learning task is usually a tuple `(input, label)`. Using `is_block_obs` to +check observations for tuples of blocks (or nested tuples) details which specific +observations are valid. + +{cell=is_block} +```julia +using FastAI.Vision: RGB +FastAI.is_block(( + Image{2}(), # input block + Label(["cat", "dog"]), # target block + ), ( + rand(RGB, 100, 100), # valid input + "mouse", # invalid label + )) +``` + +## Extending + +To extend this check to work on a new block type `B`, implement +[`invariant_checkblock`](#)`(::B)`. For help implementing invariants see the documentation +of [Invariants.jl](https://github.com/lorenzoh/Invariants.jl). + +If `invariant_checkblock` is not implemented for `B`, it will fall back to checking +[`checkblock`](#) which is correct, but doesn't yield helpful output. +""" +function is_block(block, obs; kwargs...) + inv = invariant_checkblock(block; kwargs...) + check(inv, obs) +end + +function is_block(::Type{Bool}, block, obs; kwargs...) + inv = invariant_checkblock(block; kwargs...) + check(Bool, inv, obs) +end + +#= +`is_data(data)` checks that `data` is a valid data container. +=# + + +""" + is_data(data) + is_data(Bool, data)::Bool + +Check that `data` implements the data container interface and give detailed info on missing +functionality if not. + +Pass `Bool` as a first argument to return a `Bool`. + +""" +function is_data(data; kwargs...) + inv = invariant_datacontainer(; kwargs...) + return check(inv, data) +end + + +""" + is_data(data, block) + is_data(Bool, data, block)::Bool + +Check that `data` implements the data container interface and its observations are valid +instances of `block`, giving detailed errors if not. + +Pass `Bool` as a first argument to return a `Bool`. +""" +function is_data(data, block; kwargs...) + inv = invariant_datacontainer_block(block; kwargs...) + return check(inv, data) +end + +is_data(::Type{Bool}, args...; kwargs...) = convert(Bool, is_data(args...; kwargs...)) + + +function invariant_datacontainer(; var = :data) + invariant( + "`$var` implements the data container interface", + [ + __invariant_numobs(; var), + invariant("`$var` contains at least one observation") do data + n = numobs(data) + if n <= 0 + return "Instead, got a data container with $n observations." + end + end, + __invariant_getobs(; var), + ], + all; + description="""A data container stores observations and allows (1) getting + the number of observation and (2) loading an observation. + See [the tutorial](/documents/docs/tutorials/data_containers.md) for more + information.""" |> md) +end + +function invariant_datacontainer_block(block; + datavar = "data", blockvar = "data", obsvar = "obs") + return invariant( + "`$datavar` is a data container with valid observations for block `$(blockname(block))`", + [ + + invariant_datacontainer(; var = datavar), + invariant( + invariant_checkblock(block; blockvar = blockvar, obsvar = obsvar); + inputfn = data -> getobs(data, 1) + ) + ], + all) +end + +__invariant_getobs(; var = :data) = invariant( + "`$var` implements the `getobs` interface", + [ + Invariants.hasmethod_invariant(Base.getindex, :data, :idx => 1) + Invariants.hasmethod_invariant(MLUtils.getobs, :data, :idx => 1) + ], + any; + description=Invariants.md(""" + `$var` must provide a way load an observation by implementing **either** + (1) `Base.getindex($var, idx::Int)` (preferred) or (2) `MLUtils.getobs($var, idx::Int)` + (if regular indexing is already used and has different semantics). + """), + inputfn=data -> (; data), +) + +__invariant_numobs(; var = :data) = invariant( + "`$var` implements the `numobs` interface", + [ + Invariants.hasmethod_invariant(Base.length, :data) + Invariants.hasmethod_invariant(MLUtils.numobs, :data) + ], + any; + description=Invariants.md(""" + `$var` must provide a way get the number of observations it contains implementing either + `Base.length($var)` (preferred) or `MLUtils.numobs($var, idx::Int)` + """), + inputfn=data -> (; data), +) + + +@testset "data container invariants" begin + @testset "is_data" begin + @test is_data(Bool, 1:10) + @test !is_data(Bool, nothing) + @test is_data(Bool, [1]) + @test !is_data(Bool, []) + + @test is_data(Bool, 1:10, Label(1:10)) + @test !is_data(Bool, 0:10, Label(1:10)) + @test is_data(Bool, [[0, 1]], OneHotLabel{Float32}([1, 2])) + + @test is_data(Bool, [(1, 2)], (Label([1]), Label([2]))) + end +end diff --git a/src/learner.jl b/src/learner.jl index 6896118fa9..535e481ff5 100644 --- a/src/learner.jl +++ b/src/learner.jl @@ -49,16 +49,16 @@ function tasklearner(task::LearningTask, backbone = nothing, model = nothing, callbacks = [], - pctgval = 0.2, batchsize = 16, optimizer = Adam(), lossfn = tasklossfn(task), + usedefaultcallbacks = true, kwargs...) if isnothing(model) model = isnothing(backbone) ? taskmodel(task) : taskmodel(task, backbone) end dls = taskdataloaders(traindata, validdata, task, batchsize; kwargs...) - return Learner(model, dls, optimizer, lossfn, callbacks...) + return Learner(model, dls, optimizer, lossfn, callbacks...; usedefaultcallbacks) end function tasklearner(task, data; pctgval = 0.2, kwargs...)