Skip to content

Error differentiating ResNet from torchvision #24

@lorenzoh

Description

@lorenzoh

In trying to get an image classification example working for FastAI.jl, I tried training a pretrained ResNet model from torchvision. The forward pass works fine, but when differentiating, I get an error.

I think this is actually a limitation of functorch, but figured I'd report here.

Minimum working example (last line fails on cpu and gpu):

using Cuda, PyCallChainRules

torchvision = pyimport("torchvision")

model = TorchModuleWrapper(torchvision.models.resnet18(pretrained=true).to("cuda:0"))
xs = randn(Float32, 128, 128, 3, 1) |> cu
ys = model(xs)
Zygote.gradient(() -> Flux.mse(model(xs), ys))
Stacktrace
julia> Zygote.gradient(() -> Flux.mae(model(xs), ys))
ERROR: PyError ($(Expr(:escape, :(ccall(#= /home/lorenz/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) 
RuntimeError('During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.')
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 243, in vjp
    try:
  File "/home/lorenz/.julia/packages/PyCall/7a7w0/src/pyeval.jl", line 3, in newfn
    const Py_eval_input = 258
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/make_functional.py", line 259, in forward
    @staticmethod
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 283, in forward
    return self._forward_impl(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 267, in _forward_impl
    x = self.bn1(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 148, in forward
    self.num_batches_tracked.add_(1)  # type: ignore[has-type]

Stacktrace:
[1] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:62 [inlined]
[2] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:66 [inlined]
[3] _handle_error(msg::String)
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/exception.jl:83
[4] macro expansion
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:97 [inlined]
[5] #107
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 [inlined]
[6] disable_sigint
@ ./c.jl:458 [inlined]
[7] __pycall!
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:42 [inlined]
[8] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, nargs::Int64, kw::Ptr{Nothing})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:29
[9] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:11
[10] (::PyObject)(::PyObject, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[11] (::PyObject)(::PyObject, ::Vararg{Any})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[12] rrule(wrap::TorchModuleWrapper, args::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCallChainRules.Torch ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:62
[13] rrule
@ ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:57 [inlined]
[14] rrule
@ ~/.julia/packages/ChainRulesCore/RbX5a/src/rules.jl:134 [inlined]
[15] chain_rrule
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:216 [inlined]
[16] macro expansion
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0 [inlined]
[17] _pullback(ctx::Zygote.Context, f::TorchModuleWrapper, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:9
[18] _pullback
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86 [inlined]
[19] _pullback(::Zygote.Context, ::var"#27#28")
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[20] _pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:34
[21] pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:40
[22] gradient(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:75
[23] top-level scope
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions