-
Notifications
You must be signed in to change notification settings - Fork 51
Description
It would be nice to add the option to reweight a subset of the edges quickly (per shot) without the full internal representation of the matching graph needing to be recreated each time.
This could be done via an additional argument to the decode methods, e.g. Matching.decode could be:
import numpy.typing as npt
import numpy as np
def decode(self,
z: Union[np.ndarray, List[bool], List[int]],
*,
return_weight: bool = False,
enable_correlations: bool = False,
edge_reweights: npt.NDArray[np.float64] = None
**kwargs
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
r"""
Here each row edge_reweights[i, :]
of edge_reweights
would be a reweight rule. The first two columns of edge_reweights
would specify the node endpoints of the edge to reweight, and the third column is the new edge weight (as a float) to reweight to. E.g. edge_reweights[i, :] == np.array([4, 5, 2.4], dtype=np.float64)
means "give edge (4, 5) a new weight of 2.4 and undo it after the end of the shot". Similarly edge_reweights[i, :] == np.array([10, -1, 3.1], dtype=np.float64)
means give boundary edge (10,) a new weight of 3.1 (following this convention for boundary edges).
For Matching.decode_batch
the edge_reweights
argument could be a python list of arrays (one element of the list corresponding to each shot in the batch).
Some subtleties include whether this should be done before or after the first pass of correlated matching (if correlations are enabled), or if this should be an option. It probably makes sense for the new weight to be given as a float (and converted to an even integer internally), at least on the python side.