Skip to content

Commit 2859d41

Browse files
committed
Add HELION_DEV_LOW_VRAM env var for low GPU memory machines
Some dev machine (e.g. gpu laptop) has low VRAM which causes some tritonbench inputs to OOM. This PR adds HELION_DEV_LOW_VRAM env var and uses smaller inputs if the env var is set. User can choose to opt into this mode by setting the env var, instead of passively having smaller inputs due to low VRAM.
1 parent dcfa500 commit 2859d41

File tree

3 files changed

+8
-39
lines changed

3 files changed

+8
-39
lines changed

examples/jagged_mean.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
8-
from helion.utils import get_gpu_memory_info
910

10-
# TritonBench configuration - adjust based on available GPU memory
11-
if get_gpu_memory_info()[0] < 16.0:
11+
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
12+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
1213
# Low memory configuration
1314
TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64}
1415

examples/rms_norm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
@@ -8,7 +10,8 @@
810

911
# TritonBench configuration
1012
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
11-
TRITONBENCH_ARGS = {"num_inputs": 3}
13+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
14+
TRITONBENCH_ARGS = {"num_inputs": 3}
1215

1316

1417
@helion.kernel(static_shapes=True)

helion/utils.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)