Skip to content

BUG: Evaluating pytensor.gradient of logp containing MinimizeOp attempts to concatenate tensors of incompatible rank #7872

@Michal-Novomestsky

Description

@Michal-Novomestsky

Describe the issue:

When attempting to backprop through the logp of a graph which contains a MinimizeOp, an error is thrown in which it attempts to concatenate tensors of various ranks. I believe this can be remedied by flattening at_least_2d in lines 333-337:

df_dtheta = concatenate(
[
atleast_2d(jac_col, left=False).flatten()
for jac_col in cast(list[TensorVariable], df_dtheta_columns)
],

Reproduceable code example:

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize

rng = np.random.default_rng(12345)
n = 10000
d = 10

mu = np.ones(d)
cov = np.diag(np.ones(d))

# Make a simple gaussian with mean x of gaussian prior
with pm.Model() as model:
    x = pm.MvNormal("x", mu=mu, cov=cov)

    y_obs = rng.multivariate_normal(mean=mu, cov=cov, size=n)

    y = pm.MvNormal(
        "y",
        mu=x,
        cov=cov,
        observed=y_obs,
    )

    logp = model.logp()

    # Find the mean which minimizes the logp
    x0, _ = minimize(
        objective=-logp,
        x=model.rvs_to_values[x],
        method="BFGS",
        optimizer_kwargs={"tol": 1e-8},
    )

    y = pytensor.graph.replace.graph_replace(y, {x: x0})
    
    # tg.grad throws the error
    for var in model.value_vars:
        var = model.rvs_to_values[x]
        logp = pt.sum(pm.logp(y, var))
        tg.grad(logp, var)

Error message:

Cell In[5], line 41
     39 var = model.rvs_to_values[x]
     40 logp = pt.sum(pm.logp(y, var))
---> 41 tg.grad(logp, var)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    744     if hasattr(g.type, "dtype"):
    745         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
    748     var_to_app_to_idx, grad_dict, _wrt, cost_name
    749 )
    751 rval: MutableSequence[Variable | None] = list(_rval)
    753 for i in range(len(_rval)):

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1538     # end if cache miss
   1539     return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
   1543 return rval

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1494 for node in node_to_idx:
   1495     for idx in node_to_idx[node]:
-> 1496         term = access_term_cache(node)[idx]
   1498         if not isinstance(term, Variable):
   1499             raise TypeError(
   1500                 f"{node.op}.grad returned {type(term)}, expected"
   1501                 " Variable instance."
   1502             )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
   1318         if o_shape != g_shape:
   1319             raise ValueError(
   1320                 "Got a gradient of shape "
   1321                 + str(o_shape)
   1322                 + " on an output of shape "
   1323                 + str(g_shape)
   1324             )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1328 if input_grads is None:
   1329     raise TypeError(
   1330         f"{node.op}.grad returned NoneType, expected iterable."
   1331     )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:565, in MinimizeOp.L_op(self, inputs, outputs, output_grads)
    560 implicit_f = grad(inner_fx, inner_x)
    562 df_dx, *df_dtheta_columns = jacobian(
    563     implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
    564 )
--> 565 grad_wrt_args = implict_optimization_grads(
    566     df_dx=df_dx,
    567     df_dtheta_columns=df_dtheta_columns,
    568     args=args,
    569     x_star=x_star,
    570     output_grad=output_grad,
    571     fgraph=self.fgraph,
    572 )
    574 return [zeros_like(x), *grad_wrt_args]

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
    290 r"""
    291 Compute gradients of an optimization problem with respect to its parameters.
    292 
   (...)    329     The function graph that contains the inputs and outputs of the optimization problem.
    330 """
    331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
    334     [
    335         atleast_2d(jac_col, left=False)
    336         for jac_col in cast(list[TensorVariable], df_dtheta_columns)
    337     ],
    338     axis=-1,
    339 )
    341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
    343 df_dx_star, df_dtheta_star = cast(
    344     list[TensorVariable],
    345     graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
    346 )

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2980, in concatenate(tensor_list, axis)
   2973 if not isinstance(tensor_list, tuple | list):
   2974     raise TypeError(
   2975         "The 'tensors' argument must be either a tuple "
   2976         "or a list, make sure you did not forget () or [] around "
   2977         "arguments of concatenate.",
   2978         tensor_list,
   2979     )
-> 2980 return join(axis, *tensor_list)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2794, in join(axis, *tensors_list)
   2792     return tensors_list[0]
   2793 else:
-> 2794     return _join(axis, *tensors_list)

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2487, in Join.make_node(self, axis, *tensors)
   2484 ndim = tensors[0].type.ndim
   2486 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2487     raise TypeError(
   2488         "Only tensors with the same number of dimensions can be joined. "
   2489         f"Input ndims were: {[x.ndim for x in tensors]}"
   2490     )
   2492 try:
   2493     static_axis = int(get_scalar_constant_value(axis))

TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 2, 3, 4, 2, 2, 2, 3, 3, 4]

PyMC version information:

Numpy: 1.26.4
PyMC: 0+untagged.10301.gdc7cfee.dirty
PyTensor: 2.31.7

Context for the issue:

Bug was originally caused by trying to run pm.sample against a logp containing a MinimizeOp.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions