Skip to content

Commit c595c4f

Browse files
committed
squashed cosmetic and minor changes
1 parent d6c0e16 commit c595c4f

File tree

17 files changed

+210
-261
lines changed

17 files changed

+210
-261
lines changed

hls4ml/backends/oneapi/oneapi_backend.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,6 @@ def compile(self, model):
179179
try:
180180
subprocess.run('which icpx', shell=True, cwd=builddir, check=True)
181181
except subprocess.CalledProcessError:
182-
try:
183-
import pytest
184-
185-
pytest.skip('icpx not present')
186-
except ImportError:
187-
pass
188182
raise RuntimeError('Could not find icpx. Please configure oneAPI appropriately')
189183
subprocess.run('cmake ..', shell=True, cwd=builddir, check=True)
190184
subprocess.run('make lib', shell=True, cwd=builddir, check=True)
@@ -210,12 +204,6 @@ def build(self, model, build_type='fpga_emu', run=False):
210204
try:
211205
subprocess.run('which icpx', shell=True, cwd=builddir, check=True)
212206
except subprocess.CalledProcessError:
213-
try:
214-
import pytest
215-
216-
pytest.skip('icpx not present')
217-
except ImportError:
218-
pass
219207
raise RuntimeError('Could not find icpx. Please configure oneAPI appropriately')
220208
subprocess.run('cmake ..', shell=True, cwd=builddir, check=True)
221209
subprocess.run(f'make {build_type}', shell=True, cwd=builddir, check=True)

hls4ml/backends/vivado/passes/distributed_arithmetic.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer
1212
from hls4ml.model.types import FixedPrecisionType, Source
1313
from hls4ml.utils.dependency import requires
14-
from hls4ml.utils.einsum_utils import parse_einsum # noqa: F401
1514

1615
if typing.TYPE_CHECKING:
1716
from hls4ml.model import ModelGraph
@@ -66,8 +65,8 @@ def _(node: Dense):
6665
@get_kernel_inp_kif.register(Conv1D)
6766
@get_kernel_inp_kif.register(Conv2D)
6867
def _(layer: Conv1D | Conv2D):
69-
assert layer.attributes.attributes['data_format'] == 'channels_last', 'Only channels_last format is supported'
70-
kernel = layer.attributes.attributes['weight'].data
68+
assert layer.attributes['data_format'] == 'channels_last', 'Only channels_last format is supported'
69+
kernel = layer.attributes['weight'].data
7170
k_in, i_in, f_in = _get_input_kif(layer)
7271
k_in, i_in, f_in = pad_arrs(layer, 0, k_in, i_in, f_in)
7372
k_in, i_in, f_in = im2col(kernel.shape, k_in, i_in, f_in)
@@ -149,16 +148,10 @@ def transform(self, model: 'ModelGraph', node: Layer):
149148
node.set_attr('da_codegen', Source(fn_str))
150149

151150

152-
dense_da_stream_template = '''struct config{index} {{
153-
static const unsigned n_in = {n_in};
154-
static const unsigned n_out = {n_out};
155-
static const unsigned io_type = nnet::io_stream;
156-
static const unsigned strategy = nnet::distributed_arithmetic;
157-
constexpr static auto dense_da = nnet::dense_da_{index}<typename {inp_t}::value_type, typename {out_t}::value_type>;
158-
}};\n'''
159-
160-
161151
class FuseQuantizerIntoDALayers(OptimizerPass):
152+
"""Heterogeneous quantizer can be fused into the DA CMVM kernel in some cases.
153+
This would allow heterogeenous quantizarion for io stream in some cases."""
154+
162155
def match(self, node: Layer):
163156
if not isinstance(node, FixedPointQuantizer):
164157
return False
@@ -203,9 +196,18 @@ def transform(self, model: 'ModelGraph', node: FixedPointQuantizer):
203196
return True
204197

205198

