Skip to content

Add functionality to use granularity option also for pytorch models #1051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ sphinx_github_changelog
sphinx_rtd_theme
tensorflow<=2.15
toposort>=1.5.0
torch
5 changes: 2 additions & 3 deletions hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler
from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401
from hls4ml.model import ModelGraph
from hls4ml.utils.config import create_config
from hls4ml.utils.symbolic_utils import LUTFunction
Expand Down Expand Up @@ -238,7 +239,6 @@ def convert_from_keras_model(

def convert_from_pytorch_model(
model,
input_shape,
output_dir='my-hls-test',
project_name='myproject',
input_data_tb=None,
Expand All @@ -251,7 +251,6 @@ def convert_from_pytorch_model(

Args:
model: PyTorch model to convert.
input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None
output_dir (str, optional): Output directory of the generated HLS project. Defaults to 'my-hls-test'.
project_name (str, optional): Name of the HLS project. Defaults to 'myproject'.
input_data_tb (str, optional): String representing the path of input data in .npy or .dat format that will be
Expand Down Expand Up @@ -293,7 +292,6 @@ def convert_from_pytorch_model(
config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs)

config['PytorchModel'] = model
config['InputShape'] = input_shape
config['InputData'] = input_data_tb
config['OutputPredictions'] = output_data_tb
config['HLSConfig'] = {}
Expand All @@ -303,6 +301,7 @@ def convert_from_pytorch_model(

model_config = hls_config.get('Model', None)
config['HLSConfig']['Model'] = _check_model_config(model_config)
config['InputShape'] = hls_config.get('InputShape', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't have a default of None for input shape because this will propagate further and then lead to errors which won't make it clear what is the original cause

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we're now raising an exception if that parameter is not found. No point in continuing.


_check_hls_config(config, hls_config)

Expand Down
9 changes: 8 additions & 1 deletion hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def decorator(function):
# ----------------------------------------------------------------


def pytorch_to_hls(config):
def parse_pytorch_model(config):
"""Convert PyTorch model to hls4ml ModelGraph.

Args:
Expand Down Expand Up @@ -351,6 +351,13 @@ def pytorch_to_hls(config):
if len(input_layers) == 0:
input_layers = None

# print('Creating HLS model')
# hls_model = ModelGraph(config, layer_list, inputs=input_layers)
return layer_list, input_layers


def pytorch_to_hls(config):
layer_list, input_layers = parse_pytorch_model(config)
print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, inputs=input_layers)
return hls_model
73 changes: 73 additions & 0 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def make_layer_config(layer):

def config_from_pytorch_model(
model,
input_shape,
granularity='model',
backend=None,
default_precision='ap_fixed<16,6>',
Expand All @@ -284,6 +285,7 @@ def config_from_pytorch_model(

Args:
model: PyTorch model
input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a nice opportunity here to get rid of the first None in the shape. Just ask users to not put it, but then put it in ourselves 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, that was always very ugly. Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think we need to be careful about the batch dimensions. For ONNX it's usually just 1 (but can really be any value, e.g. 10), not None. I think we try to get rid of the batch dimension pretty quickly.

granularity (str, optional): Granularity of the created config. Defaults to 'model'.
Can be set to 'model', 'type' and 'layer'.

Expand Down Expand Up @@ -321,6 +323,77 @@ def config_from_pytorch_model(
model_config['Strategy'] = 'Latency'

config['Model'] = model_config
config['PytorchModel'] = model
config['InputShape'] = input_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a check if the passed input shape makes sense. Later on if it doesn't it's not easy to figure out why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been thinking about that. In Pytorch it seems like the exact input shape is not determined by the model architecture, so it's not possible to completely infer it during parsing. We we can still check some general features, like the number of dimensions of the input, based on the type of the first layer. I'll implement something along those lines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant much simpler, just check if input shape is a list/iterable and not None. Because technically you can pass None and you'll get a strange error later. Since this is one of the top user-facing functions, it makes sense we do some user input validation here so we have more confidence later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. I changed it to enforce that the input shape is a tuple for a single input or a list of tuples for multiple inputs.


if granularity.lower() not in ['model', 'type', 'name']:
raise Exception(
f'Invalid configuration granularity specified, expected "model", "type" or "name" got "{granularity}"'
)

if backend is not None:
backend = hls4ml.backends.get_backend(backend)

(
layer_list,
_,
) = hls4ml.converters.parse_pytorch_model(config)

def make_layer_config(layer):
cls_name = layer['class_name']
if 'config' in layer.keys():
if 'activation' in layer['config'].keys():
if layer['config']['activation'] == 'softmax':
cls_name = 'Softmax'

layer_cls = hls4ml.model.layers.layer_map[cls_name]
if backend is not None:
layer_cls = backend.create_layer_class(layer_cls)

layer_config = {}

config_attrs = [a for a in layer_cls.expected_attributes if a.configurable]
for attr in config_attrs:
if isinstance(attr, hls4ml.model.attributes.TypeAttribute):
precision_cfg = layer_config.setdefault('Precision', {})
name = attr.name
if name.endswith('_t'):
name = name[:-2]
if attr.default is None:
precision_cfg[name] = default_precision
else:
precision_cfg[name] = str(attr.default)
else:
if attr.default is not None:
layer_config[attr.config_name] = attr.default

if layer['class_name'] == 'Input':
dtype = layer['config']['dtype']
if dtype.startswith('int') or dtype.startswith('uint'):
typename = dtype[: dtype.index('int') + 3]
width = int(dtype[dtype.index('int') + 3 :])
layer_config['Precision']['result'] = f'ap_{typename}<{width}>'
# elif bool, q[u]int, ...

return layer_config

if granularity.lower() == 'type':
type_config = {}
for layer in layer_list:
if layer['class_name'] in type_config:
continue
layer_config = make_layer_config(layer)
type_config[layer['class_name']] = layer_config

config['LayerType'] = type_config

elif granularity.lower() == 'name':
name_config = {}
for layer in layer_list:
layer_config = make_layer_config(layer)
name_config[layer['name']] = layer_config

config['LayerName'] = name_config

return config

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires =
tabulate
tensorflow
tensorflow-model-optimization<=0.7.5
torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? This will force all installations of hls4ml to have both tf and torch, and since these have huge dependencies themselves it may bork user's environment with the partial updates. Ideally we would have hls4ml with optional extras installed to support various frontends (e.g., pip install hls4ml[torch]), but this may not be easy with the current codebase.

Also, if we go down this route, we should remove the checks for existence of pytorch on the system in converters/__init__.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get away with not making torch a requirement by moving the from hls4ml.converters.pytorch_to_hls import parse_pytorch_model from the converters/__init__.py directly into config.py. The only reason why I had put the import there is that it's there for keras. The question is, is there a reason to place these imports into the converters/__init__.py? If not, I can just move it and get rid of the requirement.

python_requires = >=3.10
include_package_data = True
scripts = scripts/hls4ml
Expand Down
33 changes: 22 additions & 11 deletions test/pytest/test_backend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
convert_fn = hls4ml.converters.convert_from_keras_model
else:
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.ReLU())
config = hls4ml.utils.config_from_pytorch_model(model)
config = hls4ml.utils.config_from_pytorch_model(model, input_shape=(None, 1))
convert_fn = hls4ml.converters.convert_from_pytorch_model

if clock_unc is not None:
Expand All @@ -42,16 +42,27 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
test_dir = f'hls4mlprj_backend_config_{framework}_{backend}_part_{part}_period_{clock_period}_unc_{unc_str}'
output_dir = test_root_path / test_dir

hls_model = convert_fn(
model,
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)
if framework == "keras":
hls_model = convert_fn(
model,
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)
else:
hls_model = convert_fn(
model,
hls_config=config,
output_dir=str(output_dir),
backend=backend,
part=part,
clock_period=clock_period,
clock_uncertainty=clock_unc,
)

hls_model.write()

Expand Down
15 changes: 10 additions & 5 deletions test/pytest/test_batchnorm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def test_batchnorm(data, backend, io_type):

default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'

config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name')
config = hls4ml.utils.config_from_pytorch_model(
model, (None, in_shape), default_precision=default_precision, granularity='name'
)
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
)
hls_model.compile()

Expand Down Expand Up @@ -94,17 +96,20 @@ def test_batchnorm_fusion(fusion_data, backend, io_type):
# We do not have an implementation of a transpose for io_stream, need to transpose inputs and outputs outside of hls4ml
if io_type == 'io_stream':
fusion_data = np.ascontiguousarray(fusion_data.transpose(0, 2, 1))
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='internal', transpose_outputs=False)
config = hls4ml.utils.config_from_pytorch_model(
model, (None, n_in, size_in_height), channels_last_conversion='internal', transpose_outputs=False
)
else:
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='full', transpose_outputs=True)
config = hls4ml.utils.config_from_pytorch_model(
model, (None, n_in, size_in_height), channels_last_conversion='full', transpose_outputs=True
)

config['Model']['Strategy'] = 'Resource'

# conversion
output_dir = str(test_root_path / f'hls4mlprj_block_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model,
(None, n_in, size_in_height),
hls_config=config,
output_dir=output_dir,
backend=backend,
Expand Down
7 changes: 5 additions & 2 deletions test/pytest/test_merge_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def test_merge(merge_op, io_type, backend):

batch_input_shape = (None,) + input_shape
config = hls4ml.utils.config_from_pytorch_model(
model, default_precision='ap_fixed<32,16>', channels_last_conversion="internal", transpose_outputs=False
model,
[batch_input_shape, batch_input_shape],
default_precision='ap_fixed<32,16>',
channels_last_conversion="internal",
transpose_outputs=False,
)
output_dir = str(test_root_path / f'hls4mlprj_merge_pytorch_{merge_op}_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_pytorch_model(
model,
[batch_input_shape, batch_input_shape],
hls_config=config,
output_dir=output_dir,
io_type=io_type,
Expand Down
Loading
Loading