Skip to content

tensor[tile] when tile size is 1 returns a 1D tensor, instead of a scalar #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/template_via_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check(n: int, k: int, m: int) -> None:

def epilogue(acc: torch.Tensor, tile: list[torch.Tensor]) -> torch.Tensor:
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
return torch.relu(acc + bias[tile])
return torch.relu(acc + bias[0, tile[1]])

def kernel_wrapper(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return matmul_with_epilogue(x, y, epilogue)
Expand Down
5 changes: 2 additions & 3 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,8 @@ def create(
assert len(index_values) == fake_value.ndim
index_expr = []
for i, idx in enumerate(index_values):
if fake_value.size(i) != 1:
stride = state.device_function.tensor_stride(fake_value, i).name
index_expr.append(f"{idx} * {stride}")
stride = state.device_function.tensor_stride(fake_value, i).name
index_expr.append(f"{idx} * {stride}")
Comment on lines -381 to +382
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this:
N = x.size(0)
for tile in hl.tile(N):
x_tile = x[tile]

When block_size=1, the if statement evaluates to be False, so the indexing ignore the N dimension, and generate
x_tile = tl.load(tile + tl.zeros([1], ...))

I'll add a test case for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this checking the tensor size not the block size?

if not index_expr:
shape_str = tile_strategy.shape_str(output_size)
index_expr.append(f"tl.zeros({shape_str}, {dtype})")
Expand Down
18 changes: 13 additions & 5 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,19 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
type_comment=None,
)
assert for_node.body is body
extra_body = [
statement_from_string(
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
),
]
extra_body = []
if block_size == 1:
extra_body.append(
statement_from_string(
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this do the same thing as arange? I'd expect we wuld need shape=[] or even just offset_var directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this does the same thing. We don't need to make this change to fix tile indexing when block_size=1.
However, why does grid_codegen handle block_size == 1 differently with tl.zeros instead tl.arange?

),
)
else:
extra_body.append(
statement_from_string(
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
),
)
mask_statement = self._setup_mask( # pyright: ignore[reportAttributeAccessIssue]
state, block_idx, block_size, index_var, end
)
Expand Down
144 changes: 87 additions & 57 deletions test/test_associative_scan.expected

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1384,13 +1384,14 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
load_1 = tl.load(y + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None)
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
load_2 = tl.load(epilogue_closure_0 + indices_1[None, :] * 1, None)
v_0 = load_2.to(tl.float32)
v_1 = acc + v_0
v_2 = tl.full([], 0, tl.int32)
v_3 = triton_helpers.maximum(v_2, v_1)
v_4 = v_3.to(tl.float16)
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
load_2 = tl.load(epilogue_closure_0 + (0 * 1024 + indices_1 * 1), None)
v_0 = load_2[None, :]
v_1 = v_0.to(tl.float32)
v_2 = acc + v_1
v_3 = tl.full([], 0, tl.int32)
v_4 = triton_helpers.maximum(v_3, v_2)
v_5 = v_4.to(tl.float16)
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_5, None)

def matmul_with_epilogue(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]):
m, k = x.size()
Expand Down Expand Up @@ -1528,4 +1529,3 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla
_BLOCK_SIZE_2 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)

2 changes: 1 addition & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_template_via_closure0(self):
args = (
torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16),
torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16),
lambda acc, tile: torch.relu(acc + bias[tile]),
lambda acc, tile: torch.relu(acc + bias[0, tile[1]]),
)
self.assertExpectedJournal(
check_example(
Expand Down
50 changes: 25 additions & 25 deletions test/test_loops.expected
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < b
for offset_3 in tl.range(0, d.to(tl.int32), step=1):
indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32)
indices_3 = offset_3 + tl.zeros([1], tl.int32)
load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0)
v_0 = tl_math.sin(load)
tl.store(out + (indices_0[:, None, None, None] * out_stride_0 + indices_1[None, :, None, None] * out_stride_1 + indices_2[None, None, :, None] * out_stride_2 + indices_3[None, None, None, :] * out_stride_3), v_0, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None])
Expand Down Expand Up @@ -197,7 +197,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0,
v_3 = 2.0
v_4 = in_x * v_3
for offset_2 in tl.range(2, 5, step=1):
indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32)
indices_2 = offset_2 + tl.zeros([1], tl.int32)
v_4_copy = v_4
in_x_0_copy = in_x_0
T0_copy = T0
Expand Down Expand Up @@ -245,13 +245,13 @@ import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
def _fn_kernel(x, end, out, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_0
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
load = tl.load(end + tl.zeros([], tl.int32), None)
load = tl.load(end + 0 * end_stride_0, None)
for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < load
Expand All @@ -267,7 +267,7 @@ def fn(x: torch.Tensor, end: torch.Tensor):
bs = 32
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
Expand All @@ -276,7 +276,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return make_precompiler(_fn_kernel)(x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestLoops.test_data_dependent_bounds2)
from __future__ import annotations
Expand All @@ -286,13 +286,13 @@ import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
def _fn_kernel(x, end, out, out_size_0, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
load = tl.load(end + tl.zeros([], tl.int32), None)
load = tl.load(end + 0 * end_stride_0, None)
for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < load
Expand All @@ -307,15 +307,15 @@ def fn(x: torch.Tensor, end: torch.Tensor):
out = x.new_empty([x.size(0)])
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
out = x.new_empty([x.size(0)])
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestLoops.test_data_dependent_bounds3)
from __future__ import annotations
Expand All @@ -325,14 +325,14 @@ import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
def _fn_kernel(x, end0, end1, out, x_size_0, end0_stride_0, end1_stride_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64)
load = tl.load(end0 + tl.zeros([], tl.int32), None)
load_1 = tl.load(end1 + tl.zeros([], tl.int32), None)
load = tl.load(end0 + 0 * end0_stride_0, None)
load_1 = tl.load(end1 + 0 * end1_stride_0, None)
for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < load
Expand All @@ -352,7 +352,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_1 = 32
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
Expand All @@ -361,7 +361,7 @@ def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_1 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestLoops.test_data_dependent_bounds4)
from __future__ import annotations
Expand All @@ -371,14 +371,14 @@ import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_0
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
load = tl.load(begin + tl.zeros([], tl.int32), None)
load_1 = tl.load(end + tl.zeros([], tl.int32), None)
load = tl.load(begin + 0 * begin_stride_0, None)
load_1 = tl.load(end + 0 * end_stride_0, None)
for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < load_1
Expand All @@ -394,7 +394,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
bs = 32
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
Expand All @@ -403,7 +403,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestLoops.test_data_dependent_bounds5)
from __future__ import annotations
Expand All @@ -413,14 +413,14 @@ import triton
import triton.language as tl

@triton.jit
def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
load = tl.load(begin + tl.zeros([], tl.int32), None)
load_1 = tl.load(end + tl.zeros([], tl.int32), None)
load = tl.load(begin + 0 * begin_stride_0, None)
load_1 = tl.load(end + 0 * end_stride_0, None)
for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < load_1
Expand All @@ -435,15 +435,15 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
out = x.new_empty([x.size(0)])
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
out = x.new_empty([x.size(0)])
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestLoops.test_l2_grouping_with_register_block_size)
from __future__ import annotations
Expand Down
Loading
Loading