Skip to content

Commit 1ac78b5

Browse files
Merge #1393
1393: remove implicit conversions r=DhairyaLGandhi a=CarloLucibello For some layers, we currently downcast the input type from Float64 to Float32 if the weights are Float32. I think we should follow julia's promotion rules, users should provide Float32 inputs if they want Float32 outputs. This also simplifies layers' definitions. This change may have a performance impact on some people code that may go unnoticed, so while we want to promote good practice, I can understand if someone opposes this change. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [ ] Documentation, if applicable - [x] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
2 parents 33f99ef + 3b10434 commit 1ac78b5

File tree

5 files changed

+17
-43
lines changed

5 files changed

+17
-43
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## v0.12.0
44

5-
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405)
5+
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
6+
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
67
* Excise datasets in favour of other providers in the julia ecosystem.
78
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
89
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379).

docs/src/performance.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@ not because the operations are faster, but because the memory usage is halved.
1313
Which means allocations occur much faster.
1414
And you use less memory.
1515

16-
1716
## Preserve inputs' types
1817

1918
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
2019
they should also preserve the type of their inputs.
2120

2221
A very artificial example using an activation function like
2322

24-
```
25-
my_tanh(x) = Float64(tanh(x))
23+
```julia
24+
my_tanh(x) = Float64(tanh(x))
2625
```
2726

2827
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
@@ -35,20 +34,21 @@ you will see a large slow-down.
3534
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
3635
E.g. the following will have run into the same problem as above:
3736

38-
```
39-
leaky_tanh(x) = 0.01*x + tanh(x)
37+
```julia
38+
leaky_tanh(x) = 0.01*x + tanh(x)
4039
```
4140

4241
While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
43-
```
44-
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
45-
```
4642

43+
```julia
44+
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
45+
```
4746

4847
## Evaluate batches as Matrices of features
4948

5049
While it can sometimes be tempting to process your observations (feature vectors) one at a time
5150
e.g.
51+
5252
```julia
5353
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
5454
sum(zip(xs, ys)) do (x, y_target)

src/layers/basic.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ end
121121

122122
function (a::Dense)(x::AbstractArray)
123123
W, b, σ = a.W, a.b, a.σ
124-
# reshape to handle dims > 1 as batch dimensions
125124
sz = size(x)
126-
x = reshape(x, sz[1], :)
125+
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
127126
x = σ.(W*x .+ b)
128127
return reshape(x, :, sz[2:end]...)
129128
end
@@ -134,14 +133,6 @@ function Base.show(io::IO, l::Dense)
134133
print(io, ")")
135134
end
136135

137-
# Try to avoid hitting generic matmul in some simple cases
138-
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
139-
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
140-
invoke(a, Tuple{AbstractArray}, x)
141-
142-
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
143-
a(T.(x))
144-
145136
"""
146137
Diagonal(in::Integer)
147138

src/layers/conv.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,6 @@ function Base.show(io::IO, l::Conv)
164164
print(io, ")")
165165
end
166166

167-
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
168-
invoke(a, Tuple{AbstractArray}, x)
169-
170-
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
171-
a(T.(x))
172167

173168
"""
174169
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
@@ -265,11 +260,6 @@ function Base.show(io::IO, l::ConvTranspose)
265260
print(io, ")")
266261
end
267262

268-
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
269-
invoke(a, Tuple{AbstractArray}, x)
270-
271-
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
272-
a(T.(x))
273263

274264
function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
275265
calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride)
@@ -363,11 +353,6 @@ function Base.show(io::IO, l::DepthwiseConv)
363353
print(io, ")")
364354
end
365355

366-
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
367-
invoke(a, Tuple{AbstractArray}, x)
368-
369-
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
370-
a(T.(x))
371356

372357
"""
373358
CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
@@ -449,12 +434,6 @@ function Base.show(io::IO, l::CrossCor)
449434
print(io, ")")
450435
end
451436

452-
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
453-
invoke(a, Tuple{AbstractArray}, x)
454-
455-
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
456-
a(T.(x))
457-
458437
"""
459438
AdaptiveMaxPool(out::NTuple)
460439

test/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,13 @@ end
145145

146146
@testset "Precision" begin
147147
m = Chain(Dense(10, 5, relu), Dense(5, 2))
148-
x = rand(10)
148+
x64 = rand(Float64, 10)
149+
x32 = rand(Float32, 10)
149150
@test eltype(m[1].W) == Float32
150-
@test eltype(m(x)) == Float32
151-
@test eltype(f64(m)(x)) == Float64
151+
@test eltype(m(x32)) == Float32
152+
@test eltype(m(x64)) == Float64
153+
@test eltype(f64(m)(x32)) == Float64
154+
@test eltype(f64(m)(x64)) == Float64
152155
@test eltype(f64(m)[1].W) == Float64
153156
@test eltype(f32(f64(m))[1].W) == Float32
154157
end

0 commit comments

Comments
 (0)