From 47f3e3c59511808e3b436eda3d37b7ddd0511bed Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 15:41:59 -0800 Subject: [PATCH 01/30] add keras v3 object parser --- hls4ml/converters/__init__.py | 3 + hls4ml/converters/keras_to_hls.py | 9 + hls4ml/converters/keras_v3_to_hls.py | 284 +++++++++++++++++++++++++++ 3 files changed, 296 insertions(+) create mode 100644 hls4ml/converters/keras_v3_to_hls.py diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 693a76f666..47569b1ad9 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -9,6 +9,7 @@ from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler +from hls4ml.converters.keras_v3_to_hls import parse_keras_v3_model # noqa: F401 from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler @@ -17,6 +18,8 @@ pytorch_to_hls, register_pytorch_layer_handler, ) + +# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config from hls4ml.utils.dependency import requires diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index 00561e6ba8..557b7b9461 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -4,6 +4,8 @@ from hls4ml.model import ModelGraph +from .keras_v3_to_hls import parse_keras_v3_model + MAXMULT = 4096 @@ -323,6 +325,13 @@ def parse_keras_model(model_arch, reader): def keras_to_hls(config): + if 'KerasModel' in config: + import keras + + if keras.__version__ >= '3.0': + layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) + return ModelGraph(config, layer_list, input_layers, output_layers) + model_arch, reader = get_model_arch(config) layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader) print('Creating HLS model') diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py new file mode 100644 index 0000000000..5c0168cc1e --- /dev/null +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -0,0 +1,284 @@ +import typing +from itertools import chain +from types import FunctionType +from typing import Any, Callable, Sequence + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +import numpy as np + +from .keras_v3 import layer_handlers as v3_layer_handlers + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + + +def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None): + """Given a keras layer, return a list of tuples of input and output + tensors. If the layer is called only once (i.e., no shared layers), + the list will contain only one tuple. + + The layer must have been built before calling this function. + + Parameters + ---------- + layer : keras.Layer + The layer to get input and output tensors from. + node_whitelist : set[int]|None, optional + If not None, only return tensors from nodes with ids in this + set, used to filter out nodes that are not part of the model, by + default None + + + Returns + ------- + list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] + A list of tuples of input and output tensors. + """ + in_nodes = layer._inbound_nodes + if node_whitelist is not None: + in_nodes = [node for node in in_nodes if id(node) in node_whitelist] + + ret: list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] = [] + for node in in_nodes: + in_tensors = tuple(node.arguments.keras_tensors) + out_tensors = tuple(node.outputs) + ret.append((in_tensors, out_tensors)) + return ret + + +def resolve_dependency_relation(model: 'keras.Model'): + """Given a keras model, return the following information: + - A list of input tensor names + - A list of output tensor names + - A list of (layer_name, input_tensor_names, output_tensor_names) tuples + - A dictionary of tensor_name -> KerasTensor + + Parameters + ---------- + model : keras.Model + The keras model to analyze. + + Returns + ------- + tuple[tuple[str, ...], tuple[str, ...], list[tuple[str, tuple[str, ...], tuple[str, ...]]], dict[str, KerasTensor]] + inp_tensor_names, out_tensor_names, layer_io, tensors + """ + tensors: dict[str, 'KerasTensor'] = {} + "tensor_name -> KerasTensor" + depends_on: dict[str, tuple[str, ...]] = {} + "tensor_name -> {tensor_name}" + layer_io: list[tuple[str, tuple[str, ...], tuple[str, ...]]] = [] + "layer_name -> ((input_tensor_names), (output_tensor_names))" + + inputs = tuple(t.name for t in model.inputs) + outputs = tuple(t.name for t in model.outputs) + node_whitelist = {id(node) for v in model._nodes_by_depth.values() for node in v} + + for layer in model.layers: + for in_tensors, out_tensors in get_io_tensors(layer, node_whitelist): + in_tensor_names = tuple(t.name for t in in_tensors) + out_tensor_names = tuple(t.name for t in out_tensors) + for t in chain(in_tensors, out_tensors): + tensors[t.name] = t + for o_name in out_tensor_names: + depends_on[o_name] = in_tensor_names + layer_io.append((layer.name, in_tensor_names, out_tensor_names)) + + return inputs, outputs, layer_io, tensors + + +class UniqueName: + """Helper class to generate unique names for layers, if one being used multiple times.""" + + def __init__(self): + self.used_names: set[str] = set() + + def next_name(self, name: str): + i = 0 + if name in self.used_names: + while f'{name}_{i}' in self.used_names: + i += 1 + name = f'{name}_{i}' + self.used_names.add(name) + return name + + def __call__(self, name: str): + return self.next_name(name) + + def reset(self): + self.used_names.clear() + + +class KerasV3HandlerDispatcher: + """Dispatcher class to handle different types of keras v3 layers.""" + + def __init__(self, layer_handlers: dict[str, T_kv3_handler], v2_layer_handlers=None): + self.registry = layer_handlers + self.v2_layer_handlers = v2_layer_handlers or {} + + def __call__( + self, layer: 'keras.Layer', in_tensors: Sequence['keras.KerasTensor'], out_tensors: Sequence['keras.KerasTensor'] + ) -> tuple[dict[str, Any], ...]: + assert layer.built, f"Layer {layer.name} is not built" + + ret = self.v3_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + ret = self.v2_call(layer, in_tensors, out_tensors) + if ret is not None: + return ret + + raise ValueError( + f"Layer {layer.__class__.__module__}.{layer.__class__.__name__} not found in either v3 or v2 handlers" + ) + + def v3_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + cls_name = layer.__class__.__name__ + module = layer.__module__ + key = f"{module}.{cls_name}" + + # keras v3 handlers + handler = self.registry.get(key, None) + handler = handler or self.registry.get(cls_name, None) + + if handler is None: + return None + return handler(layer, inp_tensors, out_tensors) + + def v2_call( + self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] + ): + # keras v2 handlers fallback + print(f"v2 handler used for layer {layer.name}") + + import keras + + config = layer.get_config() + layer_dict = {'config': config, 'class_name': layer.__class__.__name__} + + class DummyReader: + def get_weights_data(self, layer_name, var_name): + assert layer_name == layer.name, f"Processing {layer.name}, but handler tried to read {layer_name}" + for w in layer.weights: + if var_name in w.name: + return np.array(w) + return None + + reader = DummyReader() + input_shapes = [list(t.shape) for t in inp_tensors] + input_names = [t.name for t in inp_tensors] + output_names = [t.name for t in out_tensors] + key = layer.__class__.__name__ + handler = self.v2_layer_handlers.get(key, None) + if handler is None: + return None + + ret, _ = handler(layer_dict, input_names, input_shapes, reader) + ret['output_keras_tensor_names'] = output_names + ret['input_keras_tensor_names'] = input_names + ret = (ret,) + + activation = getattr(layer, 'activation', None) + if activation not in (keras.activations.linear, None): + assert isinstance(activation, FunctionType), f"Activation function for layer {layer.name} is not a function" + intermediate_tensor_name = f'{output_names[0]}_activation' + ret[0]['output_keras_tensor_names'] = (intermediate_tensor_name,) + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{layer.name}_{act_cls_name}', + 'input_keras_tensor_names': (intermediate_tensor_name,), + 'output_keras_tensor_names': output_names, + } + ret = *ret, act_config + return ret + + +def parse_keras_v3_model(model: 'keras.Model'): + """Parse a keras model into a list of dictionaries, each + representing a layer in the HLS model, and a list of input and + output layer names. + + Parameters + ---------- + model : keras.Model + + Returns + ------- + tuple[list[dict[str, Any]], list[str], list[str], list[list[int]]] + layer_list, input_layer_names, output_layer_names, + batch_output_shapes + + Raises + ------ + ValueError + If a circular dependency is detected. + """ + + assert model.built, "Model must be built before parsing" + + import keras + + if isinstance(model, keras.Sequential): + model = model._functional # everything is functional under the hood lol + + from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import + + keras_v3_dispatcher = KerasV3HandlerDispatcher(v3_layer_handlers, v2_layer_handlers) + + model_inputs, model_outputs, dependency, tensors = resolve_dependency_relation(model) + + satisfied = set() + + unique_name = UniqueName() + + layer_list: list[dict[str, Any]] = [] + + while any(t not in satisfied for t in model_outputs): + # Until all tensors in the model are satisfied + for i, (layer_name, in_tensor_names, out_tensor_names) in enumerate(dependency): + if not all(t in satisfied for t in in_tensor_names): + continue # Skip layer if some inputs are not ready + if all(t in satisfied for t in out_tensor_names): + continue # Skip layer if the outputs are already satisfied + + layer: 'keras.Layer' = model.get_layer(layer_name) + inp_tensors = [tensors[t] for t in in_tensor_names] + out_tensors = [tensors[t] for t in out_tensor_names] + + _configs = keras_v3_dispatcher(layer, inp_tensors, out_tensors) + # Dispatch to v3 handler if available, else fallback to v2 handler + + # Prevent name conflicts. If a layer is used multiple times, add a suffix to the name. + # At this stage connections between modules are recorded by i/o keras tensor names + for _conf in _configs: + _conf['name'] = unique_name(_conf['name']) + + layer_list.extend(_configs) # Add the layer to the list + satisfied.update(out_tensor_names) # Mark the outputs as satisfied + dependency.pop(i) + break # Restart the loop to add another layer + else: + # If no layer was added in the loop, then there is a circular dependency + raise ValueError("Circular dependency detected") + + # Mark inputs[inp layer name] for ModelGraph to parse from i/o keras tensor names + provides: dict[str, str] = {} # tensor_name -> src_layer_name + for conf in layer_list: + for out_name in conf['output_keras_tensor_names']: + provides[out_name] = conf['name'] + inputs = [provides[tname] for tname in conf['input_keras_tensor_names']] + conf['inputs'] = inputs + + input_layer_names = [provides[tname] for tname in model_inputs] + output_layer_names = [provides[tname] for tname in model_outputs] + batch_output_shapes = [list(tensors[tname].shape) for tname in model_outputs] + + return layer_list, input_layer_names, output_layer_names, batch_output_shapes From 5debe7162902f8e647769054e87569de08cf2b8b Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 15:42:45 -0800 Subject: [PATCH 02/30] add keras v3 layer handlers --- hls4ml/converters/keras_v3/__init__.py | 6 + hls4ml/converters/keras_v3/_base.py | 216 ++++++++++++++++++++++++ hls4ml/converters/keras_v3/conv.py | 119 +++++++++++++ hls4ml/converters/keras_v3/core.py | 222 +++++++++++++++++++++++++ 4 files changed, 563 insertions(+) create mode 100644 hls4ml/converters/keras_v3/__init__.py create mode 100644 hls4ml/converters/keras_v3/_base.py create mode 100644 hls4ml/converters/keras_v3/conv.py create mode 100644 hls4ml/converters/keras_v3/core.py diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py new file mode 100644 index 0000000000..6dffcb71d5 --- /dev/null +++ b/hls4ml/converters/keras_v3/__init__.py @@ -0,0 +1,6 @@ +from . import conv # noqa: F401 +from . import core # noqa: F401 +from . import einsum_dense # noqa: F401 +from ._base import registry as layer_handlers + +__all__ = ['layer_handlers'] diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py new file mode 100644 index 0000000000..6f50ed6523 --- /dev/null +++ b/hls4ml/converters/keras_v3/_base.py @@ -0,0 +1,216 @@ +import typing +from types import FunctionType +from typing import Any, Callable, Sequence, TypedDict, overload + + +class DefaultConfig(TypedDict, total=False): + name: str + class_name: str + module: str + input_keras_tensor_names: list[str] + input_shape: list[list[int]] + output_keras_tensor_names: list[str] + epsilon: float + use_bias: bool + data_format: str + + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + +T_kv3_handler = Callable[ + ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] +] + +registry: dict[str, T_kv3_handler] = {} + + +@overload +def register(cls: type) -> type: ... + + +@overload +def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... + + +def register(cls: str | type): + """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. + + Parameters + ---------- + cls : str|type + If str, the key to register the handler under. If type, the class to register the handler for. + + Examples + -------- + ```python + @keras_dispatcher.register + class MyLayerHandler(KerasV3LayerHandler): + handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') + + def handle(self, layer, inp_tensors, out_tensors): + # handler code + + + @keras_dispatcher.register('MyLayer3') + def my_layer_handler(layer, inp_tensors, out_tensors): + # handler code + ``` + """ + + def deco(func): + if isinstance(cls, str): + registry[cls] = func + for k in getattr(func, 'handles', ()): + registry[k] = func + if isinstance(cls, type): + return cls + return func + + if isinstance(cls, type): + return deco(cls()) + return deco + + +def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: str): + for attr in attrs: + if attr not in config and hasattr(obj, attr): + config[attr] = getattr(obj, attr) + + +class KerasV3LayerHandler: + """Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.""" + + handles = () + default_config: DefaultConfig + + def __call__( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> tuple[dict[str, Any], ...]: + """Handle a keras layer. Return a tuple of dictionaries, each + dictionary representing a layer (module) in the HLS model. One + layer may correspond one or more dictionaries (e.g., layers with + activation functions will be split into two layers). + + Some common attributes are automatically added to the dictionary + if the handler returns a single dictionary. If the handler + returns multiple dictionaries, the attributes must be added + manually. Anything returned by the handler will override the + automatic attributes. + + Automatic attributes: - name - class_name - module - + input_keras_tensor_names - input_shape - + output_keras_tensor_names + + If the layer has an activation function, an additional + dictionary will be added to the return value representing the + activation function. + + + Parameters + ---------- + layer : keras.Layer + The layer to be converted to HLS configuration(s). + in_tensors : Sequence[KerasTensor] + The list of input tensors to the layer. + out_tensors : Sequence[KerasTensor] + The list of output tensors from the layer. + + Returns + ------- + dict[str, Any] | tuple[dict[str, Any], ...] + layer configuration(s) for the HLS model to be consumed by + the ModelGraph constructor + """ + + name = layer.name + class_name = layer.__class__.__name__ + module = layer.__module__ + + default_config: DefaultConfig = { + 'name': name, + 'class_name': class_name, + 'module': module, + 'input_keras_tensor_names': [t.name for t in in_tensors], + 'input_shape': [list(t.shape[1:]) for t in in_tensors], # type: ignore + 'output_keras_tensor_names': [t.name for t in out_tensors], + } + + maybe_add_attrs(default_config, layer, 'epsilon', 'use_bias', 'data_format') + + mandatory_keys = ['name', 'class_name', 'output_keras_tensor_names', 'input_keras_tensor_names'] + + self.default_config = default_config + config0 = self.handle(layer, in_tensors, out_tensors) + del self.default_config + + if isinstance(config0, tuple): + for conf in config0: + for key in mandatory_keys: + assert key in conf, f"Key {key} missing from layer {name} handled by {self.__class__.__name__}" + return config0 + + config = {} + config.update(default_config) + config.update(config0) + ret = (config,) + + # If activation exists, append it + + act_config, intermediate_tensor_name = self.maybe_get_activation_config(layer, out_tensors) + if act_config is not None: + ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name] + ret = *ret, act_config + + return ret + + def maybe_get_activation_config(self, layer, out_tensors): + import keras + + activation = getattr(layer, 'activation', None) + name = layer.name + if activation not in (keras.activations.linear, None): + assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function" + assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function" + intermediate_tensor_name = f'{out_tensors[0].name}_activation' + act_cls_name = activation.__name__ + act_config = { + 'class_name': 'Activation', + 'activation': act_cls_name, + 'name': f'{name}_{act_cls_name}', + 'input_keras_tensor_names': [intermediate_tensor_name], + 'output_keras_tensor_names': [out_tensors[0].name], + } + return act_config, intermediate_tensor_name + return None, None + + def handle( + self, + layer: 'keras.Layer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ) -> dict[str, Any] | tuple[dict[str, Any], ...]: + return {} + + def load_weight(self, layer: 'keras.Layer', key: str): + """Load a weight from a layer. + + Parameters + ---------- + layer : keras.Layer + The layer to load the weight from. + key : str + The key of the weight to load. + + Returns + ------- + np.ndarray + The weight. + """ + import keras + + return keras.ops.convert_to_numpy(getattr(layer, key)) diff --git a/hls4ml/converters/keras_v3/conv.py b/hls4ml/converters/keras_v3/conv.py new file mode 100644 index 0000000000..adf6221822 --- /dev/null +++ b/hls4ml/converters/keras_v3/conv.py @@ -0,0 +1,119 @@ +import typing +from math import ceil +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +@register +class KV3ConvHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.convolutional.conv1d.Conv1D', + 'keras.src.layers.convolutional.conv2d.Conv2D', + 'keras.src.layers.convolutional.depthwise_conv1d.DepthwiseConv1D', + 'keras.src.layers.convolutional.depthwise_conv2d.DepthwiseConv2D', + 'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D', + 'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D', + ) + + def handle( + self, + layer: 'keras.layers.Conv1D|keras.layers.Conv2D|keras.layers.DepthwiseConv1D|keras.layers.DepthwiseConv2D', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras.src.layers.convolutional.base_conv import BaseConv + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv + + assert len(in_tensors) == 1, f"Layer {layer.name} has more than one input" + assert len(out_tensors) == 1, f"Layer {layer.name} has more than one output" + + in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}" + assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}" + + kernel = self.load_weight(layer, 'kernel') + if layer.use_bias: + bias = self.load_weight(layer, 'bias') + else: + bias = None + + ker_px_shape: tuple[int, ...] = layer.kernel_size + data_format = layer.data_format + + if data_format == 'channels_last': + *px_in_shape, ch_in = in_shape + *px_out_shape, ch_out = out_shape + else: + ch_in, *px_in_shape = in_shape + ch_out, *px_out_shape = out_shape + + if layer.padding == 'same': + n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)] + n_padding0 = [p // 2 for p in n_padding] + n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)] + elif layer.padding == 'valid': + n_padding0 = [0] * len(px_in_shape) + n_padding1 = [0] * len(px_in_shape) + elif layer.padding == 'causal': + n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1) + n_padding1 = [0] * len(px_in_shape) + else: + raise ValueError(f"Invalid padding mode {layer.padding} for layer {layer.name}") + + config = { + 'bias_data': bias, + 'data_format': data_format, + 'weight_data': kernel, + 'n_filt': ch_out, + 'n_chan': ch_in, + } + + if layer.rank == 1: + config.update( + { + 'filt_width': ker_px_shape[0], + 'stride_width': layer.strides[0], + 'pad_left': n_padding0[0], + 'pad_right': n_padding1[0], + 'in_width': px_in_shape[0], + 'out_width': px_out_shape[0], + } + ) + elif layer.rank == 2: + config.update( + { + 'filt_height': ker_px_shape[0], + 'filt_width': ker_px_shape[1], + 'stride_height': layer.strides[0], + 'stride_width': layer.strides[1], + 'pad_top': n_padding0[0], + 'pad_bottom': n_padding1[0], + 'pad_left': n_padding0[1], + 'pad_right': n_padding1[1], + 'in_height': px_in_shape[0], + 'in_width': px_in_shape[1], + 'out_height': px_out_shape[0], + 'out_width': px_out_shape[1], + } + ) + else: + _cls = f"{layer.__class__.__module__}.{layer.__class__.__name__}" + raise ValueError(f"Only 1D and 2D conv layers are supported, got {_cls} (rank={layer.rank})") + if isinstance(layer, BaseDepthwiseConv): + config['depthwise_data'] = kernel + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseSeparableConv): + config['depthwise_data'] = kernel + config['pointwise_data'] = self.load_weight(layer, 'pointwise_kernel') + config['depth_multiplier'] = layer.depth_multiplier + elif isinstance(layer, BaseConv): + config['weight_data'] = kernel + + return config diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py new file mode 100644 index 0000000000..f3ac9a0d75 --- /dev/null +++ b/hls4ml/converters/keras_v3/core.py @@ -0,0 +1,222 @@ +import inspect +import typing +from math import prod +from typing import Any, Sequence + +import numpy as np + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + from keras.src.layers.merging.base_merge import Merge + + +@register +class KV3DenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.dense.Dense',) + + def handle( + self, + layer: 'keras.layers.Dense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + + kernel = self.load_weight(layer, 'kernel') + bias = self.load_weight(layer, 'bias') if layer.use_bias else None + n_in, n_out = kernel.shape + + config = { + 'data_format': 'channels_last', + 'weight_data': kernel, + 'bias_data': bias, + 'n_out': n_out, + 'n_in': n_in, + } + return config + + +@register +class KV3InputHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.input_layer.InputLayer',) + + def handle( + self, + layer: 'keras.layers.InputLayer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {'input_shape': list(layer._batch_shape[1:])} + return config + + +@register +class KV3MergeHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.merging.add.Add', + 'keras.src.layers.merging.multiply.Multiply', + 'keras.src.layers.merging.average.Average', + 'keras.src.layers.merging.maximum.Maximum', + 'keras.src.layers.merging.minimum.Minimum', + 'keras.src.layers.merging.concatenate.Concatenate', + 'keras.src.layers.merging.subtract.Subtract', + 'keras.src.layers.merging.dot.Dot', + ) + + def handle( + self, + layer: 'Merge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + cls_name: str | None = None, + ): + assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output" + output_shape = list(out_tensors[0].shape[1:]) + + cls_name = cls_name or layer.__class__.__name__ + config: dict[str, Any] = { + 'output_shape': output_shape, + 'op': cls_name.lower(), + } + + match cls_name.lower(): + case 'Concatenate': + rank = len(output_shape) + class_name = f'Concatenate{rank}d' + config['axis'] = layer.axis + case 'Dot': + class_name = f'Dot{len(output_shape)}d' + rank = len(output_shape) + assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}" + case _: + class_name = 'Merge' + + config['class_name'] = class_name + return config + + +@register +class KV3ActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.activation.Activation',) + + def handle( + self, + layer: 'keras.layers.Activation', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + import keras + + config = {} + config.update(self.default_config) + + activation = getattr(layer, 'activation', keras.activations.linear) + match activation: + case keras.activations.softmax: + class_name = 'Softmax' + config['axis'] = -1 + case keras.activations.hard_sigmoid: + class_name = 'HardActivation' + case keras.activations.leaky_relu: + class_name = 'LeakyReLU' + signature = inspect.signature(keras.activations.leaky_relu) + config['activ_param'] = signature.parameters['negative_slope'].default + case keras.activations.elu: + class_name = 'ELU' + signature = inspect.signature(keras.activations.elu) + config['activ_param'] = signature.parameters['alpha'].default + case _: + class_name = 'Activation' + + config['activation'] = activation.__name__ + config['class_name'] = class_name + return (config,) + + +@register +class KV3ReLUHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.activations.leaky_relu.LeakyReLU', + 'keras.src.layers.activations.prelu.PReLU', + 'keras.src.layers.activations.relu.ReLU', + ) + + def handle( + self, + layer: 'keras.layers.ReLU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + if layer.__class__.__name__ == 'ReLU': + config['class_name'] = 'Activation' + config['activation'] = 'relu' + return config + + if layer.__class__.__name__ == 'PReLU': + config['class_name'] = 'PReLU' + config['param_data'] = np.array(layer.alpha) + config['activation'] = 'prelu' + else: + config['class_name'] = 'LeakyReLU' + config['activ_param'] = float(layer.negative_slope) + config['activation'] = 'leaky_relu' + + return (config,) + + +@register +class KV3SoftmaxHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.softmax.Softmax',) + + def handle( + self, + layer: 'keras.layers.Softmax', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ax = layer.axis + ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax + # io_stream asserts axis=-1, convert to -1 when it is + n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore + ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax + config = {} + config.update(self.default_config) + if len(in_tensors) == 2: + raise NotImplementedError("Masked softmax not supported yet") + config['class_name'] = 'MaskedSoftmax' + elif len(in_tensors) == 1: + config['class_name'] = 'Softmax' + else: + raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + config['axis'] = layer.axis + config['activation'] = 'softmax' + config['n_outer'] = (n_outer,) + config['n_inner'] = n_inner + + return (config,) + + +@register +class KV3HardActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.elu.ELU',) + + def handle( + self, + layer: 'keras.layers.ELU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + config['class_name'] = 'ELU' + config['activ_param'] = float(layer.alpha) + config['activation'] = 'elu' + + return (config,) From 755be89ee30110760e898b4c9dc05ef7421903f8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 15:43:28 -0800 Subject: [PATCH 03/30] expose kv3 parser to config interface --- hls4ml/utils/config.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index b14d1ce99d..8e2d9b7011 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -156,13 +156,21 @@ def config_from_keras_model( layer_list = [] if isinstance(model, dict): + # keras v2 only model_arch = model + reader = hls4ml.converters.KerasModelReader(model) + layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) else: - model_arch = json.loads(model.to_json()) + import keras - reader = hls4ml.converters.KerasModelReader(model) + # model is keras.Model here - layer_list, _, _, _ = hls4ml.converters.parse_keras_model(model_arch, reader) + if keras.__version__ > '3.0': + layer_list, *_ = hls4ml.converters.parse_keras_v3_model(model) + else: + model_arch = json.loads(model.to_json()) + reader = hls4ml.converters.KerasModelReader(model) + layer_list, *_ = hls4ml.converters.parse_keras_model(model_arch, reader) def make_layer_config(layer): cls_name = layer['class_name'] From a1c2227a0d62be10cae51a29d4f8a7b5ba1da143 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 16:21:00 -0800 Subject: [PATCH 04/30] add kv3 converter test --- test/pytest/test_keras_v3_api.py | 516 +++++++++++++++++++++++++++++++ 1 file changed, 516 insertions(+) create mode 100644 test/pytest/test_keras_v3_api.py diff --git a/test/pytest/test_keras_v3_api.py b/test/pytest/test_keras_v3_api.py new file mode 100644 index 0000000000..81ac5c240c --- /dev/null +++ b/test/pytest/test_keras_v3_api.py @@ -0,0 +1,516 @@ +import math +from pathlib import Path + +import keras +import numpy as np +import pytest + +if keras.__version__ < '3.0': + pytest.skip('Keras API tests are only for Keras 3.0 and above', allow_module_level=True) + +from keras.api.layers import ( + ELU, + Activation, + AveragePooling1D, + AveragePooling2D, + Conv1D, + Conv2D, + Dense, + DepthwiseConv1D, + DepthwiseConv2D, + LeakyReLU, + MaxPooling1D, + MaxPooling2D, + PReLU, +) + +import hls4ml + +test_root_path = Path('/tmp/tests') + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_dense(backend, io_type): + model = keras.Sequential( + [ + Dense( + 2, + input_shape=(1,), + name='Dense', + use_bias=True, + kernel_initializer=keras.initializers.RandomUniform(minval=1, maxval=10), # type: ignore + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + ), + Activation(activation='elu', name='Activation'), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, 1) + + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}_{io_type}') + + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + + hls_model.compile() + + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert list(hls_model.get_layers())[0].attributes['class_name'] == "InputLayer" + assert list(hls_model.get_layers())[1].attributes["class_name"] == model.layers[0].name + assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ELU' + + +# TODO: add ThresholdedReLU test when it can be made to pass +# https://github.com/fastmachinelearning/hls4ml/issues/376 + + +@pytest.mark.parametrize( + "activation_function", + [ + Activation(activation='relu', name='relu'), + LeakyReLU(negative_slope=0.5), + ELU(alpha=1.0), + PReLU( + alpha_initializer="zeros", + ), + Activation(activation='sigmoid', name='sigmoid'), + ], +) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_activations(activation_function, backend, io_type): + model = keras.models.Sequential() + model.add(Dense(64, input_shape=(1,), name='Dense', kernel_initializer='lecun_uniform', kernel_regularizer=None)) + model.add(activation_function) + + model.compile(optimizer='adam', loss='mse') + + model.summary() + + X_input = np.random.rand(1000, 1) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_activations_{activation_function.name}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input) + + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) + + for layer in hls_model.get_layers(): + print(layer.attributes.attributes['class_name']) + assert len(model.layers) + 1 == len(hls_model.get_layers()) + + assert list(hls_model.get_layers())[2].attributes['class_name'] == activation_function.__class__.__name__ + + +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv1d(padds, backend, io_type): + model = keras.models.Sequential() + input_shape = (10, 128, 4) + model.add( + Conv1D( + filters=32, + kernel_size=3, + strides=2, + padding=padds, + activation='relu', + input_shape=input_shape[1:], + kernel_initializer='normal', + use_bias=False, + data_format='channels_last', + name='conv', + ) + ) + model.add(Activation(activation='relu')) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(10, 128, 4) + keras_prediction = model.predict(X_input, verbose=0) # type: ignore + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{padds}_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # 5e-2 might be too high + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + if backend in ('Vivado', 'Vitis', 'Catapult') and io_type == 'io_stream' and padds == 'same': + # Vivado/Vitis inserts and additional layer for 'same' padding in io_stream + return + + conv: keras.layers.Conv1D = model.layers[0] + ker_w, ch_in, ch_out = conv.kernel.shape + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + hls_attr = hls_model.graph['conv'].attributes + _stride = conv.strides[0] + + assert len(model.layers) + 2 == len(hls_model.get_layers()) + + assert hls_attr['name'] == model.layers[0].name + assert hls_attr['class_name'] == 'Conv1D' + assert hls_attr["in_width"] == inp_shape[0] + assert hls_attr['filt_width'] == ker_w + assert hls_attr['n_chan'] == ch_in + assert hls_attr['n_filt'] == ch_out + assert hls_attr['stride_width'] == _stride + assert hls_attr['data_format'] == conv.data_format + assert hls_attr["out_width"] == out_shape[0] + + w_pad = math.ceil(inp_shape[0] / ker_w) * ker_w - inp_shape[0] + + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + + if model.layers[0].padding == 'same': + assert hls_attr['pad_left'] == pad_left + assert hls_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_attr['pad_left'] == 0 + assert hls_attr['pad_right'] == 0 + + +chans_options = ['channels_last'] +padds_options = ['same', 'valid'] + + +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv2d(chans, padds, backend, io_type): + input_shape = (32, 32, 3) + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape), + Conv2D( + filters=32, + kernel_size=(2, 3), + strides=(4, 5), + padding=padds, + kernel_initializer='normal', + use_bias=False, + data_format=chans, + name='conv', + ), + ] + ) + model.compile(optimizer='adam', loss='mse') + + X_input = np.random.rand(1000, *input_shape) + keras_prediction = model.predict(X_input) + + config = hls4ml.utils.config_from_keras_model(model) + output_dir = str(test_root_path / f'hls4ml_project_keras_api_conv2d_{backend}_{chans}_{padds}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + + # A high tolerance, simply to verify correct functionality + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) + + hls_conv_attr = hls_model.graph['conv'].attributes + + conv: keras.layers.Conv2D = model.get_layer('conv') + + kh, kw, ch_in, ch_out = conv.kernel.shape # type: ignore + _stride = conv.strides + inp_shape = model.inputs[0].shape[1:] + out_shape = model.outputs[0].shape[1:] + + if io_type == 'io_stream' and padds == 'same' and backend in ('Vivado', 'Vitis', 'Catapult'): + return + + assert len(model.layers) + 1 == len(hls_model.get_layers()) + assert hls_conv_attr['name'] == conv.name + assert hls_conv_attr['class_name'] == 'Conv2D' + assert hls_conv_attr['filt_width'] == kw + assert hls_conv_attr['filt_height'] == kh + assert hls_conv_attr['n_filt'] == ch_out + assert hls_conv_attr['stride_width'] == _stride[1] + assert hls_conv_attr['stride_height'] == _stride[0] + assert hls_conv_attr['data_format'] == conv.data_format + + if conv.data_format == 'channels_first': + assert hls_conv_attr['n_chan'] == inp_shape[0] + assert hls_conv_attr['in_height'] == inp_shape[1] + assert hls_conv_attr['in_width'] == inp_shape[2] + assert hls_conv_attr['out_height'] == out_shape[1] + assert hls_conv_attr['out_width'] == out_shape[2] + elif model.layers[0].data_format == 'channels_last': + assert hls_conv_attr['n_chan'] == inp_shape[2] + assert hls_conv_attr['in_height'] == inp_shape[0] + assert hls_conv_attr['in_width'] == inp_shape[1] + assert hls_conv_attr['out_height'] == out_shape[0] + assert hls_conv_attr['out_width'] == out_shape[1] + + if conv.padding == 'same': + if conv.data_format == 'channels_first': + h_pad = math.ceil(inp_shape[1] / kh) * kh - inp_shape[1] + w_pad = math.ceil(inp_shape[2] / kw) * kw - inp_shape[2] + elif model.layers[0].data_format == 'channels_last': + h_pad = math.ceil(inp_shape[0] / kh) * kh - inp_shape[0] + w_pad = math.ceil(inp_shape[1] / kw) * kw - inp_shape[1] + else: + raise ValueError('Invalid data_format') + pad_top = h_pad // 2 + pad_bottom = h_pad - pad_top + pad_left = w_pad // 2 + pad_right = w_pad - pad_left + assert hls_conv_attr['pad_top'] == pad_top + assert hls_conv_attr['pad_bottom'] == pad_bottom + assert hls_conv_attr['pad_left'] == pad_left + assert hls_conv_attr['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert hls_conv_attr['pad_top'] == 0 + assert hls_conv_attr['pad_bottom'] == 0 + assert hls_conv_attr['pad_left'] == 0 + assert hls_conv_attr['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult']) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +def test_depthwise2d(backend, io_type): + ''' + Test proper handling of DepthwiseConv2D + ''' + X = np.random.rand(10, 32, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.models.Sequential([keras.layers.Input((32, 32, 3)), DepthwiseConv2D(kernel_size=(3, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<32,12>', backend=backend + ) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +# Currently only Vivado and Vitis is supported for io_stream. +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_depthwise1d(backend, io_type): + ''' + Test proper handling of DepthwiseConv1D. + ''' + X = np.random.rand(10, 32, 3) + X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6> + model = keras.Sequential([DepthwiseConv1D(kernel_size=3, input_shape=(32, 3))]) + model.compile() + + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) + output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + y_qkeras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qkeras, y_hls4ml.reshape(y_qkeras.shape), rtol=1e-2, atol=0.01) # type: ignore + + +pooling_layers = [MaxPooling1D, MaxPooling2D, AveragePooling1D, AveragePooling2D] + + +@pytest.mark.parametrize('pooling', pooling_layers) +@pytest.mark.parametrize('padds', padds_options) +@pytest.mark.parametrize('chans', chans_options) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI', 'Catapult']) +def test_pooling(pooling, padds, chans, backend): + assert '1D' in pooling.__name__ or '2D' in pooling.__name__ + + input_shape = (18, 15, 3) if '2D' in pooling.__name__ else (121, 3) + pool_size = (4, 2) if '2D' in pooling.__name__ else 2 + + X_input = np.random.rand(100, *input_shape) + + keras_model = keras.Sequential([pooling(pool_size, padding=padds, input_shape=input_shape)]) + keras_model.compile() + + hls_cfg = hls4ml.utils.config_from_keras_model(keras_model) + output_dir = str( + test_root_path / f'hls4mlprj_keras_api_pooling_{pooling.__name__}_channels_{chans}_padds_{padds}_backend_{backend}' + ) + hls_model = hls4ml.converters.convert_from_keras_model( + keras_model, hls_config=hls_cfg, output_dir=output_dir, backend=backend + ) + hls_model.compile() + + # Verify accuracy + keras_prediction = keras_model.predict(X_input) + hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # type: ignore + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=3e-2) + + # # Verify correct parsing of layer + # hls_pool = list(hls_model.get_layers())[-1] + # ker_pool = keras_model.layers[-1] + # if '2D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(2) + # assert hls_pool.attributes['stride_height'] == ker_pool.strides[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[1] + # assert hls_pool.attributes['pool_height'] == ker_pool.pool_size[1] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + + # if hls_pool.attributes['data_format'] == 'channels_last': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_first': + # assert hls_pool.attributes['in_height'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['in_width'] == ker_pool.input_shape[3] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[1] + + # if ker_pool.padding == 'same': + # # Height + # in_height = ker_pool.input_shape[1] + # if ker_pool.data_format == 'channels_first': + # in_height = ker_pool.input_shape[2] + # out_height = int(math.ceil(float(in_height) / float(ker_pool.strides[0]))) + # assert out_height == hls_pool.attributes['out_height'] + # if in_height % ker_pool.strides[0] == 0: + # pad_along_height = max(ker_pool.pool_size[1] - ker_pool.strides[0], 0) + # else: + # pad_along_height = max(ker_pool.pool_size[1] - (in_height % ker_pool.strides[0]), 0) + # pad_top = pad_along_height // 2 + # pad_bottom = pad_along_height - pad_top + # assert pad_bottom == hls_pool.attributes['pad_bottom'] + # assert pad_top == hls_pool.attributes['pad_top'] + + # # Width + # in_width = ker_pool.input_shape[2] + # if ker_pool.data_format == 'channels_first': + # in_height = keras_model.layers[1].input_shape[-1] + # out_width = int(math.ceil(float(in_width) / float(ker_pool.strides[1]))) + # assert out_width == hls_pool.attributes['out_width'] + # if in_width % ker_pool.strides[1] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[1], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (in_width % ker_pool.strides[1]), 0) + # pad_left = pad_along_width // 2 + # pad_right = pad_along_width - pad_left + # assert pad_left == hls_pool.attributes['pad_left'] + # assert pad_right == hls_pool.attributes['pad_right'] + + # elif ker_pool.padding == 'valid': + # if hls_pool.attributes['data_format'] == 'channels_first': + # in_height = ker_pool.input_shape[2] + # in_width = ker_pool.input_shape[3] + # elif hls_pool.attributes['data_format'] == 'channels_last': + # in_height = ker_pool.input_shape[1] + # in_width = ker_pool.input_shape[2] + # else: + # raise ValueError('Invalid data_format') + + # out_width = int(math.ceil(float(in_width - ker_pool.pool_size[0] + 1) / float(ker_pool.strides[1]))) + # out_height = int(math.ceil(float(in_height - ker_pool.pool_size[1] + 1) / float(ker_pool.strides[0]))) + + # assert hls_pool.attributes['out_height'] == out_height + # assert hls_pool.attributes['out_width'] == out_width + # assert hls_pool.attributes['pad_top'] == 0 + # assert hls_pool.attributes['pad_bottom'] == 0 + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + # elif '1D' in pooling.__name__: + # assert hls_pool.attributes['name'] == ker_pool._name + # assert hls_pool.attributes['class_name'][-2] == str(1) + # assert hls_pool.attributes['n_in'] == ker_pool.input_shape[1] + # assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[2] + # assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0] + # assert hls_pool.attributes['stride_width'] == ker_pool.strides[0] + + # out_same = math.ceil(float(ker_pool.input_shape[1]) / float(ker_pool.strides[0])) + # out_valid = math.ceil(float(ker_pool.input_shape[1] - ker_pool.pool_size[0] + 1) / ker_pool.strides[0]) + + # if ker_pool.padding == 'same': + # assert hls_pool.attributes['n_out'] == out_same + # if ker_pool.input_shape[1] % ker_pool.strides[0] == 0: + # pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[0], 0) + # else: + # pad_along_width = max(ker_pool.pool_size[0] - (ker_pool.input_shape[1] % ker_pool.strides[0]), 0) + # assert hls_pool.attributes['pad_left'] == pad_along_width // 2 + # assert hls_pool.attributes['pad_right'] == pad_along_width - pad_along_width // 2 + + # elif ker_pool.padding == 'valid': + # assert hls_pool.attributes['n_out'] == out_valid + # assert hls_pool.attributes['pad_left'] == 0 + # assert hls_pool.attributes['pad_right'] == 0 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_reused_layer(backend, io_type): + + inp1 = keras.layers.Input(shape=(10, 10)) + inp2 = keras.layers.Input(shape=(10, 10)) + + conv = keras.layers.Conv1D(2, 3, activation='relu') + + o1 = conv(inp1) + o2 = conv(inp2) + o3 = keras.layers.Add()([o1, o2]) + o4 = keras.layers.Dense(5)(o3) + + _ = keras.layers.Dense(5)(o3) + + model = keras.models.Model(inputs=[inp1, inp2], outputs=[o1, o2, o3, o4]) + + _ = model([inp1, inp1]) + + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}} + output_dir = str(test_root_path / f'hls4mlprj_keras_api_conv1d_{backend}_{io_type}') + + model_hls = hls4ml.converters.convert_from_keras_model( + model, backend=backend, io_type=io_type, hls_config=hls_config, output_dir=output_dir + ) + + model_hls.compile() + + data = [np.random.rand(1000, 10, 10).astype(np.float32), np.random.rand(1000, 10, 10).astype(np.float32)] + keras_pred = model.predict(data) + hls_pred = model_hls.predict(data) + + np.testing.assert_allclose(keras_pred[0].reshape(hls_pred[0].shape), hls_pred[0], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[1].reshape(hls_pred[1].shape), hls_pred[1], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[2].reshape(hls_pred[2].shape), hls_pred[2], rtol=0, atol=1e-5) + np.testing.assert_allclose(keras_pred[3].reshape(hls_pred[3].shape), hls_pred[3], rtol=0, atol=1e-2) From 067ef9ebfe532019c2957852a17ddf9befde5ae2 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 15:53:22 -0800 Subject: [PATCH 05/30] einsumdense and einsum --- hls4ml/backends/vivado/passes/einsum.py | 105 +++++++ hls4ml/backends/vivado/passes/einsum_dense.py | 145 ++++++++++ hls4ml/backends/vivado/vivado_backend.py | 29 ++ hls4ml/converters/keras_v3/einsum_dense.py | 75 +++++ hls4ml/model/layers.py | 131 ++++++++- hls4ml/utils/einsum_utils.py | 256 ++++++++++++++++++ 6 files changed, 740 insertions(+), 1 deletion(-) create mode 100644 hls4ml/backends/vivado/passes/einsum.py create mode 100644 hls4ml/backends/vivado/passes/einsum_dense.py create mode 100644 hls4ml/converters/keras_v3/einsum_dense.py create mode 100644 hls4ml/utils/einsum_utils.py diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py new file mode 100644 index 0000000000..aced45425b --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -0,0 +1,105 @@ +from math import ceil + +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Einsum + +from .reshaping_templates import transpose_config_gen + +# Shared Dense template +# Einsum template + +einsum_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp0 tpose_inp0_conf; + typedef config{index}_tpose_inp1 tpose_inp1_conf; + typedef config{index}_tpose_out tpose_out_conf; + + typedef {accum_t.name} accum_t; + + // Layer Sizes + static const unsigned n_free0 = {n_free0}; + static const unsigned n_free1 = {n_free1}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned multiplier_limit = {multiplier_limit}; + static const bool store_weights_in_bram = false; // NOT USED + + template + using product = nnet::product::{product_type}; +}}; +''' + +einsum_function_template = 'nnet::einsum<{input0_t}, {input1_t}, {output_t}, {config}>({input0}, {input1}, {output});' + +einsum_include_list = ['nnet_utils/nnet_einsum.h'] + + +class EinsumConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Einsum) + self.template = einsum_config_template + + def format(self, node: Einsum): + default_params = self._default_config_params(node) + + strategy = node.attributes.attributes['strategy'] + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free0'] = node.attributes.attributes['n_free0'] + params['n_free1'] = node.attributes.attributes['n_free1'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + inp0_t = node.get_input_variable(node.inputs[0]).type.precision + inp1_t = node.get_input_variable(node.inputs[1]).type.precision + params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t) + + total_mults = params['n_free0'] * params['n_free1'] * params['n_contract'] * params['n_inplace'] + params['multiplier_limit'] = ceil(total_mults / params['reuse_factor']) + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp0_shape = node.attributes.attributes['inp0_shape'] + inp1_shape = node.attributes.attributes['inp1_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp0_tpose_idxs = node.attributes.attributes['inp0_tpose_idxs'] + inp1_tpose_idxs = node.attributes.attributes['inp1_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp0_conf_name = f'config{node.index}_tpose_inp0' + tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + inp0_tpose_conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + inp1_tpose_conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + + return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf)) + + +class EinsumFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Einsum, include_header=einsum_include_list) + self.template = einsum_function_template + + def format(self, node: Einsum): + params = {} + params['config'] = f'config{node.index}' + params['input0_t'] = node.get_input_variable(node.inputs[0]).type.name + params['input1_t'] = node.get_input_variable(node.inputs[1]).type.name + params['output_t'] = node.get_output_variable().type.name + params['input0'] = node.get_input_variable(node.inputs[0]).name + params['input1'] = node.get_input_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py new file mode 100644 index 0000000000..3e16ed2ad3 --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -0,0 +1,145 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import EinsumDense + +from .reshaping_templates import transpose_config_gen + +# Shared Dense template + +dense_config_template = """struct config{index}_dense : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + template + using kernel = nnet::{dense_function}; + template + using product = nnet::product::{product_type}; +}};\n""" + +# EinsumDense template + +einsum_dense_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp tpose_inp_conf; + typedef config{index}_tpose_out tpose_out_conf; + {kernel_config}; + + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + + // Layer Sizes + static const unsigned n_free_data = {n_free_data}; + static const unsigned n_free_kernel = {n_free_kernel}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED +}}; +''' + +einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +einsum_dense_da_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {b});' + +einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] + + +class EinsumDenseConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(EinsumDense) + self.template = einsum_dense_config_template + self.dense_template = dense_config_template + + def dense_config(self, node: EinsumDense): + dense_params = self._default_config_params(node) + strategy = node.attributes['strategy'] + dense_params['strategy'] = strategy + dense_params['n_in'] = node.attributes.attributes['n_contract'] + dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] + if node.attributes.attributes['n_inplace'] == 1: + dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore + else: + dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' + dense_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore + ) + + dense_params['dense_function'] = 'DenseLatency' # Latency only for now + + dense_config = self.dense_template.format(**dense_params) + return dense_config + + def format(self, node: EinsumDense): + default_params = self._default_config_params(node) + + strategy = node.attributes['strategy'] + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel and distributed_arithmetic' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free_data'] = node.attributes.attributes['n_free_data'] + params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + if strategy.lower() == 'latency': + params['kernel_config'] = f'typedef config{node.index}_dense dense_conf' + else: + assert strategy.lower() == 'distributed_arithmetic', 'EinsumDense layer only supports Latency strategy for now' + inp_t = node.get_input_variable().type.name + result_t = node.get_output_variable().type.name + index = node.index + conf = f'constexpr static auto da_kernel = nnet::einsum_dense{index}_da_kernel<{inp_t}, {result_t}>' + params['kernel_config'] = conf + pf = node.attributes.attributes['parallelization_factor'] + if pf < 0: + pf = params['n_inplace'] + params['parallelization_factor'] = pf + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp_shape = node.attributes.attributes['inp_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp_conf_name = f'config{node.index}_tpose_inp' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + + if strategy.lower() == 'distributed_arithmetic': + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, einsum_conf)) + + dense_config = self.dense_config(node) + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) + + +class EinsumDenseFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(EinsumDense, include_header=einsum_dense_include_list) + self.template = einsum_dense_function_template + + def format(self, node): + params = self._default_function_params(node) + params['b'] = node.get_weights('bias').name + + strategy = node.attributes['strategy'] + if strategy == 'distributed_arithmetic': + return einsum_dense_da_function_template.format(**params) + + params['w'] = node.get_weights('weight').name + return einsum_dense_function_template.format(**params) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 117805dd86..f8908d6011 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -17,6 +17,8 @@ Dense, DepthwiseConv1D, DepthwiseConv2D, + Einsum, + EinsumDense, Embedding, GarNet, GarNetStack, @@ -660,3 +662,30 @@ def init_garnet(self, layer): @layer_optimizer(GarNetStack) def init_garnet_stack(self, layer): self.init_garnet(layer) + + @layer_optimizer(EinsumDense) + def init_einsum_dense(self, layer: EinsumDense) -> None: + strategy: str | None = layer.model.config.get_strategy(layer) + if not strategy: + layer.set_attr('strategy', 'latency') + return + if strategy in ('latency', 'resource', 'distributed_arithmetic'): + layer.set_attr('strategy', strategy) + return + warn(f'Invalid strategy "{strategy}" for EinsumDense layer "{layer.name}". Using "latency" strategy instead.') + layer.set_attr('strategy', 'latency') + + @layer_optimizer(Einsum) + def init_einsum(self, layer: Einsum) -> None: + strategy: str | None = layer.model.config.get_strategy(layer) + if not strategy: + layer.set_attr('strategy', 'latency') + return + if strategy.lower() == 'resource': + layer.set_attr('strategy', 'resource') + return + if strategy.lower() in ('latency', 'distributed_arithmetic'): + layer.set_attr('strategy', 'latency') + return + warn(f'Invalid strategy "{strategy}" for Einsum layer "{layer.name}". Using "latency" strategy instead.') + layer.set_attr('strategy', 'latency') diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py new file mode 100644 index 0000000000..8eb000fcf7 --- /dev/null +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -0,0 +1,75 @@ +import typing +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +def strip_batch_dim(equation: str, einsum_dense: bool = True): + """Remove the batch dimension from the equation. + + Args: + equation (str): The einsum equation. + einsum_dense (bool): Whether the equation is for EinsumDense layer. + + Returns: + str: The einsum equation without the batch dimension. + """ + + _inps, out = equation.split('->') + inp0, inp1 = _inps.split(',') + if einsum_dense: + if inp0.startswith('...'): + assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + else: + assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' + inp0, out = inp0[1:], out[1:] + else: + assert inp0[0] == inp1[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' + inp0, inp1, out = inp0[1:], inp1[1:], out[1:] + return f'{inp0},{inp1}->{out}' + + +@register +class KV3EinsumDenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) + + def handle( + self, + layer: 'keras.layers.EinsumDense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor' + assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor' + + inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + + # fmt: off + assert all(d is not None for d in inp_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes' + assert all(d is not None for d in out_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes' + # fmt: on + + equation = strip_batch_dim(layer.equation, True) + + kernel = self.load_weight(layer, 'kernel') + + bias = None + if layer.bias_axes: + bias = self.load_weight(layer, 'bias') + + return { + 'class_name': 'EinsumDense', + 'equation': equation, + 'weight_data': kernel, + 'bias_data': bias, + 'inp_shape': inp_shape, + 'out_shape': out_shape, + } diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 03e3d9ce8a..c8316520a5 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -27,10 +27,12 @@ find_minimum_width, ) from hls4ml.utils import attribute_descriptions as descriptions +from hls4ml.utils.einsum_utils import parse_einsum from hls4ml.utils.string_utils import convert_to_snake_case - # TODO move this to some utility module + + class classproperty: def __init__(self, func): self.func = func @@ -1621,6 +1623,131 @@ def initialize(self): self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y') +class EinsumDense(Layer): + _expected_attributes = [ + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + kernel: np.ndarray = self.attributes.attributes['weight_data'] + bias: np.ndarray | None = self.attributes.attributes['bias_data'] + equation = self.attributes['equation'] + inp_shape = self.attributes['inp_shape'] + out_shape = self.attributes['out_shape'] + + kernel_shape = kernel.shape + recipe = parse_einsum(equation, inp_shape, kernel_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. + # hls4ml dense acts like i,ij->j + # parser assumes ij,j->i, so we need to transpose the kernel to match + kernel = kernel.transpose(ker_tpose_idxs) + kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) + + def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: + _kernel = tkernel.transpose(0, 2, 1) + _kernel = _kernel.reshape(tuple(kernel_shape[i] for i in ker_tpose_idxs)) + return _kernel.transpose(np.argsort(ker_tpose_idxs)) + + # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. + if bias is not None: + bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) + else: + # The automatically created bias is just the last dimension of the output shape + # Which is too small in general for einsum dense. + # The transpose is just to match the shape in case of have real bias, no real effect. + bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) + + self.attributes.attributes['weight_data'] = kernel + self.attributes.attributes['to_original_kernel'] = to_original_kernel + self.attributes.attributes['bias_data'] = bias + self.attributes['inp_tpose_idxs'] = inp_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + self.attributes['n_free_data'] = recipe['L0'] + self.attributes['n_free_kernel'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + self.add_weights(compression=self.model.config.get_compression(self)) + self.add_bias() + + +class Matmul(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('inup1_shape', value_type=tuple), + Attribute('inp2_shape', value_type=tuple), + ] + + +class Einsum(Layer): + _expected_attributes = [ + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp0_shape', value_type=tuple), + Attribute('inp1_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + equation = self.attributes['equation'] + inp0_shape = self.attributes['inp0_shape'] + inp1_shape = self.attributes['inp1_shape'] + out_shape = self.attributes['out_shape'] + + recipe = parse_einsum(equation, inp0_shape, inp1_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + self.attributes.attributes.update(recipe) + self.attributes['n_free0'] = recipe['L0'] + self.attributes['n_free1'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + + self.attributes['inp0_tpose_idxs'] = inp0_tpose_idxs + self.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + layer_map = { 'Input': Input, 'InputLayer': Input, @@ -1687,6 +1814,8 @@ def initialize(self): 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, + 'EinsumDense': EinsumDense, + 'Einsum': Einsum, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, } diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py new file mode 100644 index 0000000000..43ceb2ba96 --- /dev/null +++ b/hls4ml/utils/einsum_utils.py @@ -0,0 +1,256 @@ +from math import prod +from typing import TypedDict + +import numpy as np + + +class EinsumRecipe(TypedDict): + direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]] + in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]] + L0: int + L1: int + I: int + C: int + out_interpert_shape: tuple[int, ...] + out_transpose_idxs: tuple[int, ...] + + +def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): + """Validate, resolve broadcasting, and compute output shape for einsum string + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + shape0 : tuple[int,...] + shape of input0 + shape1 : tuple[int,...] + shape of input1 + + Returns + ------- + tuple[str, tuple[int,...]] + einsum string w/o broadcasting, and output shape + + Raises + ------ + ValueError + If the einsum string is invalid, or if it is incompatible with the input shapes + """ + inp, out = map(str.strip, fn.split('->')) + in0, in1 = map(str.strip, inp.split(',')) + alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + s_alphabets = set(alphabets) + + # Invalid characters + if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))): + raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only") + + in0 = in0.replace('...', '0') + in1 = in1.replace('...', '0') + out = out.replace('...', '0') + ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out) + sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out) + free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out)) + + # Repeated indices + if len(sax_in0) != len(ax_in0): + for a in in0: + if in0.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times") + if len(sax_in1) != len(ax_in1): + for a in in1: + if in1.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times") + if len(sax_out) != len(ax_out): + for a in out: + if out.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times") + + # Invalid broadcasting + if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out: + if '0' in sax_in0 and '0' in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: both input0 and input1 allows broadcasting") + if '0' not in sax_out: + raise ValueError(f"einsum string {fn} is invalid: output does not allow broadcasting, but inputs do") + if '0' not in sax_in0 and '0' not in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output allows broadcasting, but inputs do not") + + # Output index out of nowhere + if remaining := sax_out - sax_in0 - sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output subscripts {remaining} not found in inputs") + + _common_in = sax_in0 & sax_in1 + + # Invalid input dimensions + if '0' in sax_in0: + if len(sax_in0) - 1 > len(shape0): + raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape0) - len(sax_in0) + 1 + in0 = in0.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in0 = list(in0) + ax_out = list(out) + else: + if len(sax_in0) != len(shape0): + raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + if '0' in sax_in1: + if len(sax_in1) - 1 > len(shape1): + raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape1) - len(sax_in1) + 1 + in1 = in1.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in1 = list(in1) + ax_out = list(out) + else: + if len(sax_in1) != len(shape1): + raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + + # Input dimension mismatch + for a in _common_in: + ax_0 = ax_in0.index(a) + ax_1 = ax_in1.index(a) + if shape0[ax_0] != shape1[ax_1]: + raise ValueError( + f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}" + ) + + out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out) + return f'{in0},{in1}->{out}', out_shape + + +def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: + """Parse einsum operation on two input arrays, return a recipe for execution + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + EinsumRecipe + einsum recipe; executed by _exec_einsum + """ + + fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) + + _in, _out = fn.split('->') + _in0, _in1 = _in.split(',') + + in0, in1, out = list(_in0), list(_in1), list(_out) + s_in0, s_in1, s_out = set(in0), set(in1), set(out) + _common = s_in0 & s_in1 + _contract = _common - s_out + _inplace = _common & s_out + contract = sorted(_contract, key=lambda x: in1.index(x)) + inplace = sorted(_inplace, key=lambda x: in1.index(x)) + invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x)) + invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x)) + direct_sum0 = s_in0 - s_out - _common + direct_sum1 = s_in1 - s_out - _common + direct_sum_axis = ( + tuple(sorted(in0.index(x) for x in direct_sum0)), + tuple(sorted(in1.index(x) for x in direct_sum1)), + ) + + contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract)) + inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace)) + invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1)) + + inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0]) + inplace_size = prod(inplace_shape) + contract_size = prod(input_shape0[i] for i in contract_idxs[0]) + invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0]) + invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1]) + invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1) + + transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0] + transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1] + + out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1 + _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1))) + out_transpose_idx = tuple(int(i) for i in _out_transpose_idx) + + return EinsumRecipe( + direct_sum_axis=direct_sum_axis, + in_transpose_idxs=(transpose_idx0, transpose_idx1), + out_interpert_shape=out_shape_pretranspose, + out_transpose_idxs=out_transpose_idx, + L0=invariant_size0, + L1=invariant_size1, + I=inplace_size, + C=contract_size, + ) + + +def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays + + Parameters + ---------- + recipe : EinsumRecipe + einsum recipe + input0 : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + sum_axis0, sum_axis1 = recipe['direct_sum_axis'] + if sum_axis0: + input0 = np.sum(input0, axis=sum_axis0) + if sum_axis1: + input1 = np.sum(input1, axis=sum_axis1) + input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel() + input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel() + output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=input0.dtype) + + L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C'] + + for l0 in range(L0): + for i in range(I): + A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C)) + B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C] + output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B + + return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs']) + + +def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays. + + WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + recipe = parse_einsum(fn, input0.shape, input1.shape) + return _exec_einsum(recipe, input0, input1) From d8bb729e5d03c95695e3ac0b7fc6bc3c9e4e6526 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 16:14:24 -0800 Subject: [PATCH 06/30] add einsum templates --- .../templates/vivado/nnet_utils/nnet_einsum.h | 83 +++++++++++++ .../vivado/nnet_utils/nnet_einsum_dense.h | 114 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_einsum.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h new file mode 100644 index 0000000000..cc2917783c --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h @@ -0,0 +1,83 @@ +#ifndef NNET_EINSUM_H_ +#define NNET_EINSUM_H_ + +#include "nnet_common.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct config_einsum { + typedef void tpose_inp0_conf; + typedef void tpose_inp1_conf; + typedef void tpose_out_conf; + + // Layer Sizes + static const unsigned n_free0; + static const unsigned n_free1; + static const unsigned n_contract; + static const unsigned n_inplace; + + // Resource reuse info + static const unsigned io_type; + static const unsigned strategy; + static const unsigned reuse_factor; + static const unsigned multiplier_limit; + static const bool store_weights_in_bram = false; // NOT USED + + template using product = nnet::product::mult; +}; + +template +void einsum(const data0_T data0[CONFIG_T::tpose_inp0_conf::N], const data1_T data1[CONFIG_T::tpose_inp1_conf::N], + res_T res[CONFIG_T::tpose_out_conf::N]) { + + #pragma HLS PIPELINE II = CONFIG_T::reuse_factor + #pragma HLS ALLOCATION operation instances = mul limit = CONFIG_T::multiplier_limit + + data0_T tpose_i0[CONFIG_T::tpose_inp0_conf::N]; + data1_T tpose_i1[CONFIG_T::tpose_inp1_conf::N]; + res_T tpose_o[CONFIG_T::tpose_out_conf::N]; + + #pragma HLS ARRAY_PARTITION variable = tpose_i0 complete + #pragma HLS ARRAY_PARTITION variable = tpose_i1 complete + #pragma HLS ARRAY_PARTITION variable = tpose_o complete + + nnet::transpose(data0, tpose_i0); + nnet::transpose(data1, tpose_i1); + + // for l0 in range(L0): + // for i in range(I): + // output[(i*L0+l0)*L1:(i*L0+l0+1)*L1] = input1[i*L1*C:(i+1)*L1*C].reshape((L1,C)) @ + // input0[(i*L0+l0)*C:(i*L0+l0+1)*C] + + constexpr unsigned L0 = CONFIG_T::n_free0; + constexpr unsigned L1 = CONFIG_T::n_free1; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + typename CONFIG_T::accum_t accum_buf; + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL + for (unsigned l1 = 0; l1 < L1; l1++) { + #pragma HLS UNROLL + accum_buf = 0; + for (unsigned c = 0; c < C; c++) { + #pragma HLS UNROLL + data0_T a = tpose_i0[(i * L0 + l0) * C + c]; + data1_T b = tpose_i1[i * L1 * C + l1 * C + c]; + accum_buf += CONFIG_T::template product::product(a, b); + } + tpose_o[(i * L0 + l0) * L1 + l1] = accum_buf; + } + } + } + + nnet::transpose(tpose_o, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h new file mode 100644 index 0000000000..9f26ff0bd7 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h @@ -0,0 +1,114 @@ +#ifndef NNET_EINSUM_DENSE_H_ +#define NNET_EINSUM_DENSE_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense_latency.h" +#include "nnet_dense_resource.h" +#include "nnet_function_stubs.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct einsum_dense_config { + // Internal data type definitions + + typedef void tpose_inp_conf; + typedef void tpose_out_conf; + typedef void dense_conf; + + // Layer Sizes + static const unsigned n_free_data = 1; + static const unsigned n_free_kernel = 1; + static const unsigned n_contract = 1; + static const unsigned n_inplace = 1; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED + + // Product function to use + template using product = nnet::product::mult; +}; + +template +void einsum_dense( + data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + // even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such + // so reusing the same multiplier for different weights doesn't really help... only full unrolling for now + dense(&inp_tpose[(i * L0 + l0) * C], out_buffer, + &weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]); + for (unsigned j = 0; j < L1; j++) { + #pragma HLS UNROLL + out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j]; + } + } + } + + nnet::transpose(out_tpose, res); +} + +template +typename std::enable_if::type +einsum_dense(data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + // for (unsigned i = 0; i < I; i++) { + // #pragma HLS UNROLL + // inp_tpose[(i * L0 + l0) * C]->out_tpose[(i * L0 + l0) * L1]; + // } + CONFIG_T::da_kernel(inp_tpose, out_tpose, l0); + } + for (unsigned ii = 0; ii < (L0 * L1 * I); ii++) { + #pragma HLS UNROLL + out_tpose[ii] = out_tpose[ii] + biases[ii]; + } + + nnet::transpose(out_tpose, res); +} + +} // namespace nnet + +#endif From 303db72eaac9baf3d2b038bbf2a038e26469c8a8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 16:20:08 -0800 Subject: [PATCH 07/30] einsumdense test --- test/pytest/test_einsum_dense.py | 57 ++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/pytest/test_einsum_dense.py diff --git a/test/pytest/test_einsum_dense.py b/test/pytest/test_einsum_dense.py new file mode 100644 index 0000000000..dbddf545ff --- /dev/null +++ b/test/pytest/test_einsum_dense.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import keras +import numpy as np +import pytest + +from hls4ml.converters import convert_from_keras_model + +if keras.__version__ < '3.0.0': + pytest.skip('Only keras v3 is supported for now', allow_module_level=True) + +from keras.api.layers import EinsumDense, Input + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('strategy', ['latency', 'distributed_arithmetic']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize( + 'operation', + [ + # eq, inp, out + ('bi,j->bij', (8,), (8, 7), None), + ('bi,j->bij', (8,), (8, 7), 'i'), + ('bi,j->bij', (8,), (8, 7), 'j'), + ('bi,io->bo', (8,), 7, None), + ('...i,oi->...o', (4, 3), (5,), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), None), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'aeb'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'ab'), + ('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'a'), + ], +) +def test_einsum_dense(backend, io_type, strategy, operation): + eq, inp_shape, out_shape, bias_axes = operation + model = keras.Sequential( + [Input(inp_shape), EinsumDense(eq, output_shape=out_shape, bias_axes=bias_axes, name='einsum_dense')] + ) + + if bias_axes is not None: + layer = model.get_layer('einsum_dense') + layer.bias.assign(keras.ops.convert_to_tensor(np.random.rand(*layer.bias.shape))) + + data = np.random.rand(1000, *inp_shape) + eq_name = eq.replace(',', '_').replace('->', '_') + ('' if bias_axes is None else f'_{bias_axes}') + output_dir = str(test_root_path / f'hls4mlprj_einsum_dense_{eq_name}_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}, 'Strategy': strategy} + model_hls = convert_from_keras_model( + model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_keras = model.predict(data, verbose=0, batch_size=1000) # type: ignore + r_hls = model_hls.predict(data).reshape(r_keras.shape) # type: ignore + + np.testing.assert_allclose(r_hls, r_keras, atol=2e-6, rtol=0) From 56c0731507e21595d81a8213b01352d8a91694df Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 15:55:19 -0800 Subject: [PATCH 08/30] support kv3 parsed batchnorm --- hls4ml/model/layers.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index c8316520a5..ee8f44dbad 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1018,16 +1018,21 @@ def initialize(self): dims = inp.dim_names self.add_output_variable(shape, dims) - gamma = self.get_attr('gamma_data') - beta = self.get_attr('beta_data') - mean = self.get_attr('mean_data') - var = self.get_attr('variance_data') - - scale = gamma / np.sqrt(var + self.get_attr('epsilon')) - bias = beta - scale * mean + if self.get_attr('scale_data') is None: + gamma = self.get_attr('gamma_data') + var = self.get_attr('variance_data') + scale = gamma / np.sqrt(var + self.get_attr('epsilon')) + self.add_weights_variable(name='scale', var_name='s{index}', data=scale) + else: + self.add_weights_variable(name='scale', var_name='s{index}') - self.add_weights_variable(name='scale', var_name='s{index}', data=scale) - self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + if self.get_attr('bias_data') is None: + beta = self.get_attr('beta_data') + mean = self.get_attr('mean_data') + bias = beta - scale * mean + self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + else: + self.add_weights_variable(name='bias', var_name='b{index}') # TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense From fe3fcd05d7cc05c8b984212de3c81b99a209a09b Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 7 Mar 2025 18:55:13 -0800 Subject: [PATCH 09/30] fix einsum/einsum dense regression issue --- hls4ml/backends/fpga/fpga_backend.py | 14 +++++++++++--- .../backends/oneapi/passes/reshaping_templates.py | 12 ++---------- hls4ml/backends/vivado/passes/einsum.py | 11 +++++++---- hls4ml/backends/vivado/passes/einsum_dense.py | 8 +++++--- .../backends/vivado/passes/reshaping_templates.py | 15 ++------------- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index bd85937d89..95d900fd62 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -914,7 +914,7 @@ def generate_conv2d_line_buffer_fn( return generated_code @staticmethod - def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): + def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): """ Generate new shape and perm_strides for a permute operation. Operates by mapping the output index to input input index by: @@ -933,12 +933,20 @@ def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]) perm (tuple[int, ...]): The permutation of the dimensions. Returns: - (new_shape, perm_strides) (tuple, tuple): the output shape and permutation strides. + dict: Dictionary containing the configuration. """ new_shape = tuple(shape[i] for i in perm) strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] perm_strides = tuple(int(strides[i]) for i in perm) - return (new_shape, perm_strides) + return dict( + dims=len(shape), + N=math.prod(shape), + from_shape=', '.join(str(x) for x in shape), + perm=', '.join(str(x) for x in perm), + perm_strides=', '.join(str(x) for x in perm_strides), + to_shape=', '.join(str(x) for x in new_shape), + config_name=name, + ) @model_optimizer() def write_hls(self, model): diff --git a/hls4ml/backends/oneapi/passes/reshaping_templates.py b/hls4ml/backends/oneapi/passes/reshaping_templates.py index 462758c228..80b467b944 100644 --- a/hls4ml/backends/oneapi/passes/reshaping_templates.py +++ b/hls4ml/backends/oneapi/passes/reshaping_templates.py @@ -185,16 +185,8 @@ def format(self, node): perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm) - return transpose_config_template.format( - dims=len(shape), - N=int(np.prod(shape)), - from_shape=', '.join(str(x) for x in shape), - perm=', '.join(str(x) for x in perm), - perm_strides=', '.join(str(x) for x in perm_strides), - to_shape=', '.join(str(x) for x in new_shape), - config_name=name, - ) + conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + return transpose_config_template.format(**conf) class TransposeFunctionTemplate(FunctionCallTemplate): diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py index aced45425b..4f976c63af 100644 --- a/hls4ml/backends/vivado/passes/einsum.py +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -4,7 +4,7 @@ from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Einsum -from .reshaping_templates import transpose_config_gen +from .reshaping_templates import transpose_config_template # Shared Dense template # Einsum template @@ -81,9 +81,12 @@ def format(self, node: Einsum): tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' tpose_out_conf_name = f'config{node.index}_tpose_out' - inp0_tpose_conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) - inp1_tpose_conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) - out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + conf = node.model.config.backend.transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + inp0_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + inp1_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + out_tpose_conf = transpose_config_template.format(**conf) return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf)) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py index 3e16ed2ad3..1b4b183039 100644 --- a/hls4ml/backends/vivado/passes/einsum_dense.py +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -2,7 +2,7 @@ from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import EinsumDense -from .reshaping_templates import transpose_config_gen +from .reshaping_templates import transpose_config_template # Shared Dense template @@ -118,8 +118,10 @@ def format(self, node: EinsumDense): tpose_inp_conf_name = f'config{node.index}_tpose_inp' tpose_out_conf_name = f'config{node.index}_tpose_out' - inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) - out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + conf = node.model.config.backend.transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + inp_tpose_conf = transpose_config_template.format(**conf) + conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + out_tpose_conf = transpose_config_template.format(**conf) if strategy.lower() == 'distributed_arithmetic': return '\n\n'.join((inp_tpose_conf, out_tpose_conf, einsum_conf)) diff --git a/hls4ml/backends/vivado/passes/reshaping_templates.py b/hls4ml/backends/vivado/passes/reshaping_templates.py index ff16d15c9d..69944e4497 100644 --- a/hls4ml/backends/vivado/passes/reshaping_templates.py +++ b/hls4ml/backends/vivado/passes/reshaping_templates.py @@ -1,5 +1,3 @@ -import numpy as np - from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D @@ -128,22 +126,13 @@ def format(self, node): class TransposeConfigTemplate(LayerConfigTemplate): def __init__(self): super().__init__(Transpose) - self.template = transpose_config_template def format(self, node): shape = tuple(node.get_input_variable().shape) perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm) - return transpose_config_template.format( - dims=len(shape), - N=np.prod(shape), - from_shape=', '.join(str(x) for x in shape), - perm=', '.join(str(x) for x in perm), - perm_strides=', '.join(str(x) for x in perm_strides), - to_shape=', '.join(str(x) for x in new_shape), - config_name=name, - ) + conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + return transpose_config_template.format(**conf) class TransposeFunctionTemplate(FunctionCallTemplate): From 54a297e3dc7ba9e465cad490ef55643f051011d9 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 11 Mar 2025 13:12:00 -0700 Subject: [PATCH 10/30] preemptive distributed_arithmetic flag for einsum ops --- hls4ml/templates/vivado/nnet_utils/nnet_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_common.h b/hls4ml/templates/vivado/nnet_utils/nnet_common.h index 6db3f62f6e..308892ba49 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_common.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_common.h @@ -24,7 +24,7 @@ namespace nnet { // Common type definitions enum io_type { io_parallel = 0, io_stream }; -enum strategy { latency, resource, resource_unrolled }; +enum strategy { latency, resource, resource_unrolled, distributed_arithmetic }; /* --- * Balanced tree reduce implementation. From 3509666178177ae2b23f7b519f12f791dc332a18 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 11 Mar 2025 20:39:01 +0000 Subject: [PATCH 11/30] update doc for kv3 --- docs/frontend/keras.rst | 4 ++-- docs/intro/setup.rst | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/frontend/keras.rst b/docs/frontend/keras.rst index d6d42cb4b8..31adf21efc 100644 --- a/docs/frontend/keras.rst +++ b/docs/frontend/keras.rst @@ -2,9 +2,9 @@ Keras and QKeras ================ -Keras and the quantization library QKeras are well supported in ``hls4ml``. Currently, the Keras v2 (``tf.keras``) is the preferred version, and the future versions of ``hls4ml`` will expand support for Keras v3. The frontend is based on the parsing the serialized json representation of the model. +Keras and the quantization library QKeras are well supported in ``hls4ml``. Both Keras v2 (``tf.keras``) and the new Keras v3 are supported. While the Keras v2 support is based on parsing the serialized json representation of the model, the Keras v3 support uses direct model inspection. -Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The equivalent QKeras API and quantizers are also supported. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. +Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The equivalent QKeras API and quantizers are also supported for Keras v2, but QKeras is currently not compatible with Keras v3. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. The ``data_format='channels_first'`` parameter of Keras layers is supported, but not extensively tested. All HLS implementations in ``hls4ml`` are based on ``channels_last`` data format and need to be converted to that format before the HLS code can be emitted. We encourage users of ``channels_first`` to report their experiences to developers on GitHub. diff --git a/docs/intro/setup.rst b/docs/intro/setup.rst index 6ba0c4ce0e..004c7523de 100644 --- a/docs/intro/setup.rst +++ b/docs/intro/setup.rst @@ -46,7 +46,9 @@ Dependencies The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed by ``pip`` or ``conda``. -* `TensorFlow `_ (version 2.8 to 2.14) and `QKeras `_ are required by the Keras converter. One may want to install newer versions of QKeras from GitHub. Newer versions of TensorFlow can be used, but QKeras and hl4ml do not currently support Keras v3. +* `TensorFlow `_ (version 2.8 to 2.14) and `QKeras `_ are required by the Keras v2 converter. One may want to install newer versions of QKeras from GitHub. + +* `Keras `_ (version 3.0.0 and newer) is required by the Keras v3 converter. Keras v3 supports multiple backends for training and inference, and the convertion is not tied any specific backend. Notice that Keras v3 may **not** coexist with Keras v2 in the same Python environment. * `ONNX `_ (version 1.4.0 and newer) is required by the ONNX converter. From c81028e690cdaf72431ebdbad7382052811f1d37 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 11 Mar 2025 21:28:27 +0000 Subject: [PATCH 12/30] more documentation --- docs/frontend/keras.rst | 12 ++++++++++-- docs/intro/setup.rst | 20 +++++++++++++++----- example-models | 2 +- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/docs/frontend/keras.rst b/docs/frontend/keras.rst index 31adf21efc..9ede7b1d8c 100644 --- a/docs/frontend/keras.rst +++ b/docs/frontend/keras.rst @@ -1,11 +1,19 @@ ================ -Keras and QKeras +Keras and its quantized variants ================ Keras and the quantization library QKeras are well supported in ``hls4ml``. Both Keras v2 (``tf.keras``) and the new Keras v3 are supported. While the Keras v2 support is based on parsing the serialized json representation of the model, the Keras v3 support uses direct model inspection. -Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The equivalent QKeras API and quantizers are also supported for Keras v2, but QKeras is currently not compatible with Keras v3. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. +Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`. The ``data_format='channels_first'`` parameter of Keras layers is supported, but not extensively tested. All HLS implementations in ``hls4ml`` are based on ``channels_last`` data format and need to be converted to that format before the HLS code can be emitted. We encourage users of ``channels_first`` to report their experiences to developers on GitHub. + +* `QKeras `_ + The equivalent QKeras API and its quantizers are also supported by ``hls4ml``. QKeras is not compatible with Keras v3. +* `HGQ `_ + The equivalent HGQ API is also supported. HGQ is not compatible with Keras v3. See `advanced/HGQ <../advanced/hgq.html>`__ for more information. +* `HGQ2 `_ + HGQ2 is based on Keras v3. Its support in hls4ml is currently under development. + The development team of ``hls4ml`` is currently exploring options for QKeras alternative and will provide a drop-in replacement API compatible with Keras v3. diff --git a/docs/intro/setup.rst b/docs/intro/setup.rst index 004c7523de..10f78ca865 100644 --- a/docs/intro/setup.rst +++ b/docs/intro/setup.rst @@ -43,16 +43,26 @@ version can be installed directly from ``git``: Dependencies ============ -The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed -by ``pip`` or ``conda``. +.. note:: + As of version 1.1.0+, all conversion frontend specific packages are optional. Only install the packages you need. -* `TensorFlow `_ (version 2.8 to 2.14) and `QKeras `_ are required by the Keras v2 converter. One may want to install newer versions of QKeras from GitHub. +The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed by ``pip`` or ``conda``. -* `Keras `_ (version 3.0.0 and newer) is required by the Keras v3 converter. Keras v3 supports multiple backends for training and inference, and the convertion is not tied any specific backend. Notice that Keras v3 may **not** coexist with Keras v2 in the same Python environment. +The following Python packages are all optional and are only required if you intend to use the corresponding converter. Only install the packages you need. + +* `Keras `_ is required by the Keras converter. + * `TensorFlow `_ (version 2.8 to 2.14) is required by the Keras v2 converter (keras v2 is included in TensorFlow). + * `Keras ` 3.0 or above is required by the Keras v3 converter. Keras v3 supports multiple backends for training and inference, and the conversion is not tied any specific backend. Notice that Keras v3 may **not** coexist with Keras v2 in the same Python environment. * `ONNX `_ (version 1.4.0 and newer) is required by the ONNX converter. -* `PyTorch `_ package is optional. If not installed, the PyTorch converter will not be available. +* `PyTorch `_ is required by the PyTorch converter. + +* Quantization support + * `QKeras `_: based on Keras v2. See `frontend/keras <../frontend/keras.html>`_ for more details + * `HGQ `_: Based on Keras v2. See `advanced/HGQ <../advanced/hgq.html>`_ for more details. + * `Brevitas `_: Based on PyTorch. See `frontend/pytorch <../frontend/pytorch.html>`_ for more details. + * `QONNX `_: Based on ONNX. See `frontend/onnx <../frontend/onnx.html>`_ for more details. Running C simulation from Python requires a C++11-compatible compiler. On Linux, a GCC C++ compiler ``g++`` is required. Any version from a recent Linux should work. On MacOS, the *clang*-based ``g++`` is enough. For the oneAPI backend, one must have oneAPI installed, along with the FPGA compiler, diff --git a/example-models b/example-models index c6bb3c0686..3cfbcfd062 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 +Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 From eed63301c013bd09d176668e6b365ab1d092ab82 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 15 Apr 2025 17:56:30 -0700 Subject: [PATCH 13/30] backport validate einsum function --- hls4ml/utils/einsum_utils.py | 64 +++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py index 43ceb2ba96..7476c7bab1 100644 --- a/hls4ml/utils/einsum_utils.py +++ b/hls4ml/utils/einsum_utils.py @@ -75,8 +75,6 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . # Invalid broadcasting if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out: - if '0' in sax_in0 and '0' in sax_in1: - raise ValueError(f"einsum string {fn} is invalid: both input0 and input1 allows broadcasting") if '0' not in sax_out: raise ValueError(f"einsum string {fn} is invalid: output does not allow broadcasting, but inputs do") if '0' not in sax_in0 and '0' not in sax_in1: @@ -88,31 +86,45 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . _common_in = sax_in0 & sax_in1 - # Invalid input dimensions - if '0' in sax_in0: - if len(sax_in0) - 1 > len(shape0): - raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") - # Replace broadcasting indices with free indices - n_broadcast = len(shape0) - len(sax_in0) + 1 - in0 = in0.replace('0', free_indices[:n_broadcast]) - out = out.replace('0', free_indices[:n_broadcast]) - ax_in0 = list(in0) - ax_out = list(out) - else: - if len(sax_in0) != len(shape0): - raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") - if '0' in sax_in1: - if len(sax_in1) - 1 > len(shape1): - raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") - # Replace broadcasting indices with free indices - n_broadcast = len(shape1) - len(sax_in1) + 1 - in1 = in1.replace('0', free_indices[:n_broadcast]) - out = out.replace('0', free_indices[:n_broadcast]) - ax_in1 = list(in1) - ax_out = list(out) + if '0' in sax_in0 and '0' in sax_in1: + # Simultaneous axes expansion in both inputs + n_boardcast0 = len(shape0) - len(sax_in0) + 1 + n_boardcast1 = len(shape1) - len(sax_in1) + 1 + assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.' + # Replace expansion indices with free indices + in0 = in0.replace('0', free_indices[:n_boardcast0]) + in1 = in1.replace('0', free_indices[:n_boardcast1]) + out = out.replace('0', free_indices[:n_boardcast0]) + ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out) + _common_in = set(ax_in0) & set(ax_in1) + else: - if len(sax_in1) != len(shape1): - raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + # Axes expansion in input0 or input1 only + if '0' in sax_in0: + if len(sax_in0) - 1 > len(shape0): + raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + # Replace auto expansion indices with free indices + n_broadcast = len(shape0) - len(sax_in0) + 1 + in0 = in0.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in0 = list(in0) + ax_out = list(out) + else: + if len(sax_in0) != len(shape0): + raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + + if '0' in sax_in1: + if len(sax_in1) - 1 > len(shape1): + raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + # Replace expansion indices with free indices + n_broadcast = len(shape1) - len(sax_in1) + 1 + in1 = in1.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in1 = list(in1) + ax_out = list(out) + else: + if len(sax_in1) != len(shape1): + raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") # Input dimension mismatch for a in _common_in: From cda903e88e72c84d24a36bfe9b17d4bdd5571601 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 15 Apr 2025 18:04:21 -0700 Subject: [PATCH 14/30] docstring style --- hls4ml/converters/keras_v3/_base.py | 112 ++++++++++++---------------- hls4ml/utils/einsum_utils.py | 105 ++++++++++---------------- 2 files changed, 89 insertions(+), 128 deletions(-) diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py index 6f50ed6523..c0549be7f0 100644 --- a/hls4ml/converters/keras_v3/_base.py +++ b/hls4ml/converters/keras_v3/_base.py @@ -37,26 +37,23 @@ def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... def register(cls: str | type): """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. - Parameters - ---------- - cls : str|type - If str, the key to register the handler under. If type, the class to register the handler for. - - Examples - -------- - ```python - @keras_dispatcher.register - class MyLayerHandler(KerasV3LayerHandler): - handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') - - def handle(self, layer, inp_tensors, out_tensors): - # handler code + Args: + cls: If str, the key to register the handler under. If type, the class to register the handler for. + + Examples: + ```python + @keras_dispatcher.register + class MyLayerHandler(KerasV3LayerHandler): + handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') + + def handle(self, layer, inp_tensors, out_tensors): + # handler code - @keras_dispatcher.register('MyLayer3') - def my_layer_handler(layer, inp_tensors, out_tensors): - # handler code - ``` + @keras_dispatcher.register('MyLayer3') + def my_layer_handler(layer, inp_tensors, out_tensors): + # handler code + ``` """ def deco(func): @@ -91,40 +88,34 @@ def __call__( in_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'], ) -> tuple[dict[str, Any], ...]: - """Handle a keras layer. Return a tuple of dictionaries, each - dictionary representing a layer (module) in the HLS model. One - layer may correspond one or more dictionaries (e.g., layers with - activation functions will be split into two layers). - - Some common attributes are automatically added to the dictionary - if the handler returns a single dictionary. If the handler - returns multiple dictionaries, the attributes must be added - manually. Anything returned by the handler will override the - automatic attributes. - - Automatic attributes: - name - class_name - module - - input_keras_tensor_names - input_shape - - output_keras_tensor_names - - If the layer has an activation function, an additional - dictionary will be added to the return value representing the - activation function. - - - Parameters - ---------- - layer : keras.Layer - The layer to be converted to HLS configuration(s). - in_tensors : Sequence[KerasTensor] - The list of input tensors to the layer. - out_tensors : Sequence[KerasTensor] - The list of output tensors from the layer. - - Returns - ------- - dict[str, Any] | tuple[dict[str, Any], ...] - layer configuration(s) for the HLS model to be consumed by - the ModelGraph constructor + """Handle a keras layer. Return a tuple of dictionaries, each dictionary representing + a layer (module) in the HLS model. + + One layer may correspond to one or more dictionaries + (e.g., layers with activation functions will be split into two layers). + + Some common attributes are automatically added to the dictionary if the handler returns a single dictionary. + If the handler returns multiple dictionaries, the attributes must be added manually. + Anything returned by the handler will override the automatic attributes. + + Automatic attributes: + - name + - class_name + - module + - input_keras_tensor_names + - input_shape + - output_keras_tensor_names + + If the layer has an activation function, an additional dictionary will be added to the return value + representing the activation function. + + Args: + layer: The layer to be converted to HLS configuration(s). + in_tensors: The list of input tensors to the layer. + out_tensors: The list of output tensors from the layer. + + Returns: + Layer configuration(s) for the HLS model to be consumed by the ModelGraph constructor. """ name = layer.name @@ -199,17 +190,12 @@ def handle( def load_weight(self, layer: 'keras.Layer', key: str): """Load a weight from a layer. - Parameters - ---------- - layer : keras.Layer - The layer to load the weight from. - key : str - The key of the weight to load. - - Returns - ------- - np.ndarray - The weight. + Args: + layer: The layer to load the weight from. + key: The key of the weight to load. + + Returns: + np.ndarray: The weight. """ import keras diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py index 7476c7bab1..d30a27eb88 100644 --- a/hls4ml/utils/einsum_utils.py +++ b/hls4ml/utils/einsum_utils.py @@ -16,26 +16,18 @@ class EinsumRecipe(TypedDict): def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): - """Validate, resolve broadcasting, and compute output shape for einsum string - - Parameters - ---------- - fn : str - einsum string, e.g. 'ij,jk->ik' - shape0 : tuple[int,...] - shape of input0 - shape1 : tuple[int,...] - shape of input1 - - Returns - ------- - tuple[str, tuple[int,...]] - einsum string w/o broadcasting, and output shape - - Raises - ------ - ValueError - If the einsum string is invalid, or if it is incompatible with the input shapes + """Validate, resolve broadcasting, and compute output shape for einsum string. + + Args: + fn: einsum string, e.g. 'ij,jk->ik' + shape0: shape of input0 + shape1: shape of input1 + + Returns: + tuple[str, tuple[int,...]]: einsum string w/o broadcasting, and output shape + + Raises: + ValueError: If the einsum string is invalid, or if it is incompatible with the input shapes """ inp, out = map(str.strip, fn.split('->')) in0, in1 = map(str.strip, inp.split(',')) @@ -140,21 +132,15 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: - """Parse einsum operation on two input arrays, return a recipe for execution - - Parameters - ---------- - fn : str - einsum string, e.g. 'ij,jk->ik' - input : np.ndarray - input0, the first input array - input1 : np.ndarray - input1, the second input array - - Returns - ------- - EinsumRecipe - einsum recipe; executed by _exec_einsum + """Parse einsum operation on two input arrays, return a recipe for execution. + + Args: + fn: einsum string, e.g. 'ij,jk->ik' + input_shape0: shape of the first input array + input_shape1: shape of the second input array + + Returns: + EinsumRecipe: einsum recipe; executed by _exec_einsum """ fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) @@ -209,21 +195,15 @@ def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: - """Execute einsum operation on two input arrays - - Parameters - ---------- - recipe : EinsumRecipe - einsum recipe - input0 : np.ndarray - input0, the first input array - input1 : np.ndarray - input1, the second input array - - Returns - ------- - np.ndarray - output array + """Execute einsum operation on two input arrays. + + Args: + recipe: einsum recipe + input0: the first input array + input1: the second input array + + Returns: + np.ndarray: output array """ sum_axis0, sum_axis1 = recipe['direct_sum_axis'] if sum_axis0: @@ -248,21 +228,16 @@ def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) - def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: """Execute einsum operation on two input arrays. - WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators - - Parameters - ---------- - fn : str - einsum string, e.g. 'ij,jk->ik' - input : np.ndarray - input0, the first input array - input1 : np.ndarray - input1, the second input array - - Returns - ------- - np.ndarray - output array + Warning: + Order of multiplication is reversed -- watchout if you are using non-commutative operators + + Args: + fn: einsum string, e.g. 'ij,jk->ik' + input0: the first input array + input1: the second input array + + Returns: + np.ndarray: output array """ recipe = parse_einsum(fn, input0.shape, input1.shape) return _exec_einsum(recipe, input0, input1) From eccde4eb6155d562d077d4e69de8fa6847943ddd Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 15 Apr 2025 18:23:38 -0700 Subject: [PATCH 15/30] quote format --- hls4ml/backends/vivado/passes/einsum_dense.py | 4 +- hls4ml/converters/keras_v3/_base.py | 20 +++++----- hls4ml/converters/keras_v3/core.py | 2 +- hls4ml/converters/keras_v3/einsum_dense.py | 4 +- hls4ml/converters/keras_v3_to_hls.py | 38 +++++++++---------- hls4ml/utils/einsum_utils.py | 28 +++++++------- 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py index 1b4b183039..12359ef754 100644 --- a/hls4ml/backends/vivado/passes/einsum_dense.py +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -6,7 +6,7 @@ # Shared Dense template -dense_config_template = """struct config{index}_dense : nnet::dense_config {{ +dense_config_template = '''struct config{index}_dense : nnet::dense_config {{ static const unsigned n_in = {n_in}; static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; @@ -20,7 +20,7 @@ using kernel = nnet::{dense_function}; template using product = nnet::product::{product_type}; -}};\n""" +}};\n''' # EinsumDense template diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py index c0549be7f0..22a28df005 100644 --- a/hls4ml/converters/keras_v3/_base.py +++ b/hls4ml/converters/keras_v3/_base.py @@ -35,7 +35,7 @@ def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... def register(cls: str | type): - """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. + '''Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. Args: cls: If str, the key to register the handler under. If type, the class to register the handler for. @@ -54,7 +54,7 @@ def handle(self, layer, inp_tensors, out_tensors): def my_layer_handler(layer, inp_tensors, out_tensors): # handler code ``` - """ + ''' def deco(func): if isinstance(cls, str): @@ -77,7 +77,7 @@ def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: st class KerasV3LayerHandler: - """Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.""" + '''Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.''' handles = () default_config: DefaultConfig @@ -88,7 +88,7 @@ def __call__( in_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'], ) -> tuple[dict[str, Any], ...]: - """Handle a keras layer. Return a tuple of dictionaries, each dictionary representing + '''Handle a keras layer. Return a tuple of dictionaries, each dictionary representing a layer (module) in the HLS model. One layer may correspond to one or more dictionaries @@ -116,7 +116,7 @@ def __call__( Returns: Layer configuration(s) for the HLS model to be consumed by the ModelGraph constructor. - """ + ''' name = layer.name class_name = layer.__class__.__name__ @@ -142,7 +142,7 @@ def __call__( if isinstance(config0, tuple): for conf in config0: for key in mandatory_keys: - assert key in conf, f"Key {key} missing from layer {name} handled by {self.__class__.__name__}" + assert key in conf, f'Key {key} missing from layer {name} handled by {self.__class__.__name__}' return config0 config = {} @@ -165,8 +165,8 @@ def maybe_get_activation_config(self, layer, out_tensors): activation = getattr(layer, 'activation', None) name = layer.name if activation not in (keras.activations.linear, None): - assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function" - assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function" + assert len(out_tensors) == 1, f'Layer {name} has more than one output, but has an activation function' + assert isinstance(activation, FunctionType), f'Activation function for layer {name} is not a function' intermediate_tensor_name = f'{out_tensors[0].name}_activation' act_cls_name = activation.__name__ act_config = { @@ -188,7 +188,7 @@ def handle( return {} def load_weight(self, layer: 'keras.Layer', key: str): - """Load a weight from a layer. + '''Load a weight from a layer. Args: layer: The layer to load the weight from. @@ -196,7 +196,7 @@ def load_weight(self, layer: 'keras.Layer', key: str): Returns: np.ndarray: The weight. - """ + ''' import keras return keras.ops.convert_to_numpy(getattr(layer, key)) diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py index f3ac9a0d75..0e809b23e2 100644 --- a/hls4ml/converters/keras_v3/core.py +++ b/hls4ml/converters/keras_v3/core.py @@ -26,7 +26,7 @@ def handle( kernel = self.load_weight(layer, 'kernel') bias = self.load_weight(layer, 'bias') if layer.use_bias else None - n_in, n_out = kernel.shape + n_in, n_out = kernel.shape # type: ignore config = { 'data_format': 'channels_last', diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py index 8eb000fcf7..f317160c5d 100644 --- a/hls4ml/converters/keras_v3/einsum_dense.py +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -9,7 +9,7 @@ def strip_batch_dim(equation: str, einsum_dense: bool = True): - """Remove the batch dimension from the equation. + '''Remove the batch dimension from the equation. Args: equation (str): The einsum equation. @@ -17,7 +17,7 @@ def strip_batch_dim(equation: str, einsum_dense: bool = True): Returns: str: The einsum equation without the batch dimension. - """ + ''' _inps, out = equation.split('->') inp0, inp1 = _inps.split(',') diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py index 5c0168cc1e..cbf8d0d427 100644 --- a/hls4ml/converters/keras_v3_to_hls.py +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -17,7 +17,7 @@ def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None): - """Given a keras layer, return a list of tuples of input and output + '''Given a keras layer, return a list of tuples of input and output tensors. If the layer is called only once (i.e., no shared layers), the list will contain only one tuple. @@ -37,7 +37,7 @@ def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None) ------- list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] A list of tuples of input and output tensors. - """ + ''' in_nodes = layer._inbound_nodes if node_whitelist is not None: in_nodes = [node for node in in_nodes if id(node) in node_whitelist] @@ -51,7 +51,7 @@ def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None) def resolve_dependency_relation(model: 'keras.Model'): - """Given a keras model, return the following information: + '''Given a keras model, return the following information: - A list of input tensor names - A list of output tensor names - A list of (layer_name, input_tensor_names, output_tensor_names) tuples @@ -66,13 +66,13 @@ def resolve_dependency_relation(model: 'keras.Model'): ------- tuple[tuple[str, ...], tuple[str, ...], list[tuple[str, tuple[str, ...], tuple[str, ...]]], dict[str, KerasTensor]] inp_tensor_names, out_tensor_names, layer_io, tensors - """ + ''' tensors: dict[str, 'KerasTensor'] = {} - "tensor_name -> KerasTensor" + 'tensor_name -> KerasTensor' depends_on: dict[str, tuple[str, ...]] = {} - "tensor_name -> {tensor_name}" + 'tensor_name -> {tensor_name}' layer_io: list[tuple[str, tuple[str, ...], tuple[str, ...]]] = [] - "layer_name -> ((input_tensor_names), (output_tensor_names))" + 'layer_name -> ((input_tensor_names), (output_tensor_names))' inputs = tuple(t.name for t in model.inputs) outputs = tuple(t.name for t in model.outputs) @@ -92,7 +92,7 @@ def resolve_dependency_relation(model: 'keras.Model'): class UniqueName: - """Helper class to generate unique names for layers, if one being used multiple times.""" + '''Helper class to generate unique names for layers, if one being used multiple times.''' def __init__(self): self.used_names: set[str] = set() @@ -114,7 +114,7 @@ def reset(self): class KerasV3HandlerDispatcher: - """Dispatcher class to handle different types of keras v3 layers.""" + '''Dispatcher class to handle different types of keras v3 layers.''' def __init__(self, layer_handlers: dict[str, T_kv3_handler], v2_layer_handlers=None): self.registry = layer_handlers @@ -123,7 +123,7 @@ def __init__(self, layer_handlers: dict[str, T_kv3_handler], v2_layer_handlers=N def __call__( self, layer: 'keras.Layer', in_tensors: Sequence['keras.KerasTensor'], out_tensors: Sequence['keras.KerasTensor'] ) -> tuple[dict[str, Any], ...]: - assert layer.built, f"Layer {layer.name} is not built" + assert layer.built, f'Layer {layer.name} is not built' ret = self.v3_call(layer, in_tensors, out_tensors) if ret is not None: @@ -133,7 +133,7 @@ def __call__( return ret raise ValueError( - f"Layer {layer.__class__.__module__}.{layer.__class__.__name__} not found in either v3 or v2 handlers" + f'Layer {layer.__class__.__module__}.{layer.__class__.__name__} not found in either v3 or v2 handlers' ) def v3_call( @@ -141,7 +141,7 @@ def v3_call( ): cls_name = layer.__class__.__name__ module = layer.__module__ - key = f"{module}.{cls_name}" + key = f'{module}.{cls_name}' # keras v3 handlers handler = self.registry.get(key, None) @@ -155,7 +155,7 @@ def v2_call( self, layer: 'keras.layers.Layer', inp_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'] ): # keras v2 handlers fallback - print(f"v2 handler used for layer {layer.name}") + print(f'v2 handler used for layer {layer.name}') import keras @@ -164,7 +164,7 @@ def v2_call( class DummyReader: def get_weights_data(self, layer_name, var_name): - assert layer_name == layer.name, f"Processing {layer.name}, but handler tried to read {layer_name}" + assert layer_name == layer.name, f'Processing {layer.name}, but handler tried to read {layer_name}' for w in layer.weights: if var_name in w.name: return np.array(w) @@ -186,7 +186,7 @@ def get_weights_data(self, layer_name, var_name): activation = getattr(layer, 'activation', None) if activation not in (keras.activations.linear, None): - assert isinstance(activation, FunctionType), f"Activation function for layer {layer.name} is not a function" + assert isinstance(activation, FunctionType), f'Activation function for layer {layer.name} is not a function' intermediate_tensor_name = f'{output_names[0]}_activation' ret[0]['output_keras_tensor_names'] = (intermediate_tensor_name,) act_cls_name = activation.__name__ @@ -202,7 +202,7 @@ def get_weights_data(self, layer_name, var_name): def parse_keras_v3_model(model: 'keras.Model'): - """Parse a keras model into a list of dictionaries, each + '''Parse a keras model into a list of dictionaries, each representing a layer in the HLS model, and a list of input and output layer names. @@ -220,9 +220,9 @@ def parse_keras_v3_model(model: 'keras.Model'): ------ ValueError If a circular dependency is detected. - """ + ''' - assert model.built, "Model must be built before parsing" + assert model.built, 'Model must be built before parsing' import keras @@ -267,7 +267,7 @@ def parse_keras_v3_model(model: 'keras.Model'): break # Restart the loop to add another layer else: # If no layer was added in the loop, then there is a circular dependency - raise ValueError("Circular dependency detected") + raise ValueError('Circular dependency detected') # Mark inputs[inp layer name] for ModelGraph to parse from i/o keras tensor names provides: dict[str, str] = {} # tensor_name -> src_layer_name diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py index d30a27eb88..b01e26edf8 100644 --- a/hls4ml/utils/einsum_utils.py +++ b/hls4ml/utils/einsum_utils.py @@ -16,7 +16,7 @@ class EinsumRecipe(TypedDict): def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): - """Validate, resolve broadcasting, and compute output shape for einsum string. + '''Validate, resolve broadcasting, and compute output shape for einsum string. Args: fn: einsum string, e.g. 'ij,jk->ik' @@ -28,7 +28,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . Raises: ValueError: If the einsum string is invalid, or if it is incompatible with the input shapes - """ + ''' inp, out = map(str.strip, fn.split('->')) in0, in1 = map(str.strip, inp.split(',')) alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -74,7 +74,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . # Output index out of nowhere if remaining := sax_out - sax_in0 - sax_in1: - raise ValueError(f"einsum string {fn} is invalid: output subscripts {remaining} not found in inputs") + raise ValueError(f'einsum string {fn} is invalid: output subscripts {remaining} not found in inputs') _common_in = sax_in0 & sax_in1 @@ -82,7 +82,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . # Simultaneous axes expansion in both inputs n_boardcast0 = len(shape0) - len(sax_in0) + 1 n_boardcast1 = len(shape1) - len(sax_in1) + 1 - assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.' + assert n_boardcast0 == n_boardcast1, f"'...' expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1." # Replace expansion indices with free indices in0 = in0.replace('0', free_indices[:n_boardcast0]) in1 = in1.replace('0', free_indices[:n_boardcast1]) @@ -94,7 +94,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . # Axes expansion in input0 or input1 only if '0' in sax_in0: if len(sax_in0) - 1 > len(shape0): - raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + raise ValueError(f'Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given') # Replace auto expansion indices with free indices n_broadcast = len(shape0) - len(sax_in0) + 1 in0 = in0.replace('0', free_indices[:n_broadcast]) @@ -103,11 +103,11 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . ax_out = list(out) else: if len(sax_in0) != len(shape0): - raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + raise ValueError(f'Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given') if '0' in sax_in1: if len(sax_in1) - 1 > len(shape1): - raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + raise ValueError(f'Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given') # Replace expansion indices with free indices n_broadcast = len(shape1) - len(sax_in1) + 1 in1 = in1.replace('0', free_indices[:n_broadcast]) @@ -116,7 +116,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . ax_out = list(out) else: if len(sax_in1) != len(shape1): - raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + raise ValueError(f'Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given') # Input dimension mismatch for a in _common_in: @@ -132,7 +132,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, . def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: - """Parse einsum operation on two input arrays, return a recipe for execution. + '''Parse einsum operation on two input arrays, return a recipe for execution. Args: fn: einsum string, e.g. 'ij,jk->ik' @@ -141,7 +141,7 @@ def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int Returns: EinsumRecipe: einsum recipe; executed by _exec_einsum - """ + ''' fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) @@ -195,7 +195,7 @@ def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: - """Execute einsum operation on two input arrays. + '''Execute einsum operation on two input arrays. Args: recipe: einsum recipe @@ -204,7 +204,7 @@ def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) - Returns: np.ndarray: output array - """ + ''' sum_axis0, sum_axis1 = recipe['direct_sum_axis'] if sum_axis0: input0 = np.sum(input0, axis=sum_axis0) @@ -226,7 +226,7 @@ def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) - def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: - """Execute einsum operation on two input arrays. + '''Execute einsum operation on two input arrays. Warning: Order of multiplication is reversed -- watchout if you are using non-commutative operators @@ -238,6 +238,6 @@ def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: Returns: np.ndarray: output array - """ + ''' recipe = parse_einsum(fn, input0.shape, input1.shape) return _exec_einsum(recipe, input0, input1) From 6dfeb99d19c6980000dc53fbcc123863461bf181 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 18 Apr 2025 10:34:23 -0700 Subject: [PATCH 16/30] restore example-models version --- example-models | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example-models b/example-models index 3cfbcfd062..c6bb3c0686 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 +Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 From 8284757aab9dc37e4594b761c7293096b66fd640 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 18 Apr 2025 10:36:47 -0700 Subject: [PATCH 17/30] pre-commit update --- .pre-commit-config.yaml | 2 +- hls4ml/converters/keras_v3/_base.py | 3 ++- hls4ml/converters/keras_v3/conv.py | 2 +- hls4ml/converters/keras_v3/core.py | 3 ++- hls4ml/converters/keras_v3/einsum_dense.py | 2 +- hls4ml/converters/keras_v3_to_hls.py | 3 ++- hls4ml/model/optimizer/passes/infer_precision.py | 2 +- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a597d86f58..737cc1ec0a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: v3.19.1 hooks: - id: pyupgrade - args: ["--py36-plus"] + args: ["--py310-plus"] - repo: https://github.com/pycqa/flake8 rev: 7.1.2 diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py index 22a28df005..22652b1b6c 100644 --- a/hls4ml/converters/keras_v3/_base.py +++ b/hls4ml/converters/keras_v3/_base.py @@ -1,6 +1,7 @@ import typing +from collections.abc import Callable, Sequence from types import FunctionType -from typing import Any, Callable, Sequence, TypedDict, overload +from typing import Any, TypedDict, overload class DefaultConfig(TypedDict, total=False): diff --git a/hls4ml/converters/keras_v3/conv.py b/hls4ml/converters/keras_v3/conv.py index adf6221822..756a889803 100644 --- a/hls4ml/converters/keras_v3/conv.py +++ b/hls4ml/converters/keras_v3/conv.py @@ -1,6 +1,6 @@ import typing +from collections.abc import Sequence from math import ceil -from typing import Sequence from ._base import KerasV3LayerHandler, register diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py index 0e809b23e2..2ce3ac4165 100644 --- a/hls4ml/converters/keras_v3/core.py +++ b/hls4ml/converters/keras_v3/core.py @@ -1,7 +1,8 @@ import inspect import typing +from collections.abc import Sequence from math import prod -from typing import Any, Sequence +from typing import Any import numpy as np diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py index f317160c5d..738d10a796 100644 --- a/hls4ml/converters/keras_v3/einsum_dense.py +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -1,5 +1,5 @@ import typing -from typing import Sequence +from collections.abc import Sequence from ._base import KerasV3LayerHandler, register diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py index cbf8d0d427..f94928e811 100644 --- a/hls4ml/converters/keras_v3_to_hls.py +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -1,7 +1,8 @@ import typing +from collections.abc import Callable, Sequence from itertools import chain from types import FunctionType -from typing import Any, Callable, Sequence +from typing import Any if typing.TYPE_CHECKING: import keras diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index bd439e4a0f..919bc0c3c2 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -1,5 +1,5 @@ import math -from typing import Iterable +from collections.abc import Iterable import numpy as np From e5ad92ce4b00dab72c48b72fbf6844d586e72962 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:02:13 -0700 Subject: [PATCH 18/30] kv3 handler update --- hls4ml/converters/keras_v3/__init__.py | 1 + hls4ml/converters/keras_v3/_base.py | 81 ++++------- hls4ml/converters/keras_v3/conv.py | 151 ++++++++++++--------- hls4ml/converters/keras_v3/core.py | 76 ++++++++--- hls4ml/converters/keras_v3/einsum_dense.py | 19 ++- hls4ml/converters/keras_v3/pooling.py | 74 ++++++++++ 6 files changed, 262 insertions(+), 140 deletions(-) create mode 100644 hls4ml/converters/keras_v3/pooling.py diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py index 6dffcb71d5..150d0e241f 100644 --- a/hls4ml/converters/keras_v3/__init__.py +++ b/hls4ml/converters/keras_v3/__init__.py @@ -1,6 +1,7 @@ from . import conv # noqa: F401 from . import core # noqa: F401 from . import einsum_dense # noqa: F401 +from . import pooling # noqa: F401 from ._base import registry as layer_handlers __all__ = ['layer_handlers'] diff --git a/hls4ml/converters/keras_v3/_base.py b/hls4ml/converters/keras_v3/_base.py index 22652b1b6c..a3c23d4654 100644 --- a/hls4ml/converters/keras_v3/_base.py +++ b/hls4ml/converters/keras_v3/_base.py @@ -1,7 +1,7 @@ import typing from collections.abc import Callable, Sequence from types import FunctionType -from typing import Any, TypedDict, overload +from typing import Any, TypedDict class DefaultConfig(TypedDict, total=False): @@ -18,7 +18,7 @@ class DefaultConfig(TypedDict, total=False): if typing.TYPE_CHECKING: import keras - from keras.api import KerasTensor + from keras import KerasTensor T_kv3_handler = Callable[ ['keras.Layer', Sequence['keras.KerasTensor'], Sequence['keras.KerasTensor']], tuple[dict[str, Any], ...] @@ -27,50 +27,6 @@ class DefaultConfig(TypedDict, total=False): registry: dict[str, T_kv3_handler] = {} -@overload -def register(cls: type) -> type: ... - - -@overload -def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ... - - -def register(cls: str | type): - '''Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. - - Args: - cls: If str, the key to register the handler under. If type, the class to register the handler for. - - Examples: - ```python - @keras_dispatcher.register - class MyLayerHandler(KerasV3LayerHandler): - handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') - - def handle(self, layer, inp_tensors, out_tensors): - # handler code - - - @keras_dispatcher.register('MyLayer3') - def my_layer_handler(layer, inp_tensors, out_tensors): - # handler code - ``` - ''' - - def deco(func): - if isinstance(cls, str): - registry[cls] = func - for k in getattr(func, 'handles', ()): - registry[k] = func - if isinstance(cls, type): - return cls - return func - - if isinstance(cls, type): - return deco(cls()) - return deco - - def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: str): for attr in attrs: if attr not in config and hasattr(obj, attr): @@ -78,7 +34,7 @@ def maybe_add_attrs(config: dict[str, Any] | DefaultConfig, obj: Any, *attrs: st class KerasV3LayerHandler: - '''Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.''' + """Base class for keras v3 layer handlers. Subclass this class to create a handler for a specific layer type.""" handles = () default_config: DefaultConfig @@ -89,7 +45,7 @@ def __call__( in_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'], ) -> tuple[dict[str, Any], ...]: - '''Handle a keras layer. Return a tuple of dictionaries, each dictionary representing + """Handle a keras layer. Return a tuple of dictionaries, each dictionary representing a layer (module) in the HLS model. One layer may correspond to one or more dictionaries @@ -117,7 +73,7 @@ def __call__( Returns: Layer configuration(s) for the HLS model to be consumed by the ModelGraph constructor. - ''' + """ name = layer.name class_name = layer.__class__.__name__ @@ -189,7 +145,7 @@ def handle( return {} def load_weight(self, layer: 'keras.Layer', key: str): - '''Load a weight from a layer. + """Load a weight from a layer. Args: layer: The layer to load the weight from. @@ -197,7 +153,30 @@ def load_weight(self, layer: 'keras.Layer', key: str): Returns: np.ndarray: The weight. - ''' + """ import keras return keras.ops.convert_to_numpy(getattr(layer, key)) + + +def register(cls: type): + """Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class. + + Args: + cls: the class to register the handler for. + + Examples: + ```python + @keras_dispatcher.register + class MyLayerHandler(KerasV3LayerHandler): + handles = ('my_package.src.submodule.MyLayer', 'MyLayer2') + + def handle(self, layer, inp_tensors, out_tensors): + # handler code + ``` + """ + + fn = cls() + for k in fn.handles: + registry[k] = fn + return cls diff --git a/hls4ml/converters/keras_v3/conv.py b/hls4ml/converters/keras_v3/conv.py index 756a889803..cff353abfe 100644 --- a/hls4ml/converters/keras_v3/conv.py +++ b/hls4ml/converters/keras_v3/conv.py @@ -1,16 +1,82 @@ import typing from collections.abc import Sequence from math import ceil +from typing import Any from ._base import KerasV3LayerHandler, register if typing.TYPE_CHECKING: import keras - from keras.api import KerasTensor + from keras import KerasTensor + + +def gen_conv_config( + in_shape: tuple[int, ...], + out_shape: tuple[int, ...], + ker_px_shape: tuple[int, ...], + strides: tuple[int, ...], + padding: str, + data_format: str, + name: str, +) -> dict[str, Any]: + if data_format == 'channels_last': + *px_in_shape, ch_in = in_shape + *px_out_shape, ch_out = out_shape + else: + ch_in, *px_in_shape = in_shape + ch_out, *px_out_shape = out_shape + if not px_out_shape: + px_out_shape = [1] * len(px_in_shape) + + if padding == 'same': + n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)] + n_padding0 = [p // 2 for p in n_padding] + n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)] + elif padding == 'valid': + n_padding0 = [0] * len(px_in_shape) + n_padding1 = [0] * len(px_in_shape) + elif padding == 'causal': + assert len(px_in_shape) == 1, f'Invalid padding mode {padding} for layer {name}: ndim > 1' + n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1) + n_padding1 = [0] * len(px_in_shape) + else: + raise ValueError(f'Invalid padding mode {padding} for layer {name}') + + if len(ker_px_shape) == 1: + config = { + 'filt_width': ker_px_shape[0], + 'stride_width': strides[0], + 'pad_left': n_padding0[0], + 'pad_right': n_padding1[0], + 'in_width': px_in_shape[0], + 'out_width': px_out_shape[0], + } + + elif len(ker_px_shape) == 2: + config = { + 'filt_height': ker_px_shape[0], + 'filt_width': ker_px_shape[1], + 'stride_height': strides[0], + 'stride_width': strides[1], + 'pad_top': n_padding0[0], + 'pad_bottom': n_padding1[0], + 'pad_left': n_padding0[1], + 'pad_right': n_padding1[1], + 'in_height': px_in_shape[0], + 'in_width': px_in_shape[1], # type: ignore + 'out_height': px_out_shape[0], + 'out_width': px_out_shape[1], + } + else: + raise ValueError(f'Only 1D and 2D layers are supported, got {len(ker_px_shape)}D') + + config['n_filt'] = ch_out + config['n_chan'] = ch_in + return config @register -class KV3ConvHandler(KerasV3LayerHandler): +class ConvHandler(KerasV3LayerHandler): handles = ( 'keras.src.layers.convolutional.conv1d.Conv1D', 'keras.src.layers.convolutional.conv2d.Conv2D', @@ -30,13 +96,13 @@ def handle( from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv - assert len(in_tensors) == 1, f"Layer {layer.name} has more than one input" - assert len(out_tensors) == 1, f"Layer {layer.name} has more than one output" + assert len(in_tensors) == 1, f'Layer {layer.name} has more than one input' + assert len(out_tensors) == 1, f'Layer {layer.name} has more than one output' in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore - assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}" - assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}" + assert all(isinstance(x, int) for x in in_shape), f'Layer {layer.name} has non-fixed size input: {in_shape}' + assert all(isinstance(x, int) for x in out_shape), f'Layer {layer.name} has non-fixed size output: {out_shape}' kernel = self.load_weight(layer, 'kernel') if layer.use_bias: @@ -47,65 +113,24 @@ def handle( ker_px_shape: tuple[int, ...] = layer.kernel_size data_format = layer.data_format - if data_format == 'channels_last': - *px_in_shape, ch_in = in_shape - *px_out_shape, ch_out = out_shape - else: - ch_in, *px_in_shape = in_shape - ch_out, *px_out_shape = out_shape - - if layer.padding == 'same': - n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)] - n_padding0 = [p // 2 for p in n_padding] - n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)] - elif layer.padding == 'valid': - n_padding0 = [0] * len(px_in_shape) - n_padding1 = [0] * len(px_in_shape) - elif layer.padding == 'causal': - n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1) - n_padding1 = [0] * len(px_in_shape) - else: - raise ValueError(f"Invalid padding mode {layer.padding} for layer {layer.name}") + config = gen_conv_config( + in_shape=in_shape, + out_shape=out_shape, + ker_px_shape=ker_px_shape, + strides=layer.strides, + data_format=data_format, + padding=layer.padding, + name=layer.name, + ) - config = { - 'bias_data': bias, - 'data_format': data_format, - 'weight_data': kernel, - 'n_filt': ch_out, - 'n_chan': ch_in, - } + config.update( + { + 'bias_data': bias, + 'data_format': data_format, + 'weight_data': kernel, + } + ) - if layer.rank == 1: - config.update( - { - 'filt_width': ker_px_shape[0], - 'stride_width': layer.strides[0], - 'pad_left': n_padding0[0], - 'pad_right': n_padding1[0], - 'in_width': px_in_shape[0], - 'out_width': px_out_shape[0], - } - ) - elif layer.rank == 2: - config.update( - { - 'filt_height': ker_px_shape[0], - 'filt_width': ker_px_shape[1], - 'stride_height': layer.strides[0], - 'stride_width': layer.strides[1], - 'pad_top': n_padding0[0], - 'pad_bottom': n_padding1[0], - 'pad_left': n_padding0[1], - 'pad_right': n_padding1[1], - 'in_height': px_in_shape[0], - 'in_width': px_in_shape[1], - 'out_height': px_out_shape[0], - 'out_width': px_out_shape[1], - } - ) - else: - _cls = f"{layer.__class__.__module__}.{layer.__class__.__name__}" - raise ValueError(f"Only 1D and 2D conv layers are supported, got {_cls} (rank={layer.rank})") if isinstance(layer, BaseDepthwiseConv): config['depthwise_data'] = kernel config['depth_multiplier'] = layer.depth_multiplier diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py index 2ce3ac4165..010caf7a22 100644 --- a/hls4ml/converters/keras_v3/core.py +++ b/hls4ml/converters/keras_v3/core.py @@ -10,12 +10,12 @@ if typing.TYPE_CHECKING: import keras - from keras.api import KerasTensor + from keras import KerasTensor from keras.src.layers.merging.base_merge import Merge @register -class KV3DenseHandler(KerasV3LayerHandler): +class DenseHandler(KerasV3LayerHandler): handles = ('keras.src.layers.core.dense.Dense',) def handle( @@ -24,7 +24,6 @@ def handle( in_tensors: Sequence['KerasTensor'], out_tensors: Sequence['KerasTensor'], ): - kernel = self.load_weight(layer, 'kernel') bias = self.load_weight(layer, 'bias') if layer.use_bias else None n_in, n_out = kernel.shape # type: ignore @@ -40,7 +39,7 @@ def handle( @register -class KV3InputHandler(KerasV3LayerHandler): +class InputHandler(KerasV3LayerHandler): handles = ('keras.src.layers.core.input_layer.InputLayer',) def handle( @@ -54,7 +53,7 @@ def handle( @register -class KV3MergeHandler(KerasV3LayerHandler): +class MergeHandler(KerasV3LayerHandler): handles = ( 'keras.src.layers.merging.add.Add', 'keras.src.layers.merging.multiply.Multiply', @@ -73,33 +72,38 @@ def handle( out_tensors: Sequence['KerasTensor'], cls_name: str | None = None, ): - assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output" + assert len(out_tensors) == 1, f'Merge layer {layer.name} has more than one output' output_shape = list(out_tensors[0].shape[1:]) cls_name = cls_name or layer.__class__.__name__ - config: dict[str, Any] = { - 'output_shape': output_shape, - 'op': cls_name.lower(), - } + config: dict[str, Any] = {'output_shape': output_shape} + op = cls_name.lower() match cls_name.lower(): case 'Concatenate': rank = len(output_shape) class_name = f'Concatenate{rank}d' config['axis'] = layer.axis case 'Dot': - class_name = f'Dot{len(output_shape)}d' - rank = len(output_shape) - assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}" + msg = ( + 'Dot product only supported flatten tensors, got input shapes' + f'{in_tensors[0].shape} and {in_tensors[1].shape} for layer {layer.name}.' + ) + assert all(len(t.shape) == 2 for t in in_tensors), msg + assert in_tensors[0].shape[1] == in_tensors[1].shape[0], f'Input shape mismatch for layer {layer.name}.' + class_name = 'Dot' + op = 'dot1d' + config['axes'] = layer.axes case _: class_name = 'Merge' config['class_name'] = class_name + config['op'] = op return config @register -class KV3ActivationHandler(KerasV3LayerHandler): +class ActivationHandler(KerasV3LayerHandler): handles = ('keras.src.layers.activations.activation.Activation',) def handle( @@ -133,11 +137,12 @@ def handle( config['activation'] = activation.__name__ config['class_name'] = class_name + config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore return (config,) @register -class KV3ReLUHandler(KerasV3LayerHandler): +class ReLUHandler(KerasV3LayerHandler): handles = ( 'keras.src.layers.activations.leaky_relu.LeakyReLU', 'keras.src.layers.activations.prelu.PReLU', @@ -171,7 +176,7 @@ def handle( @register -class KV3SoftmaxHandler(KerasV3LayerHandler): +class SoftmaxHandler(KerasV3LayerHandler): handles = ('keras.src.layers.activations.softmax.Softmax',) def handle( @@ -189,22 +194,22 @@ def handle( config = {} config.update(self.default_config) if len(in_tensors) == 2: - raise NotImplementedError("Masked softmax not supported yet") + raise NotImplementedError('Masked softmax not supported yet') config['class_name'] = 'MaskedSoftmax' elif len(in_tensors) == 1: config['class_name'] = 'Softmax' else: - raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + raise ValueError(f'Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}') config['axis'] = layer.axis config['activation'] = 'softmax' - config['n_outer'] = (n_outer,) + config['n_outer'] = n_outer config['n_inner'] = n_inner return (config,) @register -class KV3HardActivationHandler(KerasV3LayerHandler): +class EluHandler(KerasV3LayerHandler): handles = ('keras.src.layers.activations.elu.ELU',) def handle( @@ -219,5 +224,36 @@ def handle( config['class_name'] = 'ELU' config['activ_param'] = float(layer.alpha) config['activation'] = 'elu' + config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore return (config,) + + +@register +class ReshapeHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.reshaping.reshape.Reshape', 'keras.src.layers.reshaping.flatten.Flatten') + + def handle( + self, + layer: 'keras.layers.Reshape', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + return { + 'class_name': 'Reshape', + 'target_shape': list(out_tensors[0].shape[1:]), + } + + +@register +class PermuteHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.reshaping.permute.Permute',) + + def handle( + self, + layer: 'keras.layers.Permute', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {'class_name': 'Transpose', 'perm': [dim - 1 for dim in layer.dims]} # rm batch dim + return config diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py index 738d10a796..f6b7db29a2 100644 --- a/hls4ml/converters/keras_v3/einsum_dense.py +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -5,11 +5,11 @@ if typing.TYPE_CHECKING: import keras - from keras.api import KerasTensor + from keras import KerasTensor def strip_batch_dim(equation: str, einsum_dense: bool = True): - '''Remove the batch dimension from the equation. + """Remove the batch dimension from the equation. Args: equation (str): The einsum equation. @@ -17,7 +17,7 @@ def strip_batch_dim(equation: str, einsum_dense: bool = True): Returns: str: The einsum equation without the batch dimension. - ''' + """ _inps, out = equation.split('->') inp0, inp1 = _inps.split(',') @@ -29,13 +29,20 @@ def strip_batch_dim(equation: str, einsum_dense: bool = True): assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' inp0, out = inp0[1:], out[1:] else: - assert inp0[0] == inp1[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' - inp0, inp1, out = inp0[1:], inp1[1:], out[1:] + if inp0.startswith('...'): + # fmt: off + assert inp1.startswith('...') and out.startswith('...'), ( + f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' + ) + # fmt: on + else: + assert inp0[0] == inp1[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the inputs and output.' + inp0, inp1, out = inp0[1:], inp1[1:], out[1:] return f'{inp0},{inp1}->{out}' @register -class KV3EinsumDenseHandler(KerasV3LayerHandler): +class EinsumDenseHandler(KerasV3LayerHandler): handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) def handle( diff --git a/hls4ml/converters/keras_v3/pooling.py b/hls4ml/converters/keras_v3/pooling.py new file mode 100644 index 0000000000..0906580f18 --- /dev/null +++ b/hls4ml/converters/keras_v3/pooling.py @@ -0,0 +1,74 @@ +import typing +from collections.abc import Sequence + +from ._base import KerasV3LayerHandler, register +from .conv import gen_conv_config + +if typing.TYPE_CHECKING: + from keras import KerasTensor + from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + from keras.src.layers.pooling.base_pooling import BasePooling + + +@register +class PoolingHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.pooling.max_pooling1d.MaxPooling1D', + 'keras.src.layers.pooling.max_pooling2d.MaxPooling2D', + 'keras.src.layers.pooling.max_pooling3d.MaxPooling3D', + 'keras.src.layers.pooling.average_pooling1d.AveragePooling1D', + 'keras.src.layers.pooling.average_pooling2d.AveragePooling2D', + 'keras.src.layers.pooling.average_pooling3d.AveragePooling3D', + 'keras.src.layers.pooling.global_average_pooling1d.GlobalAveragePooling1D', + 'keras.src.layers.pooling.global_average_pooling2d.GlobalAveragePooling2D', + 'keras.src.layers.pooling.global_average_pooling3d.GlobalAveragePooling3D', + 'keras.src.layers.pooling.global_max_pooling1d.GlobalMaxPooling1D', + 'keras.src.layers.pooling.global_max_pooling2d.GlobalMaxPooling2D', + 'keras.src.layers.pooling.global_max_pooling3d.GlobalMaxPooling3D', + ) + + def handle( + self, + layer: 'BasePooling | BaseGlobalPooling', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + from keras.src.layers.pooling.base_pooling import BasePooling + + assert len(in_tensors) == 1, f'Layer {layer.name} has more than one input' + assert len(out_tensors) == 1, f'Layer {layer.name} has more than one output' + + in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + assert all(isinstance(x, int) for x in in_shape), f'Layer {layer.name} has non-fixed size input: {in_shape}' + assert all(isinstance(x, int) for x in out_shape), f'Layer {layer.name} has non-fixed size output: {out_shape}' + + data_format = layer.data_format + + if data_format == 'channels_last': + *px_in_shape, _ = in_shape + else: + _, *px_in_shape = in_shape + + pool_size: tuple[int, ...] = layer.pool_size if isinstance(layer, BasePooling) else tuple(px_in_shape) + + strides = layer.strides if isinstance(layer, BasePooling) else pool_size + padding = layer.padding if isinstance(layer, BasePooling) else 'valid' + config = gen_conv_config( + in_shape=in_shape, + out_shape=out_shape, + ker_px_shape=pool_size, + strides=strides, + data_format=data_format, + padding=padding, + name=layer.name, + ) + + config['pool_width'] = config.pop('filt_width') + if 'filt_height' in config: + config['pool_height'] = config.pop('filt_height') + if len(px_in_shape) == 1: + # inconsistent pooling1d config key name... + config['n_in'] = config['in_width'] + config['n_out'] = config['out_width'] + return config From 6aec7f665f885acfdc7c48c5d5008b1e12aba6fa Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:03:13 -0700 Subject: [PATCH 19/30] force keras>=3.10 --- pyproject.toml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 041428ea9f..2149cb610a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ optional-dependencies.doc = [ "sphinx-rtd-theme", ] optional-dependencies.hgq = [ "hgq>=0.2.3" ] +optional-dependencies.keras-v3 = [ "keras>=3.10" ] optional-dependencies.onnx = [ "onnx>=1.4" ] optional-dependencies.optimization = [ "keras-tuner==1.1.3", @@ -86,6 +87,25 @@ write_to = "hls4ml/_version.py" line-length = 125 skip-string-normalization = true +[tool.ruff] +target-version = "py310" + +line-length = 125 +indent-width = 4 +include = [ "hls4ml/**/*.py", "tests/**/*.py" ] +exclude = [ "hls4ml/_version.py", "hls4ml/templates/**" ] + +format.quote-style = "single" +format.skip-magic-trailing-comma = false +format.docstring-code-line-length = 125 +format.docstring-code-format = true +lint.select = [ "E", "F", "F401", "I", "W" ] +lint.ignore = [ "E741" ] +lint.per-file-ignores = { "__init__.py" = [ "F401" ] } + +lint.fixable = [ "ALL" ] +lint.unfixable = [ ] + [tool.isort] profile = "black" line_length = 125 From 64261aa229f87afde2648084faacc8ac0323d6e4 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:13:47 -0700 Subject: [PATCH 20/30] isolate merge handlers --- hls4ml/converters/keras_v3/__init__.py | 1 + hls4ml/converters/keras_v3/core.py | 52 ----------------------- hls4ml/converters/keras_v3/merge.py | 59 ++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 52 deletions(-) create mode 100644 hls4ml/converters/keras_v3/merge.py diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py index 150d0e241f..5e2dfb4ebe 100644 --- a/hls4ml/converters/keras_v3/__init__.py +++ b/hls4ml/converters/keras_v3/__init__.py @@ -1,6 +1,7 @@ from . import conv # noqa: F401 from . import core # noqa: F401 from . import einsum_dense # noqa: F401 +from . import merge # noqa: F401 from . import pooling # noqa: F401 from ._base import registry as layer_handlers diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py index 010caf7a22..b7b51968a7 100644 --- a/hls4ml/converters/keras_v3/core.py +++ b/hls4ml/converters/keras_v3/core.py @@ -2,7 +2,6 @@ import typing from collections.abc import Sequence from math import prod -from typing import Any import numpy as np @@ -11,7 +10,6 @@ if typing.TYPE_CHECKING: import keras from keras import KerasTensor - from keras.src.layers.merging.base_merge import Merge @register @@ -52,56 +50,6 @@ def handle( return config -@register -class MergeHandler(KerasV3LayerHandler): - handles = ( - 'keras.src.layers.merging.add.Add', - 'keras.src.layers.merging.multiply.Multiply', - 'keras.src.layers.merging.average.Average', - 'keras.src.layers.merging.maximum.Maximum', - 'keras.src.layers.merging.minimum.Minimum', - 'keras.src.layers.merging.concatenate.Concatenate', - 'keras.src.layers.merging.subtract.Subtract', - 'keras.src.layers.merging.dot.Dot', - ) - - def handle( - self, - layer: 'Merge', - in_tensors: Sequence['KerasTensor'], - out_tensors: Sequence['KerasTensor'], - cls_name: str | None = None, - ): - assert len(out_tensors) == 1, f'Merge layer {layer.name} has more than one output' - output_shape = list(out_tensors[0].shape[1:]) - - cls_name = cls_name or layer.__class__.__name__ - config: dict[str, Any] = {'output_shape': output_shape} - - op = cls_name.lower() - match cls_name.lower(): - case 'Concatenate': - rank = len(output_shape) - class_name = f'Concatenate{rank}d' - config['axis'] = layer.axis - case 'Dot': - msg = ( - 'Dot product only supported flatten tensors, got input shapes' - f'{in_tensors[0].shape} and {in_tensors[1].shape} for layer {layer.name}.' - ) - assert all(len(t.shape) == 2 for t in in_tensors), msg - assert in_tensors[0].shape[1] == in_tensors[1].shape[0], f'Input shape mismatch for layer {layer.name}.' - class_name = 'Dot' - op = 'dot1d' - config['axes'] = layer.axes - case _: - class_name = 'Merge' - - config['class_name'] = class_name - config['op'] = op - return config - - @register class ActivationHandler(KerasV3LayerHandler): handles = ('keras.src.layers.activations.activation.Activation',) diff --git a/hls4ml/converters/keras_v3/merge.py b/hls4ml/converters/keras_v3/merge.py new file mode 100644 index 0000000000..96c5547bae --- /dev/null +++ b/hls4ml/converters/keras_v3/merge.py @@ -0,0 +1,59 @@ +import typing +from collections.abc import Sequence +from typing import Any + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + from keras import KerasTensor + from keras.src.layers.merging.base_merge import Merge + + +@register +class MergeHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.merging.add.Add', + 'keras.src.layers.merging.multiply.Multiply', + 'keras.src.layers.merging.average.Average', + 'keras.src.layers.merging.maximum.Maximum', + 'keras.src.layers.merging.minimum.Minimum', + 'keras.src.layers.merging.concatenate.Concatenate', + 'keras.src.layers.merging.subtract.Subtract', + 'keras.src.layers.merging.dot.Dot', + ) + + def handle( + self, + layer: 'Merge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + cls_name: str | None = None, + ): + assert len(out_tensors) == 1, f'Merge layer {layer.name} has more than one output' + output_shape = list(out_tensors[0].shape[1:]) + + cls_name = cls_name or layer.__class__.__name__ + config: dict[str, Any] = {'output_shape': output_shape} + + op = cls_name.lower() + match cls_name.lower(): + case 'Concatenate': + rank = len(output_shape) + class_name = f'Concatenate{rank}d' + config['axis'] = layer.axis + case 'Dot': + msg = ( + 'Dot product only supported flatten tensors, got input shapes' + f'{in_tensors[0].shape} and {in_tensors[1].shape} for layer {layer.name}.' + ) + assert all(len(t.shape) == 2 for t in in_tensors), msg + assert in_tensors[0].shape[1] == in_tensors[1].shape[0], f'Input shape mismatch for layer {layer.name}.' + class_name = 'Dot' + op = 'dot1d' + config['axes'] = layer.axes + case _: + class_name = 'Merge' + + config['class_name'] = class_name + config['op'] = op + return config From b8ed033aae8f6e4055406fc849f3262879ab3e44 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:14:30 -0700 Subject: [PATCH 21/30] rm abomination --- hls4ml/converters/keras_to_hls.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index ea3f96c236..e002909bda 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -6,8 +6,6 @@ from .keras_v3_to_hls import parse_keras_v3_model -MAXMULT = 4096 - class KerasReader: def get_weights_data(self, layer_name, var_name): From 009ae8eaa995466db3caab66c3080c959c279cef Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:30:13 -0700 Subject: [PATCH 22/30] mv xpose config gen to utils --- hls4ml/backends/fpga/fpga_backend.py | 35 ----------------- .../oneapi/passes/reshaping_templates.py | 3 +- hls4ml/backends/vivado/passes/einsum.py | 7 ++-- hls4ml/backends/vivado/passes/einsum_dense.py | 5 ++- .../vivado/passes/reshaping_templates.py | 3 +- hls4ml/utils/transpose_utils.py | 38 +++++++++++++++++++ 6 files changed, 49 insertions(+), 42 deletions(-) create mode 100644 hls4ml/utils/transpose_utils.py diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 95d900fd62..009d742f3c 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -913,41 +913,6 @@ def generate_conv2d_line_buffer_fn( return generated_code - @staticmethod - def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): - """ - Generate new shape and perm_strides for a permute operation. Operates by mapping the output index - to input input index by: - - unravel the output index - - map each dimension to the corresponding stride in the input tensor, sum - The operation can be expressed as: - - new_shape = tuple(shape[i] for i in perm) - strides = np.cumprod((shapes[1:] + (1,))[::-1])[::-1] - perm_strides = [strides[i] for i in perm] - out[index] = inp[np.dot(np.unravel_index(index, new_shape), perm_strides)] - - Args: - name (str): The name of the configuration. - shape (tuple[int, ...]): The shape of the input tensor. - perm (tuple[int, ...]): The permutation of the dimensions. - - Returns: - dict: Dictionary containing the configuration. - """ - new_shape = tuple(shape[i] for i in perm) - strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] - perm_strides = tuple(int(strides[i]) for i in perm) - return dict( - dims=len(shape), - N=math.prod(shape), - from_shape=', '.join(str(x) for x in shape), - perm=', '.join(str(x) for x in perm), - perm_strides=', '.join(str(x) for x in perm_strides), - to_shape=', '.join(str(x) for x in new_shape), - config_name=name, - ) - @model_optimizer() def write_hls(self, model): self.writer.write_hls(model) diff --git a/hls4ml/backends/oneapi/passes/reshaping_templates.py b/hls4ml/backends/oneapi/passes/reshaping_templates.py index 80b467b944..0f07584440 100644 --- a/hls4ml/backends/oneapi/passes/reshaping_templates.py +++ b/hls4ml/backends/oneapi/passes/reshaping_templates.py @@ -3,6 +3,7 @@ from hls4ml.backends.oneapi.oneapi_template import StreamFunctionCallTemplate, TaskSequenceTemplate from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Reshape, Resize, Transpose, ZeroPadding1D, ZeroPadding2D +from hls4ml.utils.transpose_utils import transpose_config_gen # ZeroPadding templates @@ -185,7 +186,7 @@ def format(self, node): perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + conf = transpose_config_gen(name, shape, perm) return transpose_config_template.format(**conf) diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py index 4f976c63af..4f076e8cd6 100644 --- a/hls4ml/backends/vivado/passes/einsum.py +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -3,6 +3,7 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Einsum +from hls4ml.utils.transpose_utils import transpose_config_gen from .reshaping_templates import transpose_config_template @@ -81,11 +82,11 @@ def format(self, node: Einsum): tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' tpose_out_conf_name = f'config{node.index}_tpose_out' - conf = node.model.config.backend.transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) inp0_tpose_conf = transpose_config_template.format(**conf) - conf = node.model.config.backend.transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) inp1_tpose_conf = transpose_config_template.format(**conf) - conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) out_tpose_conf = transpose_config_template.format(**conf) return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf)) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py index 12359ef754..5f6dab64e0 100644 --- a/hls4ml/backends/vivado/passes/einsum_dense.py +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -1,6 +1,7 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import EinsumDense +from hls4ml.utils.transpose_utils import transpose_config_gen from .reshaping_templates import transpose_config_template @@ -118,9 +119,9 @@ def format(self, node: EinsumDense): tpose_inp_conf_name = f'config{node.index}_tpose_inp' tpose_out_conf_name = f'config{node.index}_tpose_out' - conf = node.model.config.backend.transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) inp_tpose_conf = transpose_config_template.format(**conf) - conf = node.model.config.backend.transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) out_tpose_conf = transpose_config_template.format(**conf) if strategy.lower() == 'distributed_arithmetic': diff --git a/hls4ml/backends/vivado/passes/reshaping_templates.py b/hls4ml/backends/vivado/passes/reshaping_templates.py index 69944e4497..0a14efc1af 100644 --- a/hls4ml/backends/vivado/passes/reshaping_templates.py +++ b/hls4ml/backends/vivado/passes/reshaping_templates.py @@ -1,5 +1,6 @@ from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D +from hls4ml.utils.transpose_utils import transpose_config_gen # ZeroPadding templates @@ -131,7 +132,7 @@ def format(self, node): shape = tuple(node.get_input_variable().shape) perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - conf = node.model.config.backend.transpose_config_gen(name, shape, perm) + conf = transpose_config_gen(name, shape, perm) return transpose_config_template.format(**conf) diff --git a/hls4ml/utils/transpose_utils.py b/hls4ml/utils/transpose_utils.py new file mode 100644 index 0000000000..7e399c8e7c --- /dev/null +++ b/hls4ml/utils/transpose_utils.py @@ -0,0 +1,38 @@ +import math + +import numpy as np + + +def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): + """ + Generate new shape and perm_strides for a permute operation. Operates by mapping the output index + to input input index by: + - unravel the output index + - map each dimension to the corresponding stride in the input tensor, sum + The operation can be expressed as: + + new_shape = tuple(shape[i] for i in perm) + strides = np.cumprod((shapes[1:] + (1,))[::-1])[::-1] + perm_strides = [strides[i] for i in perm] + out[index] = inp[np.dot(np.unravel_index(index, new_shape), perm_strides)] + + Args: + name (str): The name of the configuration. + shape (tuple[int, ...]): The shape of the input tensor. + perm (tuple[int, ...]): The permutation of the dimensions. + + Returns: + dict: Dictionary containing the configuration. + """ + new_shape = tuple(shape[i] for i in perm) + strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] + perm_strides = tuple(int(strides[i]) for i in perm) + return dict( + dims=len(shape), + N=math.prod(shape), + from_shape=', '.join(str(x) for x in shape), + perm=', '.join(str(x) for x in perm), + perm_strides=', '.join(str(x) for x in perm_strides), + to_shape=', '.join(str(x) for x in new_shape), + config_name=name, + ) From 3ea3490e0b9bb702ea19f8d358404d8389ad6f16 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:40:07 -0700 Subject: [PATCH 23/30] attributes.attributes -> attributes --- hls4ml/backends/vivado/passes/einsum.py | 22 ++++++++--------- hls4ml/backends/vivado/passes/einsum_dense.py | 24 +++++++++---------- hls4ml/model/layers.py | 16 ++++++------- .../model/optimizer/passes/hgq_proxy_model.py | 2 +- test/pytest/test_keras_v3_api.py | 2 +- 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py index 4f076e8cd6..22f092903c 100644 --- a/hls4ml/backends/vivado/passes/einsum.py +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -49,7 +49,7 @@ def __init__(self): def format(self, node: Einsum): default_params = self._default_config_params(node) - strategy = node.attributes.attributes['strategy'] + strategy = node.attributes['strategy'] io_type = node.model.config.get_config_value('IOType') assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' @@ -58,10 +58,10 @@ def format(self, node: Einsum): # EinsumDense config params = default_params.copy() params['strategy'] = strategy - params['n_free0'] = node.attributes.attributes['n_free0'] - params['n_free1'] = node.attributes.attributes['n_free1'] - params['n_contract'] = node.attributes.attributes['n_contract'] - params['n_inplace'] = node.attributes.attributes['n_inplace'] + params['n_free0'] = node.attributes['n_free0'] + params['n_free1'] = node.attributes['n_free1'] + params['n_contract'] = node.attributes['n_contract'] + params['n_inplace'] = node.attributes['n_inplace'] inp0_t = node.get_input_variable(node.inputs[0]).type.precision inp1_t = node.get_input_variable(node.inputs[1]).type.precision params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t) @@ -72,12 +72,12 @@ def format(self, node: Einsum): einsum_conf = self.template.format(**params) # inp/out transpose config - inp0_shape = node.attributes.attributes['inp0_shape'] - inp1_shape = node.attributes.attributes['inp1_shape'] - out_interpert_shape = node.attributes.attributes['out_interpert_shape'] - inp0_tpose_idxs = node.attributes.attributes['inp0_tpose_idxs'] - inp1_tpose_idxs = node.attributes.attributes['inp1_tpose_idxs'] - out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + inp0_shape = node.attributes['inp0_shape'] + inp1_shape = node.attributes['inp1_shape'] + out_interpert_shape = node.attributes['out_interpert_shape'] + inp0_tpose_idxs = node.attributes['inp0_tpose_idxs'] + inp1_tpose_idxs = node.attributes['inp1_tpose_idxs'] + out_tpose_idxs = node.attributes['out_tpose_idxs'] tpose_inp0_conf_name = f'config{node.index}_tpose_inp0' tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' tpose_out_conf_name = f'config{node.index}_tpose_out' diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py index 5f6dab64e0..8af86add20 100644 --- a/hls4ml/backends/vivado/passes/einsum_dense.py +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -65,9 +65,9 @@ def dense_config(self, node: EinsumDense): dense_params = self._default_config_params(node) strategy = node.attributes['strategy'] dense_params['strategy'] = strategy - dense_params['n_in'] = node.attributes.attributes['n_contract'] - dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] - if node.attributes.attributes['n_inplace'] == 1: + dense_params['n_in'] = node.attributes['n_contract'] + dense_params['n_out'] = node.attributes['n_free_kernel'] + if node.attributes['n_inplace'] == 1: dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore else: dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' @@ -91,10 +91,10 @@ def format(self, node: EinsumDense): # EinsumDense config params = default_params.copy() params['strategy'] = strategy - params['n_free_data'] = node.attributes.attributes['n_free_data'] - params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] - params['n_contract'] = node.attributes.attributes['n_contract'] - params['n_inplace'] = node.attributes.attributes['n_inplace'] + params['n_free_data'] = node.attributes['n_free_data'] + params['n_free_kernel'] = node.attributes['n_free_kernel'] + params['n_contract'] = node.attributes['n_contract'] + params['n_inplace'] = node.attributes['n_inplace'] if strategy.lower() == 'latency': params['kernel_config'] = f'typedef config{node.index}_dense dense_conf' else: @@ -104,7 +104,7 @@ def format(self, node: EinsumDense): index = node.index conf = f'constexpr static auto da_kernel = nnet::einsum_dense{index}_da_kernel<{inp_t}, {result_t}>' params['kernel_config'] = conf - pf = node.attributes.attributes['parallelization_factor'] + pf = node.attributes['parallelization_factor'] if pf < 0: pf = params['n_inplace'] params['parallelization_factor'] = pf @@ -112,10 +112,10 @@ def format(self, node: EinsumDense): einsum_conf = self.template.format(**params) # inp/out transpose config - inp_shape = node.attributes.attributes['inp_shape'] - out_interpert_shape = node.attributes.attributes['out_interpert_shape'] - inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] - out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + inp_shape = node.attributes['inp_shape'] + out_interpert_shape = node.attributes['out_interpert_shape'] + inp_tpose_idxs = node.attributes['inp_tpose_idxs'] + out_tpose_idxs = node.attributes['out_tpose_idxs'] tpose_inp_conf_name = f'config{node.index}_tpose_inp' tpose_out_conf_name = f'config{node.index}_tpose_out' diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 041ef8ab8d..91935e9b61 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1670,8 +1670,8 @@ def initialize(self): dims = [f'N_LAYER_{self.index}'] self.add_output_variable(list(out_shape), dims) - kernel: np.ndarray = self.attributes.attributes['weight_data'] - bias: np.ndarray | None = self.attributes.attributes['bias_data'] + kernel: np.ndarray = self.attributes['weight_data'] + bias: np.ndarray | None = self.attributes['bias_data'] equation = self.attributes['equation'] inp_shape = self.attributes['inp_shape'] out_shape = self.attributes['out_shape'] @@ -1705,9 +1705,9 @@ def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: # The transpose is just to match the shape in case of have real bias, no real effect. bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) - self.attributes.attributes['weight_data'] = kernel - self.attributes.attributes['to_original_kernel'] = to_original_kernel - self.attributes.attributes['bias_data'] = bias + self.attributes['weight_data'] = kernel + self.attributes['to_original_kernel'] = to_original_kernel + self.attributes['bias_data'] = bias self.attributes['inp_tpose_idxs'] = inp_tpose_idxs self.attributes['out_tpose_idxs'] = out_tpose_idxs self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] @@ -1715,7 +1715,7 @@ def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: self.attributes['n_free_kernel'] = recipe['L1'] self.attributes['n_inplace'] = recipe['I'] self.attributes['n_contract'] = recipe['C'] - pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + pf = self.attributes.get('parallelization_factor', recipe['L0']) self.attributes['parallelization_factor'] = pf self.add_weights(compression=self.model.config.get_compression(self)) @@ -1760,7 +1760,7 @@ def initialize(self): inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] out_tpose_idxs = recipe['out_transpose_idxs'] - self.attributes.attributes.update(recipe) + self.attributes.update(recipe) self.attributes['n_free0'] = recipe['L0'] self.attributes['n_free1'] = recipe['L1'] self.attributes['n_inplace'] = recipe['I'] @@ -1771,7 +1771,7 @@ def initialize(self): self.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs self.attributes['out_tpose_idxs'] = out_tpose_idxs - pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + pf = self.attributes.get('parallelization_factor', recipe['L0']) self.attributes['parallelization_factor'] = pf diff --git a/hls4ml/model/optimizer/passes/hgq_proxy_model.py b/hls4ml/model/optimizer/passes/hgq_proxy_model.py index 13e48aac43..27847803da 100644 --- a/hls4ml/model/optimizer/passes/hgq_proxy_model.py +++ b/hls4ml/model/optimizer/passes/hgq_proxy_model.py @@ -129,7 +129,7 @@ def transform(self, model, node: FixedPointQuantizer): weight_var.update_precision(precision) # Well, it turned out that there is yet ANOTHER copy saved in config. model.config.layer_name_precision[f'{name}_{k[:-2]}'] = v - elif k in target_node.attributes.attributes: + elif k in target_node.attributes: target_node.set_attr(k, v) elif k == 'parallelization_factor': target_node.set_attr(k, int(v)) diff --git a/test/pytest/test_keras_v3_api.py b/test/pytest/test_keras_v3_api.py index 81ac5c240c..a77ce32fcc 100644 --- a/test/pytest/test_keras_v3_api.py +++ b/test/pytest/test_keras_v3_api.py @@ -115,7 +115,7 @@ def test_activations(activation_function, backend, io_type): np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.02) for layer in hls_model.get_layers(): - print(layer.attributes.attributes['class_name']) + print(layer.attributes['class_name']) assert len(model.layers) + 1 == len(hls_model.get_layers()) assert list(hls_model.get_layers())[2].attributes['class_name'] == activation_function.__class__.__name__ From 5d4bdfe4e71e3ea05aa451e53d502f34bae0e369 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:56:26 -0700 Subject: [PATCH 24/30] isolate keras v2 and v3 to hls --- hls4ml/contrib/kl_layer/kl_layer.py | 2 +- hls4ml/converters/__init__.py | 25 ++--- hls4ml/converters/keras/convolution.py | 2 +- hls4ml/converters/keras/core.py | 2 +- hls4ml/converters/keras/graph.py | 2 +- hls4ml/converters/keras/hgq_proxy_model.py | 2 +- hls4ml/converters/keras/merge.py | 2 +- hls4ml/converters/keras/model.py | 2 +- hls4ml/converters/keras/pooling.py | 2 +- hls4ml/converters/keras/qkeras.py | 2 +- hls4ml/converters/keras/recurrent.py | 2 +- hls4ml/converters/keras/reshape.py | 2 +- hls4ml/converters/keras/reshaping.py | 2 +- .../{keras_to_hls.py => keras_v2_to_hls.py} | 11 +-- hls4ml/converters/keras_v3_to_hls.py | 98 ++++++++++--------- 15 files changed, 80 insertions(+), 78 deletions(-) rename hls4ml/converters/{keras_to_hls.py => keras_v2_to_hls.py} (97%) diff --git a/hls4ml/contrib/kl_layer/kl_layer.py b/hls4ml/contrib/kl_layer/kl_layer.py index 44b610d327..c3c27a849a 100644 --- a/hls4ml/contrib/kl_layer/kl_layer.py +++ b/hls4ml/contrib/kl_layer/kl_layer.py @@ -21,7 +21,7 @@ from tensorflow.python.ops import math_ops import hls4ml -from hls4ml.converters.keras_to_hls import parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import parse_default_keras_layer from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 47569b1ad9..162ad53a30 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -3,13 +3,13 @@ import yaml -from hls4ml.converters.keras_to_hls import KerasFileReader # noqa: F401 -from hls4ml.converters.keras_to_hls import KerasModelReader # noqa: F401 -from hls4ml.converters.keras_to_hls import KerasReader # noqa: F401 -from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 -from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 -from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler -from hls4ml.converters.keras_v3_to_hls import parse_keras_v3_model # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import KerasFileReader # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import KerasModelReader # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import KerasReader # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import get_supported_keras_layers # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import parse_keras_model # noqa: F401 +from hls4ml.converters.keras_v2_to_hls import keras_v2_to_hls, register_keras_layer_handler +from hls4ml.converters.keras_v3_to_hls import keras_v3_to_hls, parse_keras_v3_model # noqa: F401 from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401 from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler @@ -18,8 +18,6 @@ pytorch_to_hls, register_pytorch_layer_handler, ) - -# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config from hls4ml.utils.dependency import requires @@ -116,7 +114,7 @@ def convert_from_config(config): elif 'PytorchModel' in yamlConfig: model = pytorch_to_hls(yamlConfig) else: - model = keras_to_hls(yamlConfig) + model = keras_v2_to_hls(yamlConfig) return model @@ -216,8 +214,13 @@ def convert_from_keras_model( config['HLSConfig']['Model'] = _check_model_config(model_config) _check_hls_config(config, hls_config) + if 'KerasModel' in config: + import keras + + if keras.__version__ >= '3.0': + return keras_v3_to_hls(config) - return keras_to_hls(config) + return keras_v2_to_hls(config) @requires('_torch') diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 2900fe019f..d6be2c518a 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import get_weights_data, keras_handler, parse_default_keras_layer from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index 637bb6d401..93ae7995ff 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import get_weights_data, keras_handler, parse_default_keras_layer from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer from hls4ml.model.types import IntegerPrecisionType diff --git a/hls4ml/converters/keras/graph.py b/hls4ml/converters/keras/graph.py index 954bf20b8f..a12242b574 100644 --- a/hls4ml/converters/keras/graph.py +++ b/hls4ml/converters/keras/graph.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import get_weights_data, keras_handler, parse_default_keras_layer from hls4ml.model.quantizers import TernaryQuantizer diff --git a/hls4ml/converters/keras/hgq_proxy_model.py b/hls4ml/converters/keras/hgq_proxy_model.py index 1598759253..ddc86fb8a6 100644 --- a/hls4ml/converters/keras/hgq_proxy_model.py +++ b/hls4ml/converters/keras/hgq_proxy_model.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import KerasReader, keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import KerasReader, keras_handler, parse_default_keras_layer @keras_handler('FixedPointQuantizer', 'HGQ>FixedPointQuantizer') diff --git a/hls4ml/converters/keras/merge.py b/hls4ml/converters/keras/merge.py index 1423308cff..00d3792430 100644 --- a/hls4ml/converters/keras/merge.py +++ b/hls4ml/converters/keras/merge.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import keras_handler, parse_default_keras_layer merge_layers = ['Add', 'Subtract', 'Multiply', 'Average', 'Maximum', 'Minimum', 'Concatenate', 'Dot'] diff --git a/hls4ml/converters/keras/model.py b/hls4ml/converters/keras/model.py index 3f22907058..145042bb50 100644 --- a/hls4ml/converters/keras/model.py +++ b/hls4ml/converters/keras/model.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import ( +from hls4ml.converters.keras_v2_to_hls import ( KerasFileReader, KerasModelReader, KerasNestedFileReader, diff --git a/hls4ml/converters/keras/pooling.py b/hls4ml/converters/keras/pooling.py index 1f6dd07c4e..8a905db94c 100644 --- a/hls4ml/converters/keras/pooling.py +++ b/hls4ml/converters/keras/pooling.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import keras_handler, parse_default_keras_layer from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format pooling_layers = ['MaxPooling1D', 'MaxPooling2D', 'AveragePooling1D', 'AveragePooling2D'] diff --git a/hls4ml/converters/keras/qkeras.py b/hls4ml/converters/keras/qkeras.py index 8d50eb512e..fd670ad0d9 100644 --- a/hls4ml/converters/keras/qkeras.py +++ b/hls4ml/converters/keras/qkeras.py @@ -1,7 +1,7 @@ from hls4ml.converters.keras.convolution import parse_conv1d_layer, parse_conv2d_layer from hls4ml.converters.keras.core import parse_batchnorm_layer, parse_dense_layer from hls4ml.converters.keras.recurrent import parse_rnn_layer -from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import keras_handler, parse_default_keras_layer from hls4ml.model.quantizers import QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer from hls4ml.model.types import FixedPrecisionType diff --git a/hls4ml/converters/keras/recurrent.py b/hls4ml/converters/keras/recurrent.py index 55dd5bf82e..9f98b33f76 100644 --- a/hls4ml/converters/keras/recurrent.py +++ b/hls4ml/converters/keras/recurrent.py @@ -1,4 +1,4 @@ -from hls4ml.converters.keras_to_hls import ( +from hls4ml.converters.keras_v2_to_hls import ( KerasModelReader, KerasNestedFileReader, KerasWrappedLayerFileReader, diff --git a/hls4ml/converters/keras/reshape.py b/hls4ml/converters/keras/reshape.py index 7d58252703..244c1615e4 100644 --- a/hls4ml/converters/keras/reshape.py +++ b/hls4ml/converters/keras/reshape.py @@ -1,6 +1,6 @@ import numpy as np -from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import keras_handler, parse_default_keras_layer from hls4ml.converters.utils import parse_data_format diff --git a/hls4ml/converters/keras/reshaping.py b/hls4ml/converters/keras/reshaping.py index b6c0052973..404ce52f73 100644 --- a/hls4ml/converters/keras/reshaping.py +++ b/hls4ml/converters/keras/reshaping.py @@ -1,6 +1,6 @@ import collections.abc -from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer +from hls4ml.converters.keras_v2_to_hls import keras_handler, parse_default_keras_layer @keras_handler('ZeroPadding1D') diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_v2_to_hls.py similarity index 97% rename from hls4ml/converters/keras_to_hls.py rename to hls4ml/converters/keras_v2_to_hls.py index e002909bda..6099bb138f 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_v2_to_hls.py @@ -4,8 +4,6 @@ from hls4ml.model import ModelGraph -from .keras_v3_to_hls import parse_keras_v3_model - class KerasReader: def get_weights_data(self, layer_name, var_name): @@ -356,14 +354,7 @@ def parse_keras_model(model_arch, reader): return layer_list, input_layers, output_layers, output_shapes -def keras_to_hls(config): - if 'KerasModel' in config: - import keras - - if keras.__version__ >= '3.0': - layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) - return ModelGraph(config, layer_list, input_layers, output_layers) - +def keras_v2_to_hls(config): model_arch, reader = get_model_arch(config) layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader) print('Creating HLS model') diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py index f94928e811..0b89022c76 100644 --- a/hls4ml/converters/keras_v3_to_hls.py +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -4,11 +4,13 @@ from types import FunctionType from typing import Any +import numpy as np + +from hls4ml.model import ModelGraph + if typing.TYPE_CHECKING: import keras - from keras.api import KerasTensor - -import numpy as np + from keras import KerasTensor from .keras_v3 import layer_handlers as v3_layer_handlers @@ -19,26 +21,24 @@ def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None): '''Given a keras layer, return a list of tuples of input and output - tensors. If the layer is called only once (i.e., no shared layers), + tensors. If the layer is called only once (i.e., a layer is not used multiple times in the same model), the list will contain only one tuple. The layer must have been built before calling this function. - Parameters - ---------- - layer : keras.Layer - The layer to get input and output tensors from. - node_whitelist : set[int]|None, optional - If not None, only return tensors from nodes with ids in this - set, used to filter out nodes that are not part of the model, by - default None + Args: + layer: The layer to get input and output tensors from. + node_whitelist: If not None, only return tensors from nodes + with ids in this set, used to filter out nodes that are not + part of the model. Defaults to None. - Returns - ------- - list[tuple[tuple['KerasTensor', ...], tuple['KerasTensor', ...]]] - A list of tuples of input and output tensors. + Returns: + A list of tuples of input and output tensors. Each inner tuple + contains two tuples: the first with input KerasTensors and the + second with output KerasTensors. ''' + in_nodes = layer._inbound_nodes if node_whitelist is not None: in_nodes = [node for node in in_nodes if id(node) in node_whitelist] @@ -53,21 +53,25 @@ def get_io_tensors(layer: 'keras.Layer', node_whitelist: set[int] | None = None) def resolve_dependency_relation(model: 'keras.Model'): '''Given a keras model, return the following information: - - A list of input tensor names - - A list of output tensor names - - A list of (layer_name, input_tensor_names, output_tensor_names) tuples - - A dictionary of tensor_name -> KerasTensor - - Parameters - ---------- - model : keras.Model - The keras model to analyze. - - Returns - ------- - tuple[tuple[str, ...], tuple[str, ...], list[tuple[str, tuple[str, ...], tuple[str, ...]]], dict[str, KerasTensor]] - inp_tensor_names, out_tensor_names, layer_io, tensors + - A list of input tensor names + - A list of output tensor names + - A list of (layer_name, input_tensor_names, output_tensor_names) tuples + - A dictionary of tensor_name -> KerasTensor + + Args: + model: The keras model to analyze. + + Returns: + A tuple containing: + - inp_tensor_names (tuple[str, ...]): A tuple of input tensor names. + - out_tensor_names (tuple[str, ...]): A tuple of output tensor names. + - layer_io (list[tuple[str, tuple[str, ...], tuple[str, ...]]]): A list of + tuples, where each tuple contains the layer name, a tuple of its + input tensor names, and a tuple of its output tensor names. + - tensors (dict[str, KerasTensor]): A dictionary mapping tensor names + to KerasTensor objects. ''' + tensors: dict[str, 'KerasTensor'] = {} 'tensor_name -> KerasTensor' depends_on: dict[str, tuple[str, ...]] = {} @@ -163,7 +167,7 @@ def v2_call( config = layer.get_config() layer_dict = {'config': config, 'class_name': layer.__class__.__name__} - class DummyReader: + class IsolatedLayerReader: def get_weights_data(self, layer_name, var_name): assert layer_name == layer.name, f'Processing {layer.name}, but handler tried to read {layer_name}' for w in layer.weights: @@ -171,7 +175,7 @@ def get_weights_data(self, layer_name, var_name): return np.array(w) return None - reader = DummyReader() + reader = IsolatedLayerReader() input_shapes = [list(t.shape) for t in inp_tensors] input_names = [t.name for t in inp_tensors] output_names = [t.name for t in out_tensors] @@ -207,20 +211,19 @@ def parse_keras_v3_model(model: 'keras.Model'): representing a layer in the HLS model, and a list of input and output layer names. - Parameters - ---------- - model : keras.Model + Args: + model: keras.Model - Returns - ------- - tuple[list[dict[str, Any]], list[str], list[str], list[list[int]]] - layer_list, input_layer_names, output_layer_names, - batch_output_shapes + Returns: + A tuple containing: + - layer_list (list[dict[str, Any]]): A list of dictionaries, + each representing a layer in the HLS model. + - input_layer_names (list[str]): A list of input layer names. + - output_layer_names (list[str]): A list of output layer names. + - batch_output_shapes (list[list[int]]): A list of output shapes. - Raises - ------ - ValueError - If a circular dependency is detected. + Raises: + ValueError: If a circular dependency is detected. ''' assert model.built, 'Model must be built before parsing' @@ -230,7 +233,7 @@ def parse_keras_v3_model(model: 'keras.Model'): if isinstance(model, keras.Sequential): model = model._functional # everything is functional under the hood lol - from .keras_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import + from .keras_v2_to_hls import layer_handlers as v2_layer_handlers # Delayed import to avoid circular import keras_v3_dispatcher = KerasV3HandlerDispatcher(v3_layer_handlers, v2_layer_handlers) @@ -283,3 +286,8 @@ def parse_keras_v3_model(model: 'keras.Model'): batch_output_shapes = [list(tensors[tname].shape) for tname in model_outputs] return layer_list, input_layer_names, output_layer_names, batch_output_shapes + + +def keras_v3_to_hls(config): + layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) + return ModelGraph(config, layer_list, input_layers, output_layers) From 150a3f606038c9be9dd01ca95de6bbbd9f2c5189 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 06:59:40 -0700 Subject: [PATCH 25/30] update tests for api changes --- test/pytest/test_einsum_dense.py | 2 +- test/pytest/test_fetch_example.py | 2 +- test/pytest/test_garnet.py | 4 ++-- test/pytest/test_keras_h5_loader.py | 2 +- test/pytest/test_keras_v3_api.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/pytest/test_einsum_dense.py b/test/pytest/test_einsum_dense.py index dbddf545ff..566a0bb37f 100644 --- a/test/pytest/test_einsum_dense.py +++ b/test/pytest/test_einsum_dense.py @@ -9,7 +9,7 @@ if keras.__version__ < '3.0.0': pytest.skip('Only keras v3 is supported for now', allow_module_level=True) -from keras.api.layers import EinsumDense, Input +from keras.layers import EinsumDense, Input test_root_path = Path(__file__).parent diff --git a/test/pytest/test_fetch_example.py b/test/pytest/test_fetch_example.py index 6e640a94a0..5de748c348 100644 --- a/test/pytest/test_fetch_example.py +++ b/test/pytest/test_fetch_example.py @@ -28,5 +28,5 @@ def test_fetch_example_utils(backend): config['Backend'] = backend config['OutputDir'] = str(test_root_path / f'hls4mlprj_fetch_example_{backend}') - hls_model = hls4ml.converters.keras_to_hls(config) + hls_model = hls4ml.converters.keras_v2_to_hls(config) hls_model.compile() # For now, it is enough if it compiles, we're only testing downloading works as expected diff --git a/test/pytest/test_garnet.py b/test/pytest/test_garnet.py index 057fe36c78..c910e43dac 100644 --- a/test/pytest/test_garnet.py +++ b/test/pytest/test_garnet.py @@ -44,7 +44,7 @@ def garnet_models(): cfg['HLSConfig'] = config cfg['KerasModel'] = model - hls_model = hls4ml.converters.keras_to_hls(cfg) + hls_model = hls4ml.converters.keras_v2_to_hls(cfg) hls_model.compile() return model, hls_model @@ -78,7 +78,7 @@ def garnet_stack_models(): cfg['HLSConfig'] = config cfg['KerasModel'] = model - hls_model = hls4ml.converters.keras_to_hls(cfg) + hls_model = hls4ml.converters.keras_v2_to_hls(cfg) hls_model.compile() return model, hls_model diff --git a/test/pytest/test_keras_h5_loader.py b/test/pytest/test_keras_h5_loader.py index 0c42adee31..2987fc4466 100644 --- a/test/pytest/test_keras_h5_loader.py +++ b/test/pytest/test_keras_h5_loader.py @@ -32,7 +32,7 @@ def test_keras_h5_loader(backend): } model.save(config['KerasH5']) - hls_model = hls4ml.converters.keras_to_hls(config) + hls_model = hls4ml.converters.keras_v2_to_hls(config) hls_model.compile() data = np.random.rand(1000, 10).astype(np.float32) pred = hls_model.predict(data) diff --git a/test/pytest/test_keras_v3_api.py b/test/pytest/test_keras_v3_api.py index a77ce32fcc..8ced9ce578 100644 --- a/test/pytest/test_keras_v3_api.py +++ b/test/pytest/test_keras_v3_api.py @@ -8,7 +8,7 @@ if keras.__version__ < '3.0': pytest.skip('Keras API tests are only for Keras 3.0 and above', allow_module_level=True) -from keras.api.layers import ( +from keras.layers import ( ELU, Activation, AveragePooling1D, From ec914d1ef974dfa1f5a9ad3da88c8079eb053c90 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 07:02:39 -0700 Subject: [PATCH 26/30] update docs --- docs/frontend/keras.rst | 6 +++--- docs/intro/setup.rst | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/frontend/keras.rst b/docs/frontend/keras.rst index 9ede7b1d8c..093db120b4 100644 --- a/docs/frontend/keras.rst +++ b/docs/frontend/keras.rst @@ -1,6 +1,6 @@ -================ +================================ Keras and its quantized variants -================ +================================ Keras and the quantization library QKeras are well supported in ``hls4ml``. Both Keras v2 (``tf.keras``) and the new Keras v3 are supported. While the Keras v2 support is based on parsing the serialized json representation of the model, the Keras v3 support uses direct model inspection. @@ -10,7 +10,7 @@ The ``data_format='channels_first'`` parameter of Keras layers is supported, but * `QKeras `_ - The equivalent QKeras API and its quantizers are also supported by ``hls4ml``. QKeras is not compatible with Keras v3. + The equivalent QKeras API and its quantizers are also supported by ``hls4ml``. QKeras is not compatible with Keras v3. Currently, only HGQ2 is compatible with Keras v3 (see below). * `HGQ `_ The equivalent HGQ API is also supported. HGQ is not compatible with Keras v3. See `advanced/HGQ <../advanced/hgq.html>`__ for more information. * `HGQ2 `_ diff --git a/docs/intro/setup.rst b/docs/intro/setup.rst index 3c21290981..682d5fe54e 100644 --- a/docs/intro/setup.rst +++ b/docs/intro/setup.rst @@ -42,7 +42,7 @@ Dependencies The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed by ``pip`` or ``conda``. -The following Python packages are all optional and are only required if you intend to use the corresponding converter. Only install the packages you need. +The following Python packages are all optional and are only required if you intend to use the corresponding converter. * `Keras `_ is required by the Keras converter. * `TensorFlow `_ (version 2.8 to 2.14) is required by the Keras v2 converter (keras v2 is included in TensorFlow). From 312328e8634f42ca2c7ed1210de7456fad40f0a8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 07:34:22 -0700 Subject: [PATCH 27/30] mv einops to vivado backend, rm unused args --- hls4ml/backends/vivado/passes/einsum.py | 12 +-- hls4ml/backends/vivado/passes/einsum_dense.py | 1 - hls4ml/backends/vivado/vivado_backend.py | 79 +++++++++++++++++ hls4ml/converters/keras_v3/merge.py | 2 +- hls4ml/model/layers.py | 85 ------------------- .../templates/vivado/nnet_utils/nnet_einsum.h | 15 ++-- .../vivado/nnet_utils/nnet_einsum_dense.h | 1 - test/pytest/test_einsum_dense.py | 2 +- 8 files changed, 94 insertions(+), 103 deletions(-) diff --git a/hls4ml/backends/vivado/passes/einsum.py b/hls4ml/backends/vivado/passes/einsum.py index 22f092903c..ef2daaaf88 100644 --- a/hls4ml/backends/vivado/passes/einsum.py +++ b/hls4ml/backends/vivado/passes/einsum.py @@ -12,8 +12,8 @@ einsum_config_template = ''' struct config{index} {{ - typedef config{index}_tpose_inp0 tpose_inp0_conf; - typedef config{index}_tpose_inp1 tpose_inp1_conf; + typedef config{index}_tpose_inp0 tpose_inp0_config; + typedef config{index}_tpose_inp1 tpose_inp1_config; typedef config{index}_tpose_out tpose_out_conf; typedef {accum_t.name} accum_t; @@ -78,13 +78,13 @@ def format(self, node: Einsum): inp0_tpose_idxs = node.attributes['inp0_tpose_idxs'] inp1_tpose_idxs = node.attributes['inp1_tpose_idxs'] out_tpose_idxs = node.attributes['out_tpose_idxs'] - tpose_inp0_conf_name = f'config{node.index}_tpose_inp0' - tpose_inp1_conf_name = f'config{node.index}_tpose_inp1' + tpose_inp0_config_name = f'config{node.index}_tpose_inp0' + tpose_inp1_config_name = f'config{node.index}_tpose_inp1' tpose_out_conf_name = f'config{node.index}_tpose_out' - conf = transpose_config_gen(tpose_inp0_conf_name, inp0_shape, inp0_tpose_idxs) + conf = transpose_config_gen(tpose_inp0_config_name, inp0_shape, inp0_tpose_idxs) inp0_tpose_conf = transpose_config_template.format(**conf) - conf = transpose_config_gen(tpose_inp1_conf_name, inp1_shape, inp1_tpose_idxs) + conf = transpose_config_gen(tpose_inp1_config_name, inp1_shape, inp1_tpose_idxs) inp1_tpose_conf = transpose_config_template.format(**conf) conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) out_tpose_conf = transpose_config_template.format(**conf) diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py index 8af86add20..e8e3304512 100644 --- a/hls4ml/backends/vivado/passes/einsum_dense.py +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -45,7 +45,6 @@ static const unsigned strategy = nnet::{strategy}; static const unsigned reuse_factor = {reuse_factor}; static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 - static const bool store_weights_in_bram = false; // NOT USED }}; ''' diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index fa564d2b0c..d7b58c6e44 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -35,6 +35,7 @@ from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType from hls4ml.report import parse_vivado_report from hls4ml.utils import attribute_descriptions as descriptions +from hls4ml.utils.einsum_utils import parse_einsum class VivadoBackend(FPGABackend): @@ -690,6 +691,57 @@ def init_garnet_stack(self, layer): @layer_optimizer(EinsumDense) def init_einsum_dense(self, layer: EinsumDense) -> None: + kernel: np.ndarray = layer.attributes['weight_data'] + bias: np.ndarray | None = layer.attributes['bias_data'] + equation = layer.attributes['equation'] + inp_shape = layer.attributes['inp_shape'] + out_shape = layer.attributes['out_shape'] + + kernel_shape = kernel.shape + recipe = parse_einsum(equation, inp_shape, kernel_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. + # hls4ml dense acts like i,ij->j + # parser assumes ij,j->i, so we need to transpose the kernel to match + kernel = kernel.transpose(ker_tpose_idxs) + kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) + + def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: + _kernel = tkernel.transpose(0, 2, 1) + _kernel = _kernel.reshape(tuple(kernel_shape[i] for i in ker_tpose_idxs)) + return _kernel.transpose(np.argsort(ker_tpose_idxs)) + + # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. + if bias is not None: + bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) + else: + # The automatically created bias is just the last dimension of the output shape + # Which is too small in general for einsum dense. + # The transpose is just to match the shape in case of have real bias, no real effect. + bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) + + layer.attributes['weight_data'] = kernel + layer.attributes['to_original_kernel'] = to_original_kernel + layer.attributes['bias_data'] = bias + layer.attributes['inp_tpose_idxs'] = inp_tpose_idxs + layer.attributes['out_tpose_idxs'] = out_tpose_idxs + layer.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + layer.attributes['n_free_data'] = recipe['L0'] + layer.attributes['n_free_kernel'] = recipe['L1'] + layer.attributes['n_inplace'] = recipe['I'] + layer.attributes['n_contract'] = recipe['C'] + pf = layer.attributes.get('parallelization_factor', recipe['L0']) + layer.attributes['parallelization_factor'] = pf + + layer.add_weights(compression=layer.model.config.get_compression(layer)) + layer.add_bias() + strategy: str | None = layer.model.config.get_strategy(layer) if not strategy: layer.set_attr('strategy', 'latency') @@ -702,6 +754,33 @@ def init_einsum_dense(self, layer: EinsumDense) -> None: @layer_optimizer(Einsum) def init_einsum(self, layer: Einsum) -> None: + + equation = layer.attributes['equation'] + inp0_shape = layer.attributes['inp0_shape'] + inp1_shape = layer.attributes['inp1_shape'] + + recipe = parse_einsum(equation, inp0_shape, inp1_shape) + assert not any(recipe['direct_sum_axis']), ( + 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' + 'Use explicit addition operator before instead.' + ) + inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + layer.attributes.update(recipe) + layer.attributes['n_free0'] = recipe['L0'] + layer.attributes['n_free1'] = recipe['L1'] + layer.attributes['n_inplace'] = recipe['I'] + layer.attributes['n_contract'] = recipe['C'] + layer.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + + layer.attributes['inp0_tpose_idxs'] = inp0_tpose_idxs + layer.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs + layer.attributes['out_tpose_idxs'] = out_tpose_idxs + + pf = layer.attributes.get('parallelization_factor', recipe['L0']) + layer.attributes['parallelization_factor'] = pf + strategy: str | None = layer.model.config.get_strategy(layer) if not strategy: layer.set_attr('strategy', 'latency') diff --git a/hls4ml/converters/keras_v3/merge.py b/hls4ml/converters/keras_v3/merge.py index 96c5547bae..8ed4dd5060 100644 --- a/hls4ml/converters/keras_v3/merge.py +++ b/hls4ml/converters/keras_v3/merge.py @@ -36,7 +36,7 @@ def handle( config: dict[str, Any] = {'output_shape': output_shape} op = cls_name.lower() - match cls_name.lower(): + match cls_name: case 'Concatenate': rank = len(output_shape) class_name = f'Concatenate{rank}d' diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 91935e9b61..e843417685 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -27,7 +27,6 @@ find_minimum_width, ) from hls4ml.utils import attribute_descriptions as descriptions -from hls4ml.utils.einsum_utils import parse_einsum from hls4ml.utils.string_utils import convert_to_snake_case # TODO move this to some utility module @@ -1669,67 +1668,10 @@ def initialize(self): else: dims = [f'N_LAYER_{self.index}'] self.add_output_variable(list(out_shape), dims) - - kernel: np.ndarray = self.attributes['weight_data'] - bias: np.ndarray | None = self.attributes['bias_data'] - equation = self.attributes['equation'] - inp_shape = self.attributes['inp_shape'] - out_shape = self.attributes['out_shape'] - - kernel_shape = kernel.shape - recipe = parse_einsum(equation, inp_shape, kernel_shape) - assert not any(recipe['direct_sum_axis']), ( - 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' - 'Use explicit addition operator before instead.' - ) - inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] - out_tpose_idxs = recipe['out_transpose_idxs'] - - # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. - # hls4ml dense acts like i,ij->j - # parser assumes ij,j->i, so we need to transpose the kernel to match - kernel = kernel.transpose(ker_tpose_idxs) - kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) - - def to_original_kernel(tkernel: np.ndarray) -> np.ndarray: - _kernel = tkernel.transpose(0, 2, 1) - _kernel = _kernel.reshape(tuple(kernel_shape[i] for i in ker_tpose_idxs)) - return _kernel.transpose(np.argsort(ker_tpose_idxs)) - - # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. - if bias is not None: - bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) - else: - # The automatically created bias is just the last dimension of the output shape - # Which is too small in general for einsum dense. - # The transpose is just to match the shape in case of have real bias, no real effect. - bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) - - self.attributes['weight_data'] = kernel - self.attributes['to_original_kernel'] = to_original_kernel - self.attributes['bias_data'] = bias - self.attributes['inp_tpose_idxs'] = inp_tpose_idxs - self.attributes['out_tpose_idxs'] = out_tpose_idxs - self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] - self.attributes['n_free_data'] = recipe['L0'] - self.attributes['n_free_kernel'] = recipe['L1'] - self.attributes['n_inplace'] = recipe['I'] - self.attributes['n_contract'] = recipe['C'] - pf = self.attributes.get('parallelization_factor', recipe['L0']) - self.attributes['parallelization_factor'] = pf - self.add_weights(compression=self.model.config.get_compression(self)) self.add_bias() -class Matmul(Layer): - _expected_attributes = [ - TypeAttribute('accum'), - Attribute('inup1_shape', value_type=tuple), - Attribute('inp2_shape', value_type=tuple), - ] - - class Einsum(Layer): _expected_attributes = [ TypeAttribute('accum'), @@ -1747,33 +1689,6 @@ def initialize(self): dims = [f'N_LAYER_{self.index}'] self.add_output_variable(list(out_shape), dims) - equation = self.attributes['equation'] - inp0_shape = self.attributes['inp0_shape'] - inp1_shape = self.attributes['inp1_shape'] - out_shape = self.attributes['out_shape'] - - recipe = parse_einsum(equation, inp0_shape, inp1_shape) - assert not any(recipe['direct_sum_axis']), ( - 'Do not put direct sum indices (e.g., only appears in one of the operands) in the equation.' - 'Use explicit addition operator before instead.' - ) - inp0_tpose_idxs, inp1_tpose_idxs = recipe['in_transpose_idxs'] - out_tpose_idxs = recipe['out_transpose_idxs'] - - self.attributes.update(recipe) - self.attributes['n_free0'] = recipe['L0'] - self.attributes['n_free1'] = recipe['L1'] - self.attributes['n_inplace'] = recipe['I'] - self.attributes['n_contract'] = recipe['C'] - self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] - - self.attributes['inp0_tpose_idxs'] = inp0_tpose_idxs - self.attributes['inp1_tpose_idxs'] = inp1_tpose_idxs - self.attributes['out_tpose_idxs'] = out_tpose_idxs - - pf = self.attributes.get('parallelization_factor', recipe['L0']) - self.attributes['parallelization_factor'] = pf - layer_map = { 'Input': Input, diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h index cc2917783c..0901787505 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum.h @@ -8,8 +8,8 @@ namespace nnet { struct config_einsum { - typedef void tpose_inp0_conf; - typedef void tpose_inp1_conf; + typedef void tpose_inp0_config; + typedef void tpose_inp1_config; typedef void tpose_out_conf; // Layer Sizes @@ -23,28 +23,27 @@ struct config_einsum { static const unsigned strategy; static const unsigned reuse_factor; static const unsigned multiplier_limit; - static const bool store_weights_in_bram = false; // NOT USED template using product = nnet::product::mult; }; template -void einsum(const data0_T data0[CONFIG_T::tpose_inp0_conf::N], const data1_T data1[CONFIG_T::tpose_inp1_conf::N], +void einsum(const data0_T data0[CONFIG_T::tpose_inp0_config::N], const data1_T data1[CONFIG_T::tpose_inp1_config::N], res_T res[CONFIG_T::tpose_out_conf::N]) { #pragma HLS PIPELINE II = CONFIG_T::reuse_factor #pragma HLS ALLOCATION operation instances = mul limit = CONFIG_T::multiplier_limit - data0_T tpose_i0[CONFIG_T::tpose_inp0_conf::N]; - data1_T tpose_i1[CONFIG_T::tpose_inp1_conf::N]; + data0_T tpose_i0[CONFIG_T::tpose_inp0_config::N]; + data1_T tpose_i1[CONFIG_T::tpose_inp1_config::N]; res_T tpose_o[CONFIG_T::tpose_out_conf::N]; #pragma HLS ARRAY_PARTITION variable = tpose_i0 complete #pragma HLS ARRAY_PARTITION variable = tpose_i1 complete #pragma HLS ARRAY_PARTITION variable = tpose_o complete - nnet::transpose(data0, tpose_i0); - nnet::transpose(data1, tpose_i1); + nnet::transpose(data0, tpose_i0); + nnet::transpose(data1, tpose_i1); // for l0 in range(L0): // for i in range(I): diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h index 9f26ff0bd7..f095b02d93 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h @@ -30,7 +30,6 @@ struct einsum_dense_config { static const unsigned strategy = latency; static const unsigned reuse_factor = 1; static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 - static const bool store_weights_in_bram = false; // NOT USED // Product function to use template using product = nnet::product::mult; diff --git a/test/pytest/test_einsum_dense.py b/test/pytest/test_einsum_dense.py index 566a0bb37f..dd773b1642 100644 --- a/test/pytest/test_einsum_dense.py +++ b/test/pytest/test_einsum_dense.py @@ -14,7 +14,7 @@ test_root_path = Path(__file__).parent -@pytest.mark.parametrize('strategy', ['latency', 'distributed_arithmetic']) +@pytest.mark.parametrize('strategy', ['latency']) @pytest.mark.parametrize('io_type', ['io_parallel']) @pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) @pytest.mark.parametrize( From 9c585aab2e6de0d97d01784d45843e23775c43b2 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 27 May 2025 07:38:36 -0700 Subject: [PATCH 28/30] post merge fix --- hls4ml/converters/keras_v3_to_hls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py index 0b89022c76..bc9d15f259 100644 --- a/hls4ml/converters/keras_v3_to_hls.py +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -290,4 +290,4 @@ def parse_keras_v3_model(model: 'keras.Model'): def keras_v3_to_hls(config): layer_list, input_layers, output_layers, _ = parse_keras_v3_model(config['KerasModel']) - return ModelGraph(config, layer_list, input_layers, output_layers) + return ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers) From 81535223eab86db8e730a3b626dc1d98648fad33 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 30 May 2025 02:01:17 -0700 Subject: [PATCH 29/30] quality-of-life changes --- hls4ml/converters/keras_v3_to_hls.py | 3 ++- hls4ml/model/graph.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/hls4ml/converters/keras_v3_to_hls.py b/hls4ml/converters/keras_v3_to_hls.py index bc9d15f259..25d0610788 100644 --- a/hls4ml/converters/keras_v3_to_hls.py +++ b/hls4ml/converters/keras_v3_to_hls.py @@ -100,7 +100,8 @@ class UniqueName: '''Helper class to generate unique names for layers, if one being used multiple times.''' def __init__(self): - self.used_names: set[str] = set() + self.used_names: set[str] = set('input') + # input is reserved in hls4ml, avoid conflict with it def next_name(self, name: str): i = 0 diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 0f80a020d8..77b4bdc74a 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -867,6 +867,12 @@ def _compute_n_samples(self, x): return int(n_sample) def predict(self, x): + if isinstance(x, np.ndarray) and not x.flags['C_CONTIGUOUS']: + x = np.ascontiguousarray(x) + + # Compile the model if it wasn't compiled yet + if self._top_function_lib is None: + self.compile() top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) n_inputs = len(self.get_input_variables()) @@ -882,10 +888,9 @@ def predict(self, x): inp = [np.asarray(x[i])] else: inp = [np.asarray(xj[i]) for xj in x] - argtuple = inp - argtuple += predictions - argtuple = tuple(argtuple) - top_function(*argtuple) + inp = [_inp if _inp.flags['C_CONTIGUOUS'] else np.ascontiguousarray(_inp) for _inp in inp] + + top_function(*inp, *predictions) output.append(predictions) # Convert to list of numpy arrays (one for each output) From c4733b2d388bc01364399a9710db70452fe2a035 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 30 May 2025 05:24:44 -0700 Subject: [PATCH 30/30] fix some qol changes --- hls4ml/model/graph.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 77b4bdc74a..33a066091b 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -867,12 +867,6 @@ def _compute_n_samples(self, x): return int(n_sample) def predict(self, x): - if isinstance(x, np.ndarray) and not x.flags['C_CONTIGUOUS']: - x = np.ascontiguousarray(x) - - # Compile the model if it wasn't compiled yet - if self._top_function_lib is None: - self.compile() top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) n_inputs = len(self.get_input_variables()) @@ -888,7 +882,7 @@ def predict(self, x): inp = [np.asarray(x[i])] else: inp = [np.asarray(xj[i]) for xj in x] - inp = [_inp if _inp.flags['C_CONTIGUOUS'] else np.ascontiguousarray(_inp) for _inp in inp] + inp = [np.ascontiguousarray(_inp) for _inp in inp] top_function(*inp, *predictions) output.append(predictions)