Skip to content

Commit 59c707c

Browse files
authored
Add the core properties to Config object (#49)
1 parent 2f600d2 commit 59c707c

File tree

8 files changed

+86
-21
lines changed

8 files changed

+86
-21
lines changed

.pre-commit-config.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.6.0
4+
hooks:
5+
- id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/astral-sh/ruff-pre-commit
9+
# Ruff version.
10+
rev: v0.11.9
11+
hooks:
12+
# Run the linter.
13+
- id: ruff
14+
args: [--fix]
15+
# Run the formatter.
16+
- id: ruff-format

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ refers to hydrogen-3.
1212

1313
[Triton]: https://github.com/triton-lang/triton
1414

15-
> ⚠️ **Early Development Warning**
15+
> ⚠️ **Early Development Warning**
1616
> Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated!
1717
1818
## Example
@@ -27,13 +27,13 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2727
m, k = x.size()
2828
k, n = y.size()
2929
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
30-
30+
3131
for tile_m, tile_n in hl.tile([m, n]):
3232
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
3333
for tile_k in hl.tile(k):
3434
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
3535
out[tile_m, tile_n] = acc
36-
36+
3737
return out
3838
```
3939

@@ -204,7 +204,8 @@ Alternatively, you may install from source for development purposes:
204204
```bash
205205
git clone https://github.com/pytorch-labs/helion.git
206206
cd helion
207-
python setup.py develop
207+
# To install in editable w/ required dev packages
208+
pip install -e .'[dev]'
208209
````
209210
This installs Helion in "editable" mode so that changes to the source
210211
code take effect without needing to reinstall.

helion/_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def code_and_output(
3737
**kwargs: object,
3838
) -> tuple[str, object]:
3939
if kwargs:
40-
config = Config(**kwargs)
40+
config = Config(**kwargs) # pyre-ignore[6]
4141
elif fn.configs:
4242
(config,) = fn.configs
4343
else:

helion/autotuner/config_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
249249
use_yz_grid = fn(BooleanFragment())
250250
if config.get("l2_grouping", 1) == 1 and isinstance(block_sizes[0], list):
251251
config["use_yz_grid"] = use_yz_grid
252-
return helion.Config(config)
252+
return helion.Config(**config) # pyre-ignore[6]
253253

254254

255255
class BlockSizeSpec:

helion/runtime/config.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,50 @@
1616
class Config(Mapping[str, object]):
1717
config: dict[str, object]
1818

19-
def __init__(self, config: object = None, **kwargs: object) -> None:
20-
if config is not None:
21-
assert not kwargs
22-
assert isinstance(config, (dict, Config))
23-
self.config = {**config}
24-
else:
25-
self.config = kwargs
19+
def __init__(
20+
self,
21+
*,
22+
# Core properties
23+
block_sizes: list[int | list[int]] | None = None,
24+
loop_orders: list[list[int]] | None = None,
25+
reduction_loops: list[int | None] | None = None,
26+
num_warps: int | None = None,
27+
num_stages: int | None = None,
28+
l2_grouping: int | None = None,
29+
use_yz_grid: bool | None = None,
30+
indexing: IndexingLiteral | None = None,
31+
# For user-defined properties
32+
**kwargs: object,
33+
) -> None:
34+
"""
35+
Initialize a Config object.
36+
37+
Args:
38+
block_sizes: Controls tile sizes for hl.tile invocations.
39+
loop_orders: Permutes iteration order of tiles.
40+
reduction_loops: Configures reduction loop behavior.
41+
num_warps: Number of warps per block.
42+
num_stages: Number of stages for software pipelining.
43+
l2_grouping: Reorders program IDs for L2 cache locality.
44+
use_yz_grid: Whether to use yz grid dimensions.
45+
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
46+
**kwargs: Additional user-defined configuration parameters.
47+
"""
48+
self.config = {}
49+
core_props = {
50+
"block_sizes": block_sizes,
51+
"loop_orders": loop_orders,
52+
"reduction_loops": reduction_loops,
53+
"num_warps": num_warps,
54+
"num_stages": num_stages,
55+
"indexing": indexing,
56+
"l2_grouping": l2_grouping,
57+
"use_yz_grid": use_yz_grid,
58+
}
59+
for key, value in core_props.items():
60+
if value is not None:
61+
self.config[key] = value
62+
self.config.update(kwargs)
2663

2764
def __getitem__(self, key: str) -> object:
2865
return self.config[key]
@@ -56,7 +93,7 @@ def to_json(self) -> str:
5693
def from_json(cls, json_str: str) -> Config:
5794
"""Create a Config object from a JSON string."""
5895
config_dict = json.loads(json_str)
59-
return cls(config_dict)
96+
return cls(**config_dict) # Changed to use dictionary unpacking
6097

6198
def save(self, path: str | Path) -> None:
6299
"""Save the config to a JSON file."""
@@ -92,12 +129,12 @@ def l2_grouping(self) -> int:
92129
return cast("int", self.config.get("l2_grouping", 1))
93130

94131
@property
95-
def use_yz_grid(self) -> int:
132+
def use_yz_grid(self) -> bool:
96133
return cast("bool", self.config.get("use_yz_grid", False))
97134

98135
@property
99136
def indexing(self) -> IndexingLiteral:
100-
return cast("IndexingLiteral", self.config.get("indexing", "pointer"))
137+
return self.config.get("indexing", "pointer") # type: ignore
101138

102139

103140
def _list_to_tuple(x: object) -> object:

helion/runtime/kernel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(
6262
self.fn = fn
6363
self.signature: inspect.Signature = inspect.signature(fn)
6464
self.settings: Settings = settings or Settings.default()
65-
self.configs: list[Config] = [*map(Config, configs or ())]
65+
self.configs: list[Config] = [
66+
Config(**c) if isinstance(c, dict) else c for c in configs or []
67+
]
6668
# pyre-fixme[11]: BoundKernel undefined?
6769
self._bound_kernels: dict[Hashable, BoundKernel] = {}
6870
if any(
@@ -295,7 +297,9 @@ def to_triton_code(self, config: ConfigLike) -> str:
295297
:rtype: str
296298
"""
297299
with self.env:
298-
config = Config(config)
300+
if not isinstance(config, Config):
301+
# pyre-ignore[6]
302+
config = Config(**config)
299303
self.env.config_spec.normalize(config)
300304
root = generate_ast(self.host_fn, config)
301305
return get_needed_imports(root) + unparse(root)
@@ -310,7 +314,7 @@ def compile_config(self, config: ConfigLike) -> CompiledConfig:
310314
:rtype: Callable[..., object]
311315
"""
312316
if not isinstance(config, Config):
313-
config = Config(config)
317+
config = Config(**config) # pyre-ignore[6]
314318
if (rv := self._compile_cache.get(config)) is not None:
315319
return rv
316320
triton_code = self.to_triton_code(config)
@@ -375,7 +379,7 @@ def set_config(self, config: ConfigLike) -> None:
375379
:type config: ConfigLike
376380
"""
377381
if not isinstance(config, Config):
378-
config = Config(config)
382+
config = Config(**config) # pyre-ignore[6]
379383
self._run = self.compile_config(config)
380384

381385
def __call__(self, *args: object) -> object:

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ dependencies = [
2121
"typing_extensions>=4.0.0",
2222
]
2323

24+
[project.optional-dependencies]
25+
dev = [
26+
"expecttest",
27+
"pytest",
28+
"pre-commit"
29+
]
30+
2431
[project.urls]
2532
Homepage = "https://github.com/pytorch-labs/helion"
2633
Issues = "https://github.com/pytorch-labs/helion/issues"
@@ -67,4 +74,3 @@ force-sort-within-sections = true
6774

6875
[tool.setuptools]
6976
license-files = ["LICENSE"]
70-

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
expecttest
22
pytest
33
typing_extensions
4+
pre-commit

0 commit comments

Comments
 (0)