Skip to content

is Flux.huber_loss type-unstable ? #2459

@filchristou

Description

@filchristou

It looks like Flux.huber_loss is type unstable when it comes to Zygote autodiff ?

using Flux, Zygote
import Statistics: mean

function internfunc_nobroad(m, x, y)
    modelvals = m(x)
    Flux.mse(modelvals, y)
end

function internfunc_nobroad_huberloss(m, x, y)
    modelvals = m(x)
    Flux.huber_loss(modelvals, y)
end

function wrapfunc(model, xdata, ydata, func)
    grad = let xdata=xdata, ydata=ydata
        Zygote.gradient(m -> func(m, xdata, ydata), model)
    end
    return grad
end

fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))

fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)

image

julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)

image

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions