Skip to content

Commit d90a6ae

Browse files
committed
Unify higher level DenseNet API
1 parent 9d11b1e commit d90a6ae

File tree

3 files changed

+17
-36
lines changed

3 files changed

+17
-36
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ jobs:
3434
- '"Inception"'
3535
- '"DenseNet"'
3636
- '["ConvNeXt", "ConvMixer"]'
37-
- '[r"ViTs", r"Mixers"]'
37+
- 'r"Mixers"'
38+
- 'r"ViTs"'
3839
steps:
3940
- uses: actions/checkout@v2
4041
- uses: julia-actions/setup-julia@v1

src/convnets/densenet.jl

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -105,44 +105,13 @@ function densenet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reducti
105105
reduction, inchannels, nclasses)
106106
end
107107

108-
"""
109-
DenseNet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
110-
inchannels = 3, nclasses::Integer = 1000)
111-
112-
Create a DenseNet model
113-
([reference](https://arxiv.org/abs/1608.06993)).
114-
See also [`densenet`](#).
115-
116-
# Arguments
117-
118-
- `nblocks`: number of dense blocks between transitions
119-
- `growth_rate`: the output feature map growth rate of dense blocks (i.e. `k` in the paper)
120-
- `reduction`: the factor by which the number of feature maps is scaled across each transition
121-
- `nclasses`: the number of output classes
122-
"""
123-
struct DenseNet
124-
layers::Any
125-
end
126-
@functor DenseNet
127-
128-
function DenseNet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
129-
inchannels = 3, nclasses::Integer = 1000)
130-
layers = densenet(nblocks; growth_rate, reduction, inchannels, nclasses)
131-
return DenseNet(layers)
132-
end
133-
134-
(m::DenseNet)(x) = m.layers(x)
135-
136-
backbone(m::DenseNet) = m.layers[1]
137-
classifier(m::DenseNet) = m.layers[2]
138-
139108
const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16],
140109
161 => [6, 12, 36, 24],
141110
169 => [6, 12, 32, 32],
142111
201 => [6, 12, 48, 32])
143112

144113
"""
145-
DenseNet(config::Integer = 121; pretrain::Bool = false, nclasses::Integer = 1000)
114+
DenseNet(config::Integer; pretrain::Bool = false, nclasses::Integer = 1000)
146115
DenseNet(transition_configs::NTuple{N,Integer})
147116
148117
Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201)
@@ -155,11 +124,22 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
155124
156125
See also [`Metalhead.densenet`](#).
157126
"""
158-
function DenseNet(config::Integer = 121; pretrain::Bool = false, nclasses::Integer = 1000)
127+
struct DenseNet
128+
layers::Any
129+
end
130+
@functor DenseNet
131+
132+
function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32,
133+
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
159134
_checkconfig(config, keys(DENSENET_CONFIGS))
160-
model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses)
135+
model = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
161136
if pretrain
162137
loadpretrain!(model, string("DenseNet", config))
163138
end
164139
return model
165140
end
141+
142+
(m::DenseNet)(x) = m.layers(x)
143+
144+
backbone(m::DenseNet) = m.layers[1]
145+
classifier(m::DenseNet) = m.layers[2]

test/vits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "ViT" begin
2-
for mode in [:tiny, :small, :base, :large, :huge] #:giant, #:gigantic
2+
for mode in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic]
33
m = ViT(mode)
44
@test size(m(x_256)) == (1000, 1)
55
@test gradtest(m, x_256)

0 commit comments

Comments
 (0)