199+
dense_da_stream_template = '''struct config{index} {{
200+
static const unsigned n_in = {n_in};
201+
static const unsigned n_out = {n_out};
202+
static const unsigned io_type = nnet::io_stream;
203+
static const unsigned strategy = nnet::distributed_arithmetic;
204+
constexpr static auto dense_da = nnet::dense_da_{index}<typename {inp_t}::value_type, typename {out_t}::value_type>;
205+
}};\n'''
206+
207+
206208
class DALatencyDenseTemplate(OptimizerPass):
207-
# For Dense, distributed arithmetic do not call the original, regardless of the io_type
208-
# FOr io_stream, a minimal config will still be generated
209+
# For Dense, distributed arithmetic do not call the original impl, regardless of the io_type
210+
# For io_stream, a minimal config will still be generated
209211
def match(self, node: Layer):
210212
if node.class_name != 'Dense':
211213
return False
@@ -225,10 +227,10 @@ def transform(self, model: 'ModelGraph', node: Layer):
225227
if io_type == 'io_parallel':
226228
fn_name = f'dense_da_{node.index}<{inp_t}, {out_t}>'
227229
function_cpp = f'{namespace}::{fn_name}({inp_name}, {out_name});'
228-
node.attributes.attributes['function_cpp'] = function_cpp
230+
node.attributes['function_cpp'] = function_cpp
229231
else:
230232
assert io_type == 'io_stream'
231-
config_cpp = dense_da_stream_template.format(inp_t=inp_t, out_t=out_t, **node.attributes.attributes)
233+
config_cpp = dense_da_stream_template.format(inp_t=inp_t, out_t=out_t, **node.attributes)
232234
function_cpp = f'nnet::dense<{inp_t}, {out_t}, config{node.index}>({inp_name}, {out_name});'
233235
node.attributes['config_cpp'] = config_cpp
234236
node.attributes['function_cpp'] = function_cpp
@@ -300,8 +302,8 @@ def transform(self, model: 'ModelGraph', node: Layer):
300302
class_name = class_name[9:]
301303

302304
ndim = len(ker_shape) - 2
303-
function_cpp = f'nnet::conv{ndim}d_da_cl<config{node.index}, {inp_t}, {out_t}>({inp_name}, {out_name});'
304-
node.attributes.attributes['function_cpp'] = function_cpp
305+
function_cpp = f'nnet::conv{ndim}d_cl<config{node.index}, {inp_t}, {out_t}>({inp_name}, {out_name});'
306+
node.attributes['function_cpp'] = function_cpp
305307

306308
# config generation
307309
params = node.attributes.attributes.copy()
@@ -314,15 +316,15 @@ def transform(self, model: 'ModelGraph', node: Layer):
314316
params.setdefault('stride_height', -1 if ndim == 1 else 1)
315317

316318
config_cpp = conv_da_parallel_template.format(inp_t=inp_t, out_t=out_t, n_pixels=n_pixels, **params)
317-
node.attributes.attributes['config_cpp'] = config_cpp
319+
node.attributes['config_cpp'] = config_cpp
318320

319321
# Only unrolled header is required for io_parallel
320322
include_headers = [
321323
'nnet_utils/nnet_da_wrappers.h',
322324
f'nnet_utils/nnet_{class_name.lower()}.h',
323325
'nnet_utils/nnet_conv_stream.h', # some properties defined in config need this
324326
]
325-
node.attributes.attributes['include_header'] = include_headers
327+
node.attributes['include_header'] = include_headers
326328

327329
# avoid output weights and bias; alternatie entry point does not use them
328330
del node.attributes['weight_data']
@@ -333,6 +335,18 @@ def transform(self, model: 'ModelGraph', node: Layer):
333335
del node.attributes['bias_t']
334336

335337

338+
kernel_fn_template = '''
339+
template <typename inp_t, typename out_t>
340+
void einsum_dense{index}_da_kernel(
341+
inp_t inp_tpose[{inp_tpose}],
342+
out_t out_tpose[{out_tpose}],
343+
int l0
344+
) {{
345+
{fn_call_str}
346+
}}
347+
'''
348+
349+
336350
class DistributedArithmeticEinsumCodegen(OptimizerPass):
337351
'''Generates C++ code for distributed arithmetic implementation of Dense layers'''
338352

@@ -373,16 +387,13 @@ def transform(self, model: 'ModelGraph', node: Layer):
373387
fn_call = f'{fn_name}(&inp_tpose[({i} * {L_data} + l0) * {C}], &out_tpose[({i} * {L_data} + l0) * {L_ker}]);'
374388
fn_calls.append(fn_call)
375389

376-
kernel_fn = f'''
377-
template <typename inp_t, typename out_t>
378-
void einsum_dense{node.index}_da_kernel(
379-
inp_t inp_tpose[{L_data * C * I}],
380-
out_t out_tpose[{L_data * L_ker * I}],
381-
int l0
382-
) {{
383-
{" ".join(fn_calls)}
384-
}}
385-
'''
390+
kernel_fn = kernel_fn_template.format(
391+
index=node.index,
392+
inp_tpose=L_data * C * I,
393+
out_tpose=L_data * L_ker * I,
394+
fn_call_str=' \n'.join(fn_calls),
395+
)
396+
386397
code_gen = '\n\n'.join(fn_strs) + '\n\n' + kernel_fn
387398
node.attributes['da_codegen'] = Source(code_gen)
388399
del node.attributes['weight_data']

