-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Describe the issue:
When attempting to backprop through the logp
of a graph which contains a MinimizeOp
, an error is thrown in which it attempts to concatenate tensors of various ranks. I believe this can be remedied by flattening at_least_2d
in lines 333-337:
df_dtheta = concatenate(
[
atleast_2d(jac_col, left=False).flatten()
for jac_col in cast(list[TensorVariable], df_dtheta_columns)
],
Reproduceable code example:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytensor.gradient as tg
import pytensor
from pytensor.tensor.optimize import minimize
rng = np.random.default_rng(12345)
n = 10000
d = 10
mu = np.ones(d)
cov = np.diag(np.ones(d))
# Make a simple gaussian with mean x of gaussian prior
with pm.Model() as model:
x = pm.MvNormal("x", mu=mu, cov=cov)
y_obs = rng.multivariate_normal(mean=mu, cov=cov, size=n)
y = pm.MvNormal(
"y",
mu=x,
cov=cov,
observed=y_obs,
)
logp = model.logp()
# Find the mean which minimizes the logp
x0, _ = minimize(
objective=-logp,
x=model.rvs_to_values[x],
method="BFGS",
optimizer_kwargs={"tol": 1e-8},
)
y = pytensor.graph.replace.graph_replace(y, {x: x0})
# tg.grad throws the error
for var in model.value_vars:
var = model.rvs_to_values[x]
logp = pt.sum(pm.logp(y, var))
tg.grad(logp, var)
Error message:
Cell In[5], line 41
39 var = model.rvs_to_values[x]
40 logp = pt.sum(pm.logp(y, var))
---> 41 tg.grad(logp, var)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
744 if hasattr(g.type, "dtype"):
745 assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
748 var_to_app_to_idx, grad_dict, _wrt, cost_name
749 )
751 rval: MutableSequence[Variable | None] = list(_rval)
753 for i in range(len(_rval)):
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1538 # end if cache miss
1539 return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
1543 return rval
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
1494 for node in node_to_idx:
1495 for idx in node_to_idx[node]:
-> 1496 term = access_term_cache(node)[idx]
1498 if not isinstance(term, Variable):
1499 raise TypeError(
1500 f"{node.op}.grad returned {type(term)}, expected"
1501 " Variable instance."
1502 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
1318 if o_shape != g_shape:
1319 raise ValueError(
1320 "Got a gradient of shape "
1321 + str(o_shape)
1322 + " on an output of shape "
1323 + str(g_shape)
1324 )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1328 if input_grads is None:
1329 raise TypeError(
1330 f"{node.op}.grad returned NoneType, expected iterable."
1331 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:565, in MinimizeOp.L_op(self, inputs, outputs, output_grads)
560 implicit_f = grad(inner_fx, inner_x)
562 df_dx, *df_dtheta_columns = jacobian(
563 implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
564 )
--> 565 grad_wrt_args = implict_optimization_grads(
566 df_dx=df_dx,
567 df_dtheta_columns=df_dtheta_columns,
568 args=args,
569 x_star=x_star,
570 output_grad=output_grad,
571 fgraph=self.fgraph,
572 )
574 return [zeros_like(x), *grad_wrt_args]
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
290 r"""
291 Compute gradients of an optimization problem with respect to its parameters.
292
(...) 329 The function graph that contains the inputs and outputs of the optimization problem.
330 """
331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
334 [
335 atleast_2d(jac_col, left=False)
336 for jac_col in cast(list[TensorVariable], df_dtheta_columns)
337 ],
338 axis=-1,
339 )
341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
343 df_dx_star, df_dtheta_star = cast(
344 list[TensorVariable],
345 graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
346 )
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2980, in concatenate(tensor_list, axis)
2973 if not isinstance(tensor_list, tuple | list):
2974 raise TypeError(
2975 "The 'tensors' argument must be either a tuple "
2976 "or a list, make sure you did not forget () or [] around "
2977 "arguments of concatenate.",
2978 tensor_list,
2979 )
-> 2980 return join(axis, *tensor_list)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2794, in join(axis, *tensors_list)
2792 return tensors_list[0]
2793 else:
-> 2794 return _join(axis, *tensors_list)
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
249 def __call__(
250 self, *inputs: Any, name=None, return_list=False, **kwargs
251 ) -> Variable | list[Variable]:
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...) 291
292 """
--> 293 node = self.make_node(*inputs, **kwargs)
294 if name is not None:
295 if len(node.outputs) == 1:
File ~/git/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:2487, in Join.make_node(self, axis, *tensors)
2484 ndim = tensors[0].type.ndim
2486 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2487 raise TypeError(
2488 "Only tensors with the same number of dimensions can be joined. "
2489 f"Input ndims were: {[x.ndim for x in tensors]}"
2490 )
2492 try:
2493 static_axis = int(get_scalar_constant_value(axis))
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 2, 3, 4, 2, 2, 2, 3, 3, 4]
PyMC version information:
Numpy: 1.26.4
PyMC: 0+untagged.10301.gdc7cfee.dirty
PyTensor: 2.31.7
Context for the issue:
Bug was originally caused by trying to run pm.sample
against a logp
containing a MinimizeOp
.