Skip to content

Commit 1931653

Browse files
committed
fix einsum/einsum dense regression issue
1 parent 6f284ec commit 1931653

File tree

5 files changed

+27
-33
lines changed

5 files changed

+27
-33
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def generate_conv2d_line_buffer_fn(
914914
return generated_code
915915

916916
@staticmethod
917-
def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
917+
def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
918918
"""
919919
Generate new shape and perm_strides for a permute operation. Operates by mapping the output index
920920
to input input index by:
@@ -933,12 +933,20 @@ def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...])
933933
perm (tuple[int, ...]): The permutation of the dimensions.
934934
935935
Returns:
936-
(new_shape, perm_strides) (tuple, tuple): the output shape and permutation strides.
936+
dict: Dictionary containing the configuration.
937937
"""
938938
new_shape = tuple(shape[i] for i in perm)
939939
strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1]
940940
perm_strides = tuple(int(strides[i]) for i in perm)
941-
return (new_shape, perm_strides)
941+
return dict(
942+
dims=len(shape),
943+
N=math.prod(shape),
944+
from_shape=', '.join(str(x) for x in shape),
945+
perm=', '.join(str(x) for x in perm),
946+
perm_strides=', '.join(str(x) for x in perm_strides),
947+
to_shape=', '.join(str(x) for x in new_shape),
948+
config_name=name,
949+
)
942950

943951
@model_optimizer()
944952
def write_hls(self, model):

hls4ml/backends/oneapi/passes/reshaping_templates.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,8 @@ def format(self, node):
185185
perm = tuple(node.get_attr('perm'))
186186
name = f'config{node.index}'
187187

188-
new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm)
189-
return transpose_config_template.format(
190-
dims=len(shape),
191-
N=int(np.prod(shape)),
192-
from_shape=', '.join(str(x) for x in shape),
193-
perm=', '.join(str(x) for x in perm),
194-
perm_strides=', '.join(str(x) for x in perm_strides),
195-
to_shape=', '.join(str(x) for x in new_shape),
196-
config_name=name,
197-
)
188+
conf = node.model.config.backend.transpose_config_gen(name, shape, perm)
189+
return transpose_config_template.format(**conf)
198190

199191

200192
class TransposeFunctionTemplate(FunctionCallTemplate):

hls4ml/backends/vivado/passes/einsum.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
55
from hls4ml.model.layers import Einsum
66

7-
from .reshaping_templates import transpose_config_gen
7+
from .reshaping_templates import transpose_config_template
88

99
# Shared Dense template
1010
# Einsum template
@@ -81,9 +81,12 @@ def format(self, node: Einsum):
8181
tpose_inp1_conf_name = f'config{node.index}_tpose_inp1'
8282
tpose_out_conf_name = f'config{node.index}_tpose_out'
8383

84-
inp0_tpose_conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs)
85-
inp1_tpose_conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs)
86-
out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
84+
conf = node.model.config.backend.transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs)
85+
inp0_tpose_conf = transpose_config_template.format(**conf)
86+
conf = node.model.config.backend.transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs)
87+
inp1_tpose_conf = transpose_config_template.format(**conf)
88+
conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
89+
out_tpose_conf = transpose_config_template.format(**conf)
8790

8891
return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf))
8992

hls4ml/backends/vivado/passes/einsum_dense.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
33
from hls4ml.model.layers import EinsumDense
44

5-
from .reshaping_templates import transpose_config_gen
5+
from .reshaping_templates import transpose_config_template
66

77
# Shared Dense template
88

@@ -118,8 +118,10 @@ def format(self, node: EinsumDense):
118118
tpose_inp_conf_name = f'config{node.index}_tpose_inp'
119119
tpose_out_conf_name = f'config{node.index}_tpose_out'
120120

121-
inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs)
122-
out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
121+
conf = node.model.config.backend.transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs)
122+
inp_tpose_conf = transpose_config_template.format(**conf)
123+
conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
124+
out_tpose_conf = transpose_config_template.format(**conf)
123125

124126
if strategy.lower() == 'distributed_arithmetic':
125127
return '\n\n'.join((inp_tpose_conf, out_tpose_conf, einsum_conf))

hls4ml/backends/vivado/passes/reshaping_templates.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
42
from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D
53

@@ -128,22 +126,13 @@ def format(self, node):
128126
class TransposeConfigTemplate(LayerConfigTemplate):
129127
def __init__(self):
130128
super().__init__(Transpose)
131-
self.template = transpose_config_template
132129

133130
def format(self, node):
134131
shape = tuple(node.get_input_variable().shape)
135132
perm = tuple(node.get_attr('perm'))
136133
name = f'config{node.index}'
137-
new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm)
138-
return transpose_config_template.format(
139-
dims=len(shape),
140-
N=np.prod(shape),
141-
from_shape=', '.join(str(x) for x in shape),
142-
perm=', '.join(str(x) for x in perm),
143-
perm_strides=', '.join(str(x) for x in perm_strides),
144-
to_shape=', '.join(str(x) for x in new_shape),
145-
config_name=name,
146-
)
134+
conf = node.model.config.backend.transpose_config_gen(name, shape, perm)
135+
return transpose_config_template.format(**conf)
147136

148137

149138
class TransposeFunctionTemplate(FunctionCallTemplate):

0 commit comments

Comments
 (0)