hls4ml/converters/keras_v3/hgq2/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def extract_fixed_quantizer_config(q, tensor: 'KerasTensor', is_input: bool) ->
3333
B = np.broadcast_to(B.astype(np.int8), (1,) + shape) # type: ignore
3434
I = np.broadcast_to(I.astype(np.int8), (1,) + shape) # noqa: E741
3535

36-
overflow_mode = internal_q.overflow_mode
37-
round_mode = internal_q.round_mode
36+
overflow_mode: str = internal_q.overflow_mode
37+
round_mode: str = internal_q.round_mode
38+
if round_mode.startswith('S_'):
39+
round_mode = round_mode[2:]
3840
fusible = np.unique(k).size == 1 and np.unique(B).size == 1 and np.unique(I).size == 1
3941

4042
input_keras_tensor_names = tensor.name if is_input else f'{tensor.name}_q'

hls4ml/converters/keras_v3/hgq2/softmax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def handle(
6363
assert all(ax1 - ax0 == 1 for ax0, ax1 in zip(axs[:-1], axs[1:])), 'Softmax must act on adjacent axes'
6464
n_outer: int = prod(in_tensors[0].shape[1 : axs[0]]) # type: ignore
6565
n_inner: int = prod(in_tensors[0].shape[axs[-1] + 1 :]) # type: ignore
66-
ax = -1 # if n_inner == 1 else 999 # 999 as placeholder
66+
ax = -1
6767
n_in: int = prod(in_tensors[0].shape[1:]) # type: ignore
6868

6969
from hgq.quantizer.internal import FixedPointQuantizerBase
@@ -124,8 +124,6 @@ def handle(
124124

125125
if layer.stable:
126126
inp_norm_t = fixed_quantizer_to_hls4ml_t(layer.exp_table.iq.quantizer)
127-
# inp_norm_t.saturation_mode = SaturationMode.WRAP
128-
# inp_norm_t.rounding_mode = RoundingMode.TRN
129127
config['inp_norm_t'] = inp_norm_t
130128

131129
return (config,)

hls4ml/converters/keras_v3/hgq2/unary_lut.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Sequence
33

44
import numpy as np
5-
from quantizers import float_quantize, get_fixed_quantizer_np
5+
from quantizers import get_fixed_quantizer_np
66

77
from hls4ml.model.types import FixedPrecisionType
88

@@ -14,8 +14,6 @@
1414

1515
from decimal import Decimal
1616

17-
from hls4ml.utils.qinterval import minimal_kif
18-
1917

2018
@register
2119
class QUnaryLUTHandler(QLayerHandler, KerasV3LayerHandler):
@@ -78,12 +76,8 @@ def handle(
7876
table_t = FixedPrecisionType(b, I, k)
7977
else:
8078
assert isinstance(oq, FloatPointQuantizer)
81-
m, e, e0 = (ops.convert_to_numpy(x).ravel().item() for x in (oq.m, oq.e, oq.e0))
82-
table = float_quantize(table, m, e, e0)
83-
k, i, f = (int(np.min(x)) for x in minimal_kif(table))
84-
8579
raise NotImplementedError('FloatPointQuantizer is not supported yet')
86-
table_t = FixedPrecisionType(k + i + f, k + i, bool(k))
80+
8781
table = ops.convert_to_numpy(table)
8882

8983
config.update(

hls4ml/converters/keras_v3_to_hls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None):
2323
"""Given a keras layer, return a list of tuples of input and output
24-
tensors. If the layer is called only once (i.e., no shared layers),
25-
the list will contain only one tuple.
24+
tensors. If the layer is called only once (i.e., layer is not used
25+
multiple times in the same model), the list will contain only one tuple.
2626
2727
The layer must have been built before calling this function.
2828

hls4ml/model/graph.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -873,12 +873,6 @@ def _compute_n_samples(self, x):
873873
return int(n_sample)
874874

875875
def predict(self, x):
876-
if isinstance(x, np.ndarray) and not x.flags['C_CONTIGUOUS']:
877-
x = np.ascontiguousarray(x)
878-
879-
# Compile the model if it wasn't compiled yet
880-
if self._top_function_lib is None:
881-
self.compile()
882876
top_function, ctype = self._get_top_function(x)
883877
n_samples = self._compute_n_samples(x)
884878
n_inputs = len(self.get_input_variables())

0 commit comments

Comments
 (0)