Skip to content

Commit f9e7130

Browse files
committed
add trilinear Upsample layer
1 parent bb88c55 commit f9e7130

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/layers/upsample.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Currently supported upsampling `mode`s
1212
and corresponding NNlib's methods are:
1313
- `:nearest` -> [`NNlib.upsample_nearest`](@ref)
1414
- `:bilinear` -> [`NNlib.upsample_bilinear`](@ref)
15+
- `:trilinear` -> [`NNlib.upsample_trilinear`](@ref)
1516
1617
# Examples
1718
@@ -35,7 +36,7 @@ struct Upsample{mode, S, T}
3536
end
3637

3738
function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing)
38-
mode in [:nearest, :bilinear] ||
39+
mode in [:nearest, :bilinear, :trilinear] ||
3940
throw(ArgumentError("mode=:$mode is not supported."))
4041
if !(isnothing(scale) isnothing(size))
4142
throw(ArgumentError("Either scale or size should be specified (but not both)."))
@@ -58,6 +59,11 @@ end
5859
(m::Upsample{:bilinear, Nothing})(x::AbstractArray) =
5960
NNlib.upsample_bilinear(x; size=m.size)
6061

62+
(m::Upsample{:trilinear})(x::AbstractArray) =
63+
NNlib.upsample_trilinear(x, m.scale)
64+
(m::Upsample{:trilinear, Nothing})(x::AbstractArray) =
65+
NNlib.upsample_trilinear(x; size=m.size)
66+
6167
function Base.show(io::IO, u::Upsample{mode}) where {mode}
6268
print(io, "Upsample(")
6369
print(io, ":", mode)

test/layers/upsample.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@
1818
@test size(y) == (4, 6, 2, 3)
1919
end
2020

21+
@testset "upsample trilinear" begin
22+
m = Upsample(:trilinear, scale=(2, 3, 2))
23+
x = rand(Float32, 3, 4, 2, 3, 4)
24+
y = m(x)
25+
@test y isa Array{Float32, 5}
26+
@test size(y) == (6, 12, 4, 3, 4)
27+
28+
m = Upsample(:trilinear, scale=3)
29+
x = rand(Float32, 3, 4, 2, 3, 4)
30+
y = m(x)
31+
@test y isa Array{Float32, 5}
32+
@test size(y) == (9, 12, 6, 3, 4)
33+
34+
m = Upsample(:trilinear, size=(4, 6, 4))
35+
x = rand(Float32, 3, 4, 2, 3, 4)
36+
y = m(x)
37+
@test y isa Array{Float32, 5}
38+
@test size(y) == (4, 6, 4, 3, 4)
39+
end
40+
2141
@testset "upsample nearest" begin
2242
x = rand(Float32, 3, 2, 3)
2343
m = Upsample(:nearest, scale=(2,))

0 commit comments

Comments
 (0)