Skip to content

Commit 13c47cf

Browse files
authored
Add HELION_FORCE_AUTOTUNE=1 and update readme (#132)
1 parent 10ad012 commit 13c47cf

File tree

4 files changed

+51
-22
lines changed

4 files changed

+51
-22
lines changed

README.md

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,30 @@ Changing these options results in often significantly different
222222
output Triton code, allowing the autotuner to explore a wide range of
223223
implementations from a single Helion kernel.
224224

225-
## Logs and Debugging
225+
## Settings for Development and Debugging
226226

227-
`HELION_LOGS` is the recommended way to emit debugging information from Helion.
227+
When developing kernels with Helion, you might prefer skipping autotuning for faster iteration. To
228+
do this, set the environment variable `HELION_USE_DEFAULT_CONFIG=1` or use the decorator argument
229+
`@helion.kernel(use_default_config=True)`. **Warning:** The default configuration is slow and not intended for
230+
production or performance testing.
231+
232+
To view the generated Triton code, set the environment variable `HELION_PRINT_OUTPUT_CODE=1` or include
233+
`print_output_code=True` in the `@helion.kernel` decorator. This prints the Triton code to `stderr`, which is
234+
helpful for debugging and understanding Helion's compilation process. One can also use
235+
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.
236+
237+
To force autotuning, bypassing provided configurations, set `HELION_FORCE_AUTOTUNE=1` or invoke `foo_kernel.autotune(args,
238+
force=True)`.
239+
240+
Additional settings are available in
241+
[settings.py](https://github.com/pytorch-labs/helion/blob/main/helion/runtime/settings.py). If both an environment
242+
variable and a kernel decorator argument are set, the kernel decorator argument takes precedence, and the environment
243+
variable will be ignored.
244+
245+
Enable logging by setting the environment variable `HELION_LOGS=all` for INFO-level logs, or `HELION_LOGS=+all`
246+
for DEBUG-level logs. Alternatively, you can specify logging for specific modules using a comma-separated list
247+
(e.g., `HELION_LOGS=+helion.runtime.kernel`).
228248

229-
An example to this is
230-
```
231-
HELION_LOGS=helion.runtime.kernel python examples/add.py
232-
```
233-
will emit the generated Triton kernels at INFO level logging.
234-
Adding `+` in front of path like `+helion.runtime.kernel` will emit logs at
235-
DEBUG level.
236249

237250
## Requirements
238251

helion/_logging/_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class LogRegistery:
3131
log_levels: dict[str, int] = field(default_factory=dict)
3232

3333

34-
_LOG_REGISTERY = LogRegistery()
34+
_LOG_REGISTERY = LogRegistery({"all": ["helion"]})
3535

3636

3737
def parse_log_value(value: str) -> None:

helion/runtime/kernel.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,22 +194,28 @@ def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]:
194194
def autotune(
195195
self,
196196
args: Sequence[object],
197+
*,
198+
force: bool = False,
197199
**options: object,
198200
) -> Config:
199201
"""
200-
Perform autotuning to find the optimal configuration for
201-
the kernel. This uses the default setting, you can call
202-
helion.autotune.* directly for more customization.
202+
Perform autotuning to find the optimal configuration for the kernel. This uses the
203+
default setting, you can call helion.autotune.* directly for more customization.
204+
205+
If config= or configs= is provided to helion.kernel(), the search will be restricted to
206+
the provided configs. Use force=True to ignore the provided configs.
203207
204208
Mutates (the bound version of) self so that `__call__` will run the best config found.
205209
206210
:param args: Example arguments used for benchmarking during autotuning.
207211
:type args: list[object]
212+
:param force: If True, force full autotuning even if a config is provided.
213+
:type force: bool
208214
:return: The best configuration found during autotuning.
209215
:rtype: Config
210216
"""
211217
args = self.normalize_args(*args)
212-
return self.bind(args).autotune(args, **options)
218+
return self.bind(args).autotune(args, force=force, **options)
213219

214220
def __call__(self, *args: object, **kwargs: object) -> object:
215221
"""
@@ -278,8 +284,6 @@ def __init__(self, kernel: Kernel, args: tuple[object, ...]) -> None:
278284
self.host_function: HostFunction = HostFunction(
279285
self.kernel.fn, self.fake_args, constexpr_args
280286
)
281-
if len(kernel.configs) == 1:
282-
self.set_config(kernel.configs[0])
283287

284288
@property
285289
def settings(self) -> Settings:
@@ -370,24 +374,34 @@ def _debug_str(self) -> str:
370374
def autotune(
371375
self,
372376
args: Sequence[object],
377+
*,
378+
force: bool = False,
373379
**kwargs: object,
374380
) -> Config:
375381
"""
376-
Perform autotuning to find the optimal configuration for
377-
the kernel. This uses the default setting, you can call
378-
helion.autotune.* directly for more customization.
382+
Perform autotuning to find the optimal configuration for the kernel. This uses the
383+
default setting, you can call helion.autotune.* directly for more customization.
384+
385+
If config= or configs= is provided to helion.kernel(), the search will be restricted to
386+
the provided configs. Use force=True to ignore the provided configs.
379387
380388
Mutates self so that `__call__` will run the best config found.
381389
382390
:param args: Example arguments used for benchmarking during autotuning.
383391
:type args: list[object]
392+
:param force: If True, force full autotuning even if a config is provided.
393+
:type force: bool
384394
:return: The best configuration found during autotuning.
385395
:rtype: Config
386396
"""
387-
if self.kernel.configs:
388-
from ..autotuner import FiniteSearch
397+
force = force or self.settings.force_autotune
398+
if not force and self.kernel.configs:
399+
if len(self.kernel.configs) == 1:
400+
(config,) = self.kernel.configs
401+
else:
402+
from ..autotuner import FiniteSearch
389403

390-
config = FiniteSearch(self, args, self.configs).autotune()
404+
config = FiniteSearch(self, args, self.configs).autotune()
391405
else:
392406
from ..autotuner import DifferentialEvolutionSearch
393407

helion/runtime/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class _Settings:
6161
)
6262
autotune_precompile: bool = sys.platform != "win32"
6363
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
64+
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
6465

6566

6667
class Settings(_Settings):
@@ -79,6 +80,7 @@ class Settings(_Settings):
7980
"autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.",
8081
"autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.",
8182
"print_output_code": "If True, print the output code of the kernel to stderr.",
83+
"force_autotune": "If True, force autotuning even if a config is provided.",
8284
}
8385
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}
8486

0 commit comments

Comments
 (0)