Skip to content

Commit 4a3c80b

Browse files
author
Anton Smirnov
committed
Refactor
1 parent 1893507 commit 4a3c80b

File tree

2 files changed

+58
-51
lines changed

2 files changed

+58
-51
lines changed

src/layers/normalise.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,22 @@ function _promote_to_output(
206206
_maybe_promote_type(T, Vel), N), Wel)
207207
end
208208

209+
function _basetype(::Type{T}) where T
210+
if T <: Array
211+
return Array
212+
elseif T <: CuArray
213+
return CuArray
214+
end
215+
throw("Unsupported type $T")
216+
end
217+
209218
# For InstanceNorm, GroupNorm, and BatchNorm.
210219
# Compute the statistics on the slices specified by reduce_dims.
211220
# reduce_dims=[1,...,N-2,N] for BatchNorm
212221
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
213-
_norm_layer_forward(l, x; reduce_dims, affine_shape) =
214-
_norm_layer_forward(l, x, _promote_to_output(l, x); reduce_dims, affine_shape)
215-
216222
function _norm_layer_forward(
217-
l, x::Array{T, N}, ::Type{O}; reduce_dims, affine_shape,
218-
) where {T, N, O}
223+
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
224+
) where {T, N}
219225
if !_isactive(l) && l.track_stats # testmode with tracked stats
220226
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
221227
μ = reshape(l.μ, stats_shape)
@@ -228,7 +234,8 @@ function _norm_layer_forward(
228234
end
229235
end
230236

231-
o::Array{O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
237+
O = _promote_to_output(l, x)
238+
o::_basetype(typeof(x)){O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
232239
hasaffine(l) || return l.λ.(o)
233240

234241
γ = reshape(l.γ, affine_shape)

test/runtests.jl

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,55 @@ using CUDA
88

99
Random.seed!(0)
1010

11-
# @testset "Utils" begin
12-
# include("utils.jl")
13-
# end
11+
@testset "Utils" begin
12+
include("utils.jl")
13+
end
1414

15-
# @testset "Onehot" begin
16-
# include("onehot.jl")
17-
# end
15+
@testset "Onehot" begin
16+
include("onehot.jl")
17+
end
1818

19-
# @testset "Optimise" begin
20-
# include("optimise.jl")
21-
# end
19+
@testset "Optimise" begin
20+
include("optimise.jl")
21+
end
2222

23-
# @testset "Data" begin
24-
# include("data.jl")
25-
# end
23+
@testset "Data" begin
24+
include("data.jl")
25+
end
2626

27-
# @testset "Losses" begin
28-
# include("losses.jl")
29-
# include("ctc.jl")
30-
# CUDA.functional() && include("ctc-gpu.jl")
31-
# end
27+
@testset "Losses" begin
28+
include("losses.jl")
29+
include("ctc.jl")
30+
CUDA.functional() && include("ctc-gpu.jl")
31+
end
3232

3333
@testset "Layers" begin
34-
# include("layers/basic.jl")
34+
include("layers/basic.jl")
3535
include("layers/normalisation.jl")
36-
# include("layers/stateless.jl")
37-
# include("layers/recurrent.jl")
38-
# include("layers/conv.jl")
39-
# include("layers/upsample.jl")
40-
# include("layers/show.jl")
41-
end
42-
43-
# @testset "outputsize" begin
44-
# using Flux: outputsize
45-
# include("outputsize.jl")
46-
# end
47-
48-
# @testset "CUDA" begin
49-
# if CUDA.functional()
50-
# include("cuda/runtests.jl")
51-
# else
52-
# @warn "CUDA unavailable, not testing GPU support"
53-
# end
54-
# end
55-
56-
# @static if VERSION == v"1.6"
57-
# using Documenter
58-
# @testset "Docs" begin
59-
# DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
60-
# doctest(Flux)
61-
# end
62-
# end
36+
include("layers/stateless.jl")
37+
include("layers/recurrent.jl")
38+
include("layers/conv.jl")
39+
include("layers/upsample.jl")
40+
include("layers/show.jl")
41+
end
42+
43+
@testset "outputsize" begin
44+
using Flux: outputsize
45+
include("outputsize.jl")
46+
end
47+
48+
@testset "CUDA" begin
49+
if CUDA.functional()
50+
include("cuda/runtests.jl")
51+
else
52+
@warn "CUDA unavailable, not testing GPU support"
53+
end
54+
end
55+
56+
@static if VERSION == v"1.6"
57+
using Documenter
58+
@testset "Docs" begin
59+
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
60+
doctest(Flux)
61+
end
62+
end

0 commit comments

Comments
 (0)