Skip to content

Commit 2ee4c0f

Browse files
superbobryjax authors
authored andcommitted
Added installation instructions to the error in _pallas_call_lowering
PiperOrigin-RevId: 621168804
1 parent 4c41c12 commit 2ee4c0f

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,15 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
496496
return name
497497

498498

499+
def _unsupported_lowering_error(platform: str) -> Exception:
500+
return ValueError(
501+
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
502+
" install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install"
503+
" jaxlib TPU and libtpu. See"
504+
" https://jax.readthedocs.io/en/latest/installation.html."
505+
)
506+
507+
499508
def _pallas_call_lowering(
500509
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
501510
):
@@ -526,7 +535,7 @@ def _pallas_call_lowering(
526535
ctx, *in_nodes, interpret=interpret, **params
527536
)
528537

529-
raise ValueError(f"Cannot lower pallas_call on platform: {platform}.")
538+
raise _unsupported_lowering_error(platform)
530539

531540

532541
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)

0 commit comments

Comments
 (0)