Skip to content

Commit 15f90a6

Browse files
committed
grow_at_start instead of grow_first
1 parent 72f5566 commit 15f90a6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/convnets/inception.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ classifier(m::InceptionResNetv2) = m.layers[2]
489489

490490
"""
491491
xception_block(inchannels, outchannels, nrepeats; stride = 1, start_with_relu = true,
492-
grow_first = true)
492+
grow_at_start = true)
493493
494494
Create an Xception block.
495495
([reference](https://arxiv.org/abs/1610.02357))
@@ -501,11 +501,11 @@ Create an Xception block.
501501
- `nrepeats`: number of repeats of depthwise separable convolution layers.
502502
- `stride`: stride by which to downsample the input.
503503
- `start_with_relu`: if true, start the block with a ReLU activation.
504-
- `grow_first`: if true, increase the number of channels at the first convolution.
504+
- `grow_at_start`: if true, increase the number of channels at the first convolution.
505505
"""
506506
function xception_block(inchannels, outchannels, nrepeats; stride = 1,
507507
start_with_relu = true,
508-
grow_first = true)
508+
grow_at_start = true)
509509
if outchannels != inchannels || stride != 1
510510
skip = conv_bn((1, 1), inchannels, outchannels, identity; stride = stride,
511511
bias = false)
@@ -514,7 +514,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
514514
end
515515
layers = []
516516
for i in 1:nrepeats
517-
if grow_first
517+
if grow_at_start
518518
inc = i == 1 ? inchannels : outchannels
519519
outc = outchannels
520520
else
@@ -551,7 +551,7 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
551551
xception_block(128, 256, 2; stride = 2),
552552
xception_block(256, 728, 2; stride = 2),
553553
[xception_block(728, 728, 3) for _ in 1:8]...,
554-
xception_block(728, 1024, 2; stride = 2, grow_first = false),
554+
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
555555
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
556556
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
557557
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses))

0 commit comments

Comments
 (0)