Skip to content

Commit ec937d2

Browse files
authored
Use unique name for hls4ml layer in pytorch extension api test (#1255)
* use unique name for hls4ml layer in pytorch extension api test * add annotation and update docs for pytorch extension API
1 parent 887a17b commit ec937d2

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

docs/advanced/extension.rst

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ Extension API
55
``hls4ml`` natively supports a large number of neural network layers.
66
But what if a desired layer is not supported?
77
If it is standard enough and its implementation would benefit the community as a whole, we would welcome a contribution to add it to the standard set of supported layers.
8-
However, if it is a somewhat niche custom layer, there is another approach we can take to extend hls4ml through the *extension API*.
8+
However, if it is a somewhat niche custom layer, there is another approach we can take to extend hls4ml through the *extension API*. This feature is support for both keras and pytorch layers.
99

10-
This documentation will walk through a complete `complete end-to-end example <https://github.com/fastmachinelearning/hls4ml/blob/main/test/pytest/test_extensions.py>`_, which is part of our testing suite.
10+
Complete end-to-end examples are available for both `keras <https://github.com/fastmachinelearning/hls4ml/blob/main/test/pytest/test_extensions.py>`_ and `pytorch <https://github.com/fastmachinelearning/hls4ml/blob/main/test/pytest/test_extensions_pytorch.py>`_, which are part of our testing suite. The description here uses the keras example.
1111
To implement a custom layer in ``hls4ml`` with the extension API, the required components are:
1212

1313
* Your custom layer class
@@ -18,9 +18,6 @@ To implement a custom layer in ``hls4ml`` with the extension API, the required c
1818
* Function config template
1919
* Registration of layer, source code, and templates
2020

21-
.. note::
22-
currently, then extension API supports keras models. Support for pytorch models is in development.
23-
2421
Complete example
2522
================
2623

test/pytest/test_extensions_pytorch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def forward(self, inputs):
2222

2323

2424
# hls4ml layer implementation
25-
class HReverse(hls4ml.model.layers.Layer):
25+
# Note that the `Torch` suffix is added here to avoid clashes with other tests and not mandatory
26+
class HReverseTorch(hls4ml.model.layers.Layer):
2627
'''hls4ml implementation of a hypothetical custom layer'''
2728

2829
def initialize(self):
@@ -34,10 +35,10 @@ def initialize(self):
3435

3536
# hls4ml optimizer to remove duplicate optimizer
3637
class RemoveDuplicateReverse(hls4ml.model.optimizer.OptimizerPass):
37-
'''OptimizerPass to remove consecutive HReverse layers.'''
38+
'''OptimizerPass to remove consecutive HReverseTorch layers.'''
3839

3940
def match(self, node):
40-
return isinstance(node, HReverse) and isinstance(node.get_input_node(), HReverse)
41+
return isinstance(node, HReverseTorch) and isinstance(node.get_input_node(), HReverseTorch)
4142

4243
def transform(self, model, node):
4344
first = node.get_input_node()
@@ -53,7 +54,7 @@ def parse_reverse_layer(operation, layer_name, input_names, input_shapes, node,
5354
assert operation == 'TReverse'
5455

5556
layer = {}
56-
layer['class_name'] = 'HReverse'
57+
layer['class_name'] = 'HReverseTorch'
5758
layer['name'] = layer_name
5859
layer['n_in'] = input_shapes[0][1]
5960

@@ -75,7 +76,7 @@ def parse_reverse_layer(operation, layer_name, input_names, input_shapes, node,
7576

7677
class HReverseConfigTemplate(hls4ml.backends.template.LayerConfigTemplate):
7778
def __init__(self):
78-
super().__init__(HReverse)
79+
super().__init__(HReverseTorch)
7980
self.template = rev_config_template
8081

8182
def format(self, node):
@@ -85,7 +86,7 @@ def format(self, node):
8586

8687
class HReverseFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate):
8788
def __init__(self):
88-
super().__init__(HReverse, include_header=rev_include_list)
89+
super().__init__(HReverseTorch, include_header=rev_include_list)
8990
self.template = rev_function_template
9091

9192
def format(self, node):
@@ -126,7 +127,7 @@ def register_custom_layer():
126127
hls4ml.converters.register_pytorch_layer_handler('TReverse', parse_reverse_layer)
127128

128129
# Register the hls4ml's IR layer
129-
hls4ml.model.layers.register_layer('HReverse', HReverse)
130+
hls4ml.model.layers.register_layer('HReverseTorch', HReverseTorch)
130131

131132

132133
@pytest.mark.parametrize('backend_id', ['Vivado', 'Vitis', 'Quartus'])
@@ -136,7 +137,7 @@ def test_extensions_pytorch(tmp_path, backend_id):
136137
ip_flow = hls4ml.model.flow.get_flow(backend.get_default_flow())
137138
# Add the pass into the main optimization flow
138139
optimize_flow = [flow for flow in ip_flow.requires if ':optimize' in flow][0]
139-
optmizer_name = f'{backend_id.lower()}:remove_duplicate_reverse'
140+
optmizer_name = f'{backend_id.lower()}:remove_duplicate_reverse_torch'
140141
backend.register_pass(optmizer_name, RemoveDuplicateReverse, flow=optimize_flow)
141142

142143
# Register template passes for the given backend

0 commit comments

Comments
 (0)