Skip to content

Commit edf83e0

Browse files
authored
Merge pull request #159 from theabhirath/repeat-fix-2
2 parents 2f39fd3 + f53fd94 commit edf83e0

File tree

13 files changed

+49
-42
lines changed

13 files changed

+49
-42
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
1111
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1212
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
13-
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
1413
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1514

1615
[compat]
@@ -20,15 +19,14 @@ Functors = "0.2"
2019
MLUtils = "0.2"
2120
NNlib = "0.7.34, 0.8"
2221
julia = "1.6"
23-
NeuralAttentionlib = "0.0"
2422

2523
[extras]
2624
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2725

2826
[publish]
29-
title = "Metalhead.jl"
30-
theme = "_flux-theme"
3127
ignore = ["^(gh-pages|juliamnt|julia.dmg)$"]
28+
theme = "_flux-theme"
29+
title = "Metalhead.jl"
3230

3331
[targets]
3432
test = ["Test"]

src/Metalhead.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using BSON
77
using Artifacts, LazyArtifacts
88
using Statistics
99
using MLUtils
10-
using NeuralAttentionlib
1110

1211
import Functors
1312

src/convnets/densenet.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct DenseNet
114114
end
115115

116116
function DenseNet(nblocks::NTuple{N, <:Integer};
117-
growth_rate = 32, reduction = 0.5, nclasses = 1000) where N
117+
growth_rate = 32, reduction = 0.5, nclasses = 1000) where {N}
118118
layers = densenet(nblocks; growth_rate = growth_rate,
119119
reduction = reduction,
120120
nclasses = nclasses)
@@ -135,7 +135,8 @@ const densenet_config = Dict(121 => (6, 12, 24, 16),
135135
201 => (6, 12, 48, 32))
136136

137137
"""
138-
DenseNet(config::Int = 121; pretrain = false, nclasses = 1000)
138+
DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
139+
DenseNet(transition_config::NTuple{N,Integer})
139140
140141
Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201)
141142
([reference](https://arxiv.org/abs/1608.06993)).
@@ -146,7 +147,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
146147
147148
See also [`Metalhead.densenet`](#).
148149
"""
149-
function DenseNet(config::Int = 121; pretrain = false, nclasses = 1000)
150+
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
150151
@assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))."
151152
model = DenseNet(densenet_config[config]; nclasses = nclasses)
152153

src/convnets/mobilenet.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ function mobilenetv1(width_mult, config;
2828
nclasses = 1000,
2929
fcsize = 1024)
3030
layers = []
31-
for (dw, outch, stride, repeats) in config
31+
for (dw, outch, stride, nrepeats) in config
3232
outch = Int(outch * width_mult)
33-
for _ in 1:repeats
33+
for _ in 1:nrepeats
3434
layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
3535
stride = stride, pad = 1) :
3636
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
@@ -198,11 +198,11 @@ Create a MobileNetv3 model.
198198
(with 1.0 being the default in the paper;
199199
this is usually a value between 0.1 and 1.4)
200200
- `configs`: a "list of tuples" configuration for each layer that details:
201-
- `k::Int` - The size of the convolutional kernel
201+
- `k::Integer` - The size of the convolutional kernel
202202
- `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer
203-
- `t::Int` - The number of output feature maps for a given block
204-
- `r::Int` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers
205-
- `s::Int` - The stride of the convolutional kernel
203+
- `t::Integer` - The number of output feature maps for a given block
204+
- `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers
205+
- `s::Integer` - The stride of the convolutional kernel
206206
- `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`)
207207
- `max_width`: The maximum number of feature maps in any layer of the network
208208
- `nclasses`: the number of output classes

