Skip to content

Commit 6d0e123

Browse files
bors[bot]tknopp
andauthored
Merge #1792
1792: Add trilinear Upsample layer r=ToucheSir a=tknopp The trilinear Upsample layer was missing although it was actually already implemented in NNLib. This PR makes it accessible through the `Upsample` layer. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [x] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Tobias Knopp <tobias@knoppweb.de>
2 parents cbc1275 + 95c8435 commit 6d0e123

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## v0.12.9
44
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
5+
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
56

67
## v0.12.8
78
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)

docs/src/models/nnlib.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ NNlib.depthwiseconv
5656
```@docs
5757
NNlib.upsample_nearest
5858
NNlib.upsample_bilinear
59+
NNlib.upsample_trilinear
5960
NNlib.pixel_shuffle
6061
```
6162

src/layers/upsample.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Currently supported upsampling `mode`s
1212
and corresponding NNlib's methods are:
1313
- `:nearest` -> [`NNlib.upsample_nearest`](@ref)
1414
- `:bilinear` -> [`NNlib.upsample_bilinear`](@ref)
15+
- `:trilinear` -> [`NNlib.upsample_trilinear`](@ref)
1516
1617
# Examples
1718
@@ -35,7 +36,7 @@ struct Upsample{mode, S, T}
3536
end
3637

3738
function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing)
38-
mode in [:nearest, :bilinear] ||
39+
mode in [:nearest, :bilinear, :trilinear] ||
3940
throw(ArgumentError("mode=:$mode is not supported."))
4041
if !(isnothing(scale) isnothing(size))
4142
throw(ArgumentError("Either scale or size should be specified (but not both)."))
@@ -58,6 +59,11 @@ end
5859
(m::Upsample{:bilinear, Nothing})(x::AbstractArray) =
5960
NNlib.upsample_bilinear(x; size=m.size)
6061

62+
(m::Upsample{:trilinear})(x::AbstractArray) =
63+
NNlib.upsample_trilinear(x, m.scale)
64+
(m::Upsample{:trilinear, Nothing})(x::AbstractArray) =
65+
NNlib.upsample_trilinear(x; size=m.size)
66+
6167
function Base.show(io::IO, u::Upsample{mode}) where {mode}
6268
print(io, "Upsample(")
6369
print(io, ":", mode)

test/layers/upsample.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@
1818
@test size(y) == (4, 6, 2, 3)
1919
end
2020

21+
@testset "upsample trilinear" begin
22+
m = Upsample(:trilinear, scale=(2, 3, 2))
23+
x = rand(Float32, 3, 4, 2, 3, 4)
24+
y = m(x)
25+
@test y isa Array{Float32, 5}
26+
@test size(y) == (6, 12, 4, 3, 4)
27+
28+
m = Upsample(:trilinear, scale=3)
29+
x = rand(Float32, 3, 4, 2, 3, 4)
30+
y = m(x)
31+
@test y isa Array{Float32, 5}
32+
@test size(y) == (9, 12, 6, 3, 4)
33+
34+
m = Upsample(:trilinear, size=(4, 6, 4))
35+
x = rand(Float32, 3, 4, 2, 3, 4)
36+
y = m(x)
37+
@test y isa Array{Float32, 5}
38+
@test size(y) == (4, 6, 4, 3, 4)
39+
end
40+
2141
@testset "upsample nearest" begin
2242
x = rand(Float32, 3, 2, 3)
2343
m = Upsample(:nearest, scale=(2,))

0 commit comments

Comments
 (0)