Skip to content

Commit 777b6bf

Browse files
authored
Fix WrapperBlock behavior and add Many (#158)
* Add dataset recipes and dataset registry * Add fastai dataset registry with some recipes * move `typify` helepr * add learning method registry * add test for `ImageSegmentationFolders` recipe * import method registry * move file * fix default data registry definitin * `WrapperBlock` now inherits from `WrapperBlock` to make wrapping work * Fix `en|decodedblock` for `ImagePreprocessing` to reflect dimensionality of `ImageTensor` * add missing `mockblock` for `OneHotMulti` * update learning methods * add convenience `plotpredictions` method * Improve `WrapperBlock` behavior * fix `WrapperBlock`s and implement `Many`
1 parent 523d555 commit 777b6bf

25 files changed

+724
-77
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ MosaicViews = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
2929
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
3030
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
3131
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
32+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3233
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
3334
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3435
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -53,8 +54,8 @@ Flux = "0.12"
5354
FluxTraining = "0.2"
5455
Glob = "1"
5556
IndirectArrays = "0.5"
56-
LearnBase = "0.3, 0.4"
5757
JLD2 = "0.4"
58+
LearnBase = "0.3, 0.4"
5859
MLDataPattern = "0.5"
5960
Makie = "0.15"
6061
MosaicViews = "0.2, 0.3"

src/FastAI.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using MLDataPattern
3030
using Parameters
3131
using PrettyTables
3232
using StaticArrays
33+
using Setfield
3334
using ShowCases
3435
using Statistics: mean
3536
using Test: @testset, @test, @test_nowarn
@@ -57,7 +58,6 @@ include("datablock/loss.jl")
5758
include("datablock/plot.jl")
5859

5960

60-
6161
# submodules
6262
include("datasets/Datasets.jl")
6363
@reexport using .Datasets
@@ -78,6 +78,11 @@ include("training/metrics.jl")
7878
include("serialization.jl")
7979

8080

81+
include("fasterai/methodregistry.jl")
82+
include("fasterai/learningmethods.jl")
83+
include("fasterai/defaults.jl")
84+
85+
8186

8287

8388
export
@@ -115,6 +120,7 @@ export
115120
Label,
116121
LabelMulti,
117122
Keypoints,
123+
Many,
118124

119125
# encodings
120126
encode,
@@ -131,6 +137,14 @@ export
131137
describemethod,
132138
checkblock,
133139

140+
# learning methods
141+
findlearningmethods,
142+
ImageClassificationSingle,
143+
ImageClassificationMulti,
144+
ImageSegmentation,
145+
ImageKeypointRegression,
146+
147+
134148
# training
135149
methodlearner,
136150
Learner,

src/datablock/block.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ Randomly generate an instance of `block`.
6060
mockblock(blocks::Tuple) = map(mockblock, blocks)
6161

6262

63+
# ## Utilities
64+
65+
typify(T::Type) = T
66+
typify(t::Tuple) = Tuple{map(typify, t)...}
67+
typify(block::FastAI.AbstractBlock) = typeof(block)
68+
69+
6370
# ## Block implementations
6471

6572
abstract type AbstractLabel{T} <: Block end

src/datablock/encoding.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ Replaces all `nothing`s in outblocks with the corresponding block in `inblocks`.
3939
`outblocks` may be obtained by
4040
"""
4141
fillblock(inblocks::Tuple, outblocks::Tuple) = map(fillblock, inblocks, outblocks)
42-
fillblock(inblock::Block, ::Nothing) = inblock
43-
fillblock(::Block, outblock::Block) = outblock
42+
fillblock(inblock::AbstractBlock, ::Nothing) = inblock
43+
fillblock(::AbstractBlock, outblock::AbstractBlock) = outblock
4444

4545
function encodedblock(enc, block, fill::Bool)
4646
outblock = encodedblock(enc, block)

src/datablock/wrappers.jl

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
1-
abstract type WrapperBlock <: Block end
1+
# # Wrapper blocks
2+
3+
abstract type WrapperBlock <: AbstractBlock end
24

35
wrapped(w::WrapperBlock) = w.block
6+
function setwrapped(w::WrapperBlock, b)
7+
return Setfield.@set w.block = b
8+
end
49
mockblock(w::WrapperBlock) = mockblock(wrapped(w))
510
checkblock(w::WrapperBlock, data) = checkblock(wrapped(w), data)
611

712
# If not overwritten, encodings are applied to the wrapped block
13+
14+
function encodedblock(enc::Encoding, wrapper::WrapperBlock)
15+
inner = encodedblock(enc, wrapped(wrapper))
16+
return isnothing(inner) ? nothing : setwrapped(wrapper, inner)
17+
end
18+
function decodedblock(enc::Encoding, wrapper::WrapperBlock)
19+
inner = decodedblock(enc, wrapped(wrapper))
20+
return isnothing(inner) ? nothing : setwrapped(wrapper, inner)
21+
end
822
function encode(enc::Encoding, ctx, wrapper::WrapperBlock, data; kwargs...)
9-
return encode(enc, ctx, wrapper.block, data; kwargs...)
23+
return encode(enc, ctx, wrapped(wrapper), data; kwargs...)
1024
end
1125
function decode(enc::Encoding, ctx, wrapper::WrapperBlock, data; kwargs...)
12-
return decode(enc, ctx, wrapper.block, data; kwargs...)
26+
return decode(enc, ctx, wrapped(wrapper), data; kwargs...)
1327
end
1428

29+
# ## Named
1530

1631
"""
1732
Named(name, block)
@@ -36,7 +51,39 @@ function decodedblock(enc::Encoding, named::Named{Name}) where Name
3651
return isnothing(outblock) ? nothing : Named(Name, outblock)
3752
end
3853

39-
# Wrapper encodings
54+
# ## Many
55+
56+
"""
57+
Many(block) <: WrapperBlock
58+
59+
`Many` indicates that you can variable number of data instances for
60+
`block`. Consider a bounding box detection task where there may be any
61+
number of targets in an image and this number varies for different
62+
samples. The blocks `(Image{2}(), BoundingBox{2}()` imply that there is exactly
63+
one bounding box for every image, which is not the case. Instead you
64+
would want to use `(Image{2}(), Many(BoundingBox{2}())`.
65+
"""
66+
struct Many{B<:AbstractBlock} <: WrapperBlock
67+
block::B
68+
end
69+
70+
FastAI.checkblock(many::Many, datas) = all(checkblock(wrapped(many), data) for data in datas)
71+
FastAI.mockblock(many::Many) = [mockblock(wrapped(many)), mockblock(wrapped(many))]
72+
73+
function FastAI.encode(enc::Encoding, ctx, many::Many, datas)
74+
return map(datas) do data
75+
encode(enc, ctx, wrapped(many), data)
76+
end
77+
end
78+
79+
function FastAI.decode(enc::Encoding, ctx, many::Many, datas)
80+
return map(datas) do data
81+
decode(enc, ctx, wrapped(many), data)
82+
end
83+
end
84+
85+
86+
# # Wrapper encodings
4087

4188
"""
4289
Only(name, encoding)

src/datasets/Datasets.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module Datasets
1414

1515

1616
using ..FastAI
17+
using ..FastAI: typify
1718

1819
using DataDeps
1920
using Glob
@@ -41,6 +42,9 @@ include("containers.jl")
4142
include("transformations.jl")
4243

4344
include("load.jl")
45+
include("recipes.jl")
46+
include("registry.jl")
47+
include("fastairegistry.jl")
4448

4549

4650
export
@@ -71,6 +75,12 @@ export
7175
# datasets
7276
DATASETS,
7377
loadfolderdata,
74-
datasetpath
78+
datasetpath,
79+
80+
# recipes
81+
loadrecipe,
82+
finddatasets,
83+
listdatasources,
84+
loaddataset
7585

7686
end # module

src/datasets/fastairegistry.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
const FASTAI_DATA_RECIPES = Dict{String, Vector{DatasetRecipe}}(
3+
# Image classification datasets
4+
[name => [ImageClassificationFolders()] for name in (
5+
"imagenette", "imagenette-160", "imagenette-320",
6+
"imagenette2", "imagenette2-160", "imagenette2-320",
7+
"imagewoof", "imagewoof-160", "imagewoof-320",
8+
"imagewoof2", "imagewoof2-160", "imagewoof2-320",
9+
)]...,
10+
11+
"camvid_tiny" => [ImageSegmentationFolders()],
12+
)
13+
14+
15+
"""
16+
const FASTAI_DATA_REGISTRY
17+
18+
The default `DataRegistry` containing every dataset in
19+
the fastai dataset collection.
20+
"""
21+
const FASTAI_DATA_REGISTRY = DatasetRegistry(
22+
Dict(d => () -> datasetpath(d) for d in DATASETS),
23+
FASTAI_DATA_RECIPES,
24+
)

src/datasets/load.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,29 +51,6 @@ function getclassessegmentation(dir::AbstractPath)
5151
end
5252
getclassessegmentation(name::String) = getclassessegmentation(datasetpath(name))
5353

54-
#=
55-
"""
56-
loadtaskdata(dir, ImageSegmentation; [split = false])
57-
58-
Load a data container for `ImageSegmentation` with observations
59-
`(input = image, target = mask)`.
60-
61-
If `split` is `true`, returns a tuple of the data containers split by
62-
the name of the grandparent folder.
63-
64-
"""
65-
function loadtaskdata(
66-
dir,
67-
::Type{FastAI.ImageSegmentation};
68-
split=false,
69-
kwargs...)
70-
imagedata = mapobs(loadfile, filterobs(isimagefile, FileDataset(joinpath(dir, "images"))))
71-
maskdata = mapobs(maskfromimage ∘ loadfile, filterobs(isimagefile, FileDataset(joinpath(dir, "labels"))))
72-
return mapobs((input = obs -> obs[1], target = obs -> obs[2]), (imagedata, maskdata))
73-
end
74-
=#
75-
76-
7754

7855

7956
maskfromimage(a::AbstractArray{<:Gray{T}}, classes) where T = maskfromimage(reinterpret(T, a), classes)

src/datasets/recipes.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
abstract type DatasetRecipe
3+
4+
A recipe that contains configuration for loading a data container. Calling it with a path returns a data container and the blocks that each sample is made of.
5+
6+
#### Interface
7+
8+
- `loadrecipe(::DatasetRecipe, args...) -> (data, blocks)`
9+
- `recipeblocks(::Type{DatasetRecipe}) -> TBlocks`
10+
11+
#### Invariants
12+
13+
- `data` must be a data container of samples that are valid `blocks`, i.e. `checkblock(blocks, getobs(data, 1)) == true`
14+
"""
15+
abstract type DatasetRecipe end
16+
17+
18+
"""
19+
loadrecipe(recipe, path)
20+
21+
Load a recipe from a path. Return a data container `data` and concrete
22+
`blocks`.
23+
"""
24+
function loadrecipe end
25+
26+
27+
"""
28+
recipeblocks(TRecipe) -> TBlocks
29+
recipeblocks(recipe) -> TBlocks
30+
31+
Return the `Block` _types_ for the data container that recipe
32+
type `TRecipe` creates. Does not return `Block` instances as the exact
33+
configuration may not be known until the dataset is being
34+
loaded.
35+
36+
#### Examples
37+
38+
```julia
39+
recipeblocks(ImageLabelClf) == Tuple{Image{2}, Label}
40+
```
41+
"""
42+
recipeblocks(::R) where {R<:DatasetRecipe} = recipeblocks(R)
43+
44+
45+
# ## Implementations
46+
47+
# ImageClfFolders
48+
49+
"""
50+
ImageClfFolders(; labelfn = parentname, split = false)
51+
52+
Recipe for loading a single-label image classification dataset
53+
stored in a hierarchical folder format. If `split == true`, split
54+
the data container on the name of the grandparent folder. The label
55+
defaults to the name of the parent folder but a custom function can
56+
be passed as `labelfn`.
57+
58+
```julia
59+
julia> recipeblocks(ImageClassificationFolders)
60+
Tuple{Image{2}, Label}
61+
```
62+
"""
63+
Base.@kwdef struct ImageClassificationFolders <: DatasetRecipe
64+
labelfn = parentname
65+
split::Bool = false
66+
end
67+
68+
function loadrecipe(recipe::ImageClassificationFolders, path)
69+
isdir(path) || error("$path is not a directory")
70+
data = loadfolderdata(
71+
path,
72+
filterfn=isimagefile,
73+
loadfn=(loadfile, recipe.labelfn),
74+
splitfn=recipe.split ? grandparentname : nothing)
75+
76+
(recipe.split ? length(data) > 0 : nobs(data) > 0) || error("No image files found in $path")
77+
78+
labels = recipe.split ? first(values(data))[2] : data[2]
79+
blocks = Image{2}(), Label(unique(eachobs(labels)))
80+
length(blocks[2].classes) > 1 || error("Expected multiple different labels, got: $(blocks[2].classes))")
81+
return data, blocks
82+
end
83+
84+
recipeblocks(::Type{ImageClassificationFolders}) = Tuple{Image{2}, Label}
85+
86+
87+
# ImageSegmentationFolders
88+
89+
90+
"""
91+
ImageSegmentationFolders(; imagefolder="images", maskfolder="labels", labelfile="codes.txt")
92+
93+
Dataset recipe for loading 2D image segmentation datasets from a common format
94+
where images and masks are stored as images in two different subfolders
95+
"<root>/<imagefolder>" and "<root>/<maskfolder>"
96+
The class labels should be in a newline-delimited file "<root>/<labelfile>".
97+
"""
98+
Base.@kwdef struct ImageSegmentationFolders <: DatasetRecipe
99+
imagefolder::String = "images"
100+
maskfolder::String = "labels"
101+
labelfile::String = "codes.txt"
102+
end
103+
104+
function loadrecipe(recipe::ImageSegmentationFolders, path)
105+
isdir(path) || error("$path is not a directory")
106+
imagepath = joinpath(path, recipe.imagefolder)
107+
maskpath = joinpath(path, recipe.maskfolder)
108+
classespath = joinpath(path, recipe.labelfile)
109+
110+
isdir(imagepath) || error("Image folder $imagepath is not a directory")
111+
isdir(maskpath) || error("Mask folder $maskpath is not a directory")
112+
113+
isfile(classespath) || error("Classes file $classespath does not exist")
114+
classes = readlines(open(joinpath(path, recipe.labelfile)))
115+
length(classes) > 1 || error("Expected multiple different labels, got: $(blocks[2].classes))")
116+
117+
images = loadfolderdata(imagepath, filterfn=isimagefile, loadfn=loadfile)
118+
masks = loadfolderdata(maskpath, filterfn=isimagefile, loadfn=f -> loadmask(f, classes))
119+
nobs(images) == nobs(masks) || error("Expected the same number of images and masks, but found $(nobs(images)) images and $(nobs(masks)) masks")
120+
nobs(images) > 0 || error("No images or masks found in folders $imagepath and $maskpath")
121+
122+
blocks = Image{2}(), Mask{2}(classes)
123+
return (images, masks), blocks
124+
end
125+
126+
recipeblocks(::Type{ImageSegmentationFolders}) = Tuple{Image{2}, Mask{2}}

0 commit comments

Comments
 (0)