src/convnets/resnet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ as shown below:
236236
resnet50_v1 = ResNet([1, 1, 4], [3, 4, 6, 3], :B; block = Metalhead.bottleneck_v1)
237237
```
238238
"""
239-
function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
239+
function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000)
240240
@assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))"
241241

242242
config, block = resnet_config[depth]

src/convnets/resnext.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ const resnext_config = Dict(
9999
)
100100

101101
"""
102-
ResNeXt(config::Int = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000)
102+
ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000)
103103
104104
Create a ResNeXt model with specified configuration. Currently supported values for `config` are (50, 101).
105105
([reference](https://arxiv.org/abs/1611.05431)).
@@ -110,7 +110,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
110110
111111
See also [`Metalhead.resnext`](#).
112112
"""
113-
function ResNeXt(config::Int = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000)
113+
function ResNeXt(config::Integer = 50; cardinality = 32, width = 4, pretrain = false, nclasses = 1000)
114114
@assert config in keys(resnext_config) "`config` must be one of $(sort(collect(keys(resnext_config))))"
115115

116116
model = ResNeXt(cardinality, width; block_config = resnext_config[config], nclasses)

src/convnets/vgg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ Construct a VGG model with the specified input image size. Typically, the image
115115
116116
## Keyword Arguments:
117117
- `config` : VGG convolutional block configuration. It is defined as a vector of tuples `(output_channels, num_convolutions)` for each block
118-
- `inchannels`::Int : number of input channels
118+
- `inchannels`::Integer : number of input channels
119119
- `batchnorm`::Bool : set to `true` to use batch normalization after each convolution
120-
- `nclasses`::Int : number of output classes
120+
- `nclasses`::Integer : number of output classes
121121
- `fcsize`: intermediate fully connected layer size
122122
(see [`Metalhead.vgg_classifier_layers`](#))
123123
- `dropout`: dropout level between fully connected layers
@@ -142,7 +142,7 @@ backbone(m::VGG) = m.layers[1]
142142
classifier(m::VGG) = m.layers[2]
143143

144144
"""
145-
VGG(depth::Int = 16; pretrain = false, batchnorm = false)
145+
VGG(depth::Integer = 16; pretrain = false, batchnorm = false)
146146
147147
Create a VGG style model with specified `depth`. Available values include (11, 13, 16, 19).
148148
([reference](https://arxiv.org/abs/1409.1556v6)).
@@ -154,7 +154,7 @@ See also [`VGG`](#).
154154
# Arguments
155155
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
156156
"""
157-
function VGG(depth::Int = 16; pretrain = false, batchnorm = false, nclasses = 1000)
157+
function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000)
158158
@assert depth in keys(vgg_config) "depth must be from one in $(sort(collect(keys(vgg_config))))"
159159

160160
model = VGG((224, 224); config = vgg_conv_config[vgg_config[depth]],

src/layers/Layers.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Flux: outputsize, Zygote
55
using Functors
66
using Statistics
77
using MLUtils
8-
using NeuralAttentionlib
98

109
include("../utilities.jl")
1110

src/layers/attention.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
MHAttention(nheads::Int, qkv_layer, attn_drop, projection)
2+
MHAttention(nheads::Integer, qkv_layer, attn_drop, projection)
33
44
Multi-head self-attention layer.
55
@@ -17,7 +17,7 @@ struct MHAttention{P, Q, R}
1717
end
1818

1919
"""
20-
MHAttention(planes, nheads = 8; qkv_bias = false, attn_drop = 0., proj_drop = 0.)
20+
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.)
2121
2222
Multi-head self-attention layer.
2323
@@ -28,7 +28,7 @@ Multi-head self-attention layer.
2828
- `attn_drop`: dropout rate after the self-attention layer
2929
- `proj_drop`: dropout rate after the projection layer
3030
"""
31-
function MHAttention(planes, nheads = 8; qkv_bias = false, attn_drop = 0., proj_drop = 0.)
31+
function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.)
3232
@assert planes % nheads == 0 "planes should be divisible by nheads"
3333
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
3434
attn_drop = Dropout(attn_drop)
@@ -39,10 +39,20 @@ end
3939

4040
@functor MHAttention
4141

42-
function (m::MHAttention)(x::AbstractArray{T, 3}) where T
43-
features, len_seq, batch_size = size(x)
44-
q, k, v = chunk(reshape(m.qkv_layer(x), features ÷ m.nheads, m.nheads, len_seq, 3 * batch_size), 3; dims = 4)
45-
scale = convert(T, sqrt(size(q, 1) / m.nheads))
46-
attn = m.attn_drop(softmax(NeuralAttentionlib.matmul(q, permutedims(k, (2, 1, 3, 4))) * scale))
47-
x = m.projection(reshape(NeuralAttentionlib.matmul(attn, v), (features, len_seq, batch_size)))
42+
function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
43+
nfeatures, seq_len, batch_size = size(x)
44+
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
45+
qkv = m.qkv_layer(x_reshaped)
46+
qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size)
47+
query, key, value = chunk(qkv_reshaped, 3; dims = 4)
48+
scale = convert(T, sqrt(size(query, 1) / m.nheads))
49+
key_reshaped = reshape(
50+
permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads, seq_len * batch_size
51+
)
52+
query_reshaped = reshape(query, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
53+
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
54+
value_reshaped = reshape(value, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
55+
pre_projection = reshape(batched_mul(attention, value_reshaped), (nfeatures, seq_len, batch_size))
56+
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
57+
return reshape(y, :, seq_len, batch_size)
4858
end

src/layers/embeddings.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ patches.
1717
single argument constructor for a normalization layer like LayerNorm or BatchNorm
1818
- `flatten`: set true to flatten the input spatial dimensions after the embedding
1919
"""
20-
function PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels = 3,
20+
function PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels::Integer = 3,
2121
patch_size::Dims{2} = (16, 16), embedplanes = 768,
2222
norm_layer = planes -> identity, flatten = true)
2323

@@ -33,15 +33,15 @@ function PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels = 3,
3333
end
3434

3535
"""
36-
ViPosEmbedding(embedsize, npatches; init = (dims::Dims{2}) -> rand(Float32, dims))
36+
ViPosEmbedding(embedsize::Integer, npatches::Integer; init = (dims::Dims{2}) -> rand(Float32, dims))
3737
3838
Positional embedding layer used by many vision transformer-like models.
3939
"""
4040
struct ViPosEmbedding{T}
4141
vectors::T
4242
end
4343

44-
ViPosEmbedding(embedsize, npatches; init = (dims::Dims{2}) -> rand(Float32, dims)) =
44+
ViPosEmbedding(embedsize::Integer, npatches::Integer; init = (dims::Dims{2}) -> rand(Float32, dims)) =
4545
ViPosEmbedding(init((embedsize, npatches)))
4646

4747
(p::ViPosEmbedding)(x) = x .+ p.vectors
@@ -59,8 +59,8 @@ end
5959

6060
ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1))
6161

62-
function (m::ClassTokens)(x)
63-
tokens = repeat(m.token, 1, 1, size(x, 3))
62+
function (m::ClassTokens)(x::AbstractArray{T, 3}) where {T}
63+
tokens = m.token .* fill!(similar(x, 1, 1, size(x, 3)), one(T))
6464
return hcat(tokens, x)
6565
end
6666

0 commit comments

Comments
 (0)