Skip to content

Commit c076c6b

Browse files
committed
Add docs
1. Loosen type constraints on `inputscale` 2. Updated README
1 parent 97fc329 commit c076c6b

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | N |
2121
| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.GoogLeNet.html) | N |
2222
| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv3.html) | N |
23+
| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv4.html) | N |
24+
| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.InceptionResNetv2.html) | N |
2325
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.SqueezeNet.html) | N |
2426
| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet.html) | N |
2527
| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N |

src/convnets/inception.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,15 @@ end
286286

287287
"""
288288
inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
289+
290+
Create an Inceptionv4 model.
291+
([reference](https://arxiv.org/abs/1602.07261))
292+
293+
# Arguments
294+
295+
- inchannels: number of input channels.
296+
- dropout: rate of dropout in classifier head.
297+
- nclasses: the number of output classes.
289298
"""
290299
function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
291300
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
@@ -314,6 +323,18 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
314323
return Chain(body, head)
315324
end
316325

326+
"""
327+
Inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
328+
329+
Creates an Inceptionv4 model.
330+
([reference](https://arxiv.org/abs/1602.07261))
331+
332+
# Arguments
333+
334+
- inchannels: number of input channels.
335+
- dropout: rate of dropout in classifier head.
336+
- nclasses: the number of output classes.
337+
"""
317338
struct Inceptionv4
318339
layers::Any
319340
end
@@ -398,6 +419,18 @@ function block8(scale = 1.0f0; no_relu = false)
398419
branch3, inputscale(scale; activation = activation)), +)
399420
end
400421

422+
"""
423+
inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
424+
425+
Creates an InceptionResNetv2 model.
426+
([reference](https://arxiv.org/abs/1602.07261))
427+
428+
# Arguments
429+
430+
- inchannels: number of input channels.
431+
- dropout: rate of dropout in classifier head.
432+
- nclasses: the number of output classes.
433+
"""
401434
function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
402435
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
403436
conv_bn((3, 3), 32, 32)...,
@@ -418,6 +451,18 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
418451
return Chain(body, head)
419452
end
420453

454+
"""
455+
InceptionResNetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
456+
457+
Creates an InceptionResNetv2 model.
458+
([reference](https://arxiv.org/abs/1602.07261))
459+
460+
# Arguments
461+
462+
- inchannels: number of input channels.
463+
- dropout: rate of dropout in classifier head.
464+
- nclasses: the number of output classes.
465+
"""
421466
struct InceptionResNetv2
422467
layers::Any
423468
end

src/utilities.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ cat_channels(xy...) = cat(xy...; dims = Val(3))
3939
"""
4040
inputscale(λ; activation = identity)
4141
42-
Scale the input (assumed to be an `AbstractArray`) by a scalar λ and applies an activation
43-
function to it. Equivalent to `activation.(λ .* x)`.
42+
Scale the input by a scalar λ and applies an activation function to it.
43+
Equivalent to `activation.(λ .* x)`.
4444
"""
4545
inputscale(λ; activation = identity) = x -> _input_scale(x, λ, activation)
46-
_input_scale(x::AbstractArray, λ, activation) = activation.(λ .* x)
47-
_input_scale(x::AbstractArray, λ, ::typeof(identity)) = λ .* x
46+
_input_scale(x, λ, activation) = activation.(λ .* x)
47+
_input_scale(x, λ, ::typeof(identity)) = λ .* x
4848

4949
"""
5050
swapdims(perm)

0 commit comments

Comments
 (0)