Skip to content

[Bug] object has no attribute _differentiable_kwargs #101

@jaghili

Description

@jaghili

🐛 Bug

  • Install torch==2.0.1
  • Install linear_operator 0.5.3 with pip

To reproduce

I took the snippet from the README

import linear_operator
import torch

class DiagLinearOperator(linear_operator.LinearOperator):
    r"""
    A LinearOperator representing a diagonal matrix.
    """
    def __init__(self, diag):
        # diag: the vector that defines the diagonal of the matrix
        self.diag = diag

    def _matmul(self, v):
        return self.diag.unsqueeze(-1) * v

    def _size(self):
        return torch.Size([*self.diag.shape, self.diag.size(-1)])

    def _transpose_nonbatch(self):
        return self  # Diagonal matrices are symmetric

    # this function is optional, but it will accelerate computation
    def logdet(self):
        return self.diag.log().sum(dim=-1)
# ...

D = DiagLinearOperator(torch.tensor([1., 2., 3.]))
# Represents the matrix
#   [[1., 0., 0.],
#    [0., 2., 0.],
#    [0., 0., 3.]]
torch.matmul(D, torch.tensor([4., 5., 6.]))
# Returns [4., 10., 18.]

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/jagh/codes/ng/src/a.py", line 31, in <module>
    torch.matmul(D, torch.tensor([4., 5., 6.]))
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2970, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1839, in matmul
    return Matmul.apply(self.representation_tree(), other, *self.representation())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2072, in representation_tree
    return LinearOperatorRepresentationTree(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py", line 8, in __init__
    self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys()
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DiagLinearOperator' object has no attribute '_differentiable_kwargs'

Expected Behavior

Snippet should return [4., 10., 18.]

Additional context

I added self._differentiable_kwargs = { some dict }, which seems by pass the problem, but I get another message with self._nondifferentiable_kwargs I don't know how to setup. Did I miss something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions