Skip to content

Commit 9d11b1e

Browse files
committed
Merge branch 'refine' of https://github.com/theabhirath/Metalhead.jl into refine
2 parents 5aece44 + e9306c3 commit 9d11b1e

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

src/layers/mlp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...)
4848

4949
"""
50-
create_classifier(inplanes::Integer, nclasses::Integer, activation = relu;
50+
create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
5151
pool_layer = AdaptiveMeanPool((1, 1)),
5252
dropout_rate = 0.0, use_conv::Bool = false)
5353
@@ -64,7 +64,7 @@ Creates a classifier head to be used for models.
6464
- `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer.
6565
"""
6666
function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
67-
use_conv::Bool = falsepool_layer = AdaptiveMeanPool((1, 1)),
67+
use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)),
6868
dropout_rate = nothing)
6969
# Pooling
7070
flatten_in_pool = !use_conv && pool_layer !== identity

src/mixers/mlpmixer.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ function mixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block,
3434
end
3535

3636
"""
37-
MLPMixer(size::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224),
38-
inchannels::Integer = 3, nclasses::Integer = 1000)
37+
MLPMixer(size::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224),
38+
inchannels::Integer = 3, nclasses::Integer = 1000)
3939
4040
Creates a model with the MLPMixer architecture.
4141
([reference](https://arxiv.org/pdf/2105.01601)).

src/vit-based/vit.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.)
2+
transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.)
33
44
Transformer as used in the base ViT architecture.
55
([reference](https://arxiv.org/abs/2010.11929)).
@@ -99,12 +99,10 @@ struct ViT
9999
end
100100
@functor ViT
101101

102-
function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256),
103-
patch_size::Dims{2} = (16, 16),
102+
function ViT(mode::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16),
104103
inchannels::Integer = 3, nclasses::Integer = 1000)
105104
_checkconfig(mode, keys(VIT_CONFIGS))
106-
kwargs = VIT_CONFIGS[mode]
107-
layers = vit(imsize; inchannels, patch_size, nclasses, kwargs...)
105+
layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[mode]...)
108106
return ViT(layers)
109107
end
110108

0 commit comments

Comments
 (0)