Skip to content

Commit f7c93de

Browse files
authored
Fix import of transformmodindex on nightly (#52)
stack-info: PR: #52, branch: drisspg/stack/1
1 parent 7586742 commit f7c93de

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

attn_gym/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
_vmap_for_bhqkv,
1111
_ModificationType,
1212
)
13-
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
13+
# TODO This was moved on nightly, this enables 2.5 and 2.6 | we should remove this once 2.5 is no longer supported
14+
try:
15+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
16+
except ImportError:
17+
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
1418
from contextlib import nullcontext
1519

1620
Tensor = torch.Tensor

0 commit comments

Comments
 (0)