Skip to content

PyTorch Pool1d + Squeeze + io_stream + Catapult failure #1320

Open
@sei-jgwohlbier

Description

@sei-jgwohlbier

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

PyTorch Pool1d + Squeeze + io_stream + Catapult fails to compile

Details

Running the test case below using Pool1d, squeeze, io_stream, and Catapult backend fails with error shown.

Steps to Reproduce

Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [71bf4ae]
  3. Run conversion on code below.
from pathlib import Path

import numpy as np
import os
import pprint
import shutil
import torch
import torch.nn as nn
from torchinfo import summary

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model

test_root_path = Path(__file__).parent

if __name__ == "__main__":

    class test(nn.Module):
        def __init__(self, size_in, momentum=0.2):
            super().__init__()
            self.avgpool = nn.AvgPool1d(size_in)
            self.relu = nn.ReLU()

        def forward(self, x):
            z = self.avgpool(x)
            z = torch.squeeze(z)
            z = self.relu(z)
            return z

    n_in = 2
    n_out = 1
    size_in = 4
    n_batch = 3

    X_input_shape = (n_batch, n_in, size_in)

    model = test(size_in)

    io_type="io_stream"
    backend="Catapult"
    #backend="Vivado"
    output_dir = str(test_root_path / f"hls4mlprj_pool_squeeze_{backend}_io_stream")

    if os.path.exists(output_dir):
        print("delete project dir")
        shutil.rmtree(output_dir)
    if os.path.isfile(output_dir + ".tar.gz"):
        print("delete tar.gz")
        os.remove(output_dir + ".tar.gz")

    model.eval()
    print(model)
    summary(model, input_size=X_input_shape)

    X_input = np.random.rand(*X_input_shape)

    with torch.no_grad():
        pytorch_prediction = model(torch.Tensor(X_input))

    # X_input_hls is channels last
    X_input_hls = np.ascontiguousarray(X_input.transpose(0, 2, 1))

    # write tb data
    ipf = "./tb_input_features.dat"
    if os.path.isfile(ipf):
        os.remove(ipf)
    with open(ipf, "ab") as f:
        for x in X_input_hls:
            np.savetxt(f, x.flatten()[None]) # None trick for one line per

    opf = "./tb_output_predictions.dat"
    if os.path.isfile(opf):
        os.remove(opf)
    with open(opf, "ab") as f:
        for p in pytorch_prediction:
            np.savetxt(f, p.flatten()[None])

    default_precision="ap_fixed<16,6,AP_TRN_ZERO,AP_SAT_SYM>"
    default_precision="ap_fixed<32,12>"
    config = config_from_pytorch_model(model,
                                       input_shape=X_input_shape[-2:],
                                       backend=backend,
                                       channels_last_conversion="internal",
                                       default_precision=default_precision,
                                       granularity="name",
                                       transpose_outputs=False)
    config["Model"]["Strategy"] = "Resource"
    config["Model"]["BramFactor"] = 0

    pprint.pprint(config)

    hls_model = convert_from_pytorch_model(
        model,
        output_dir=output_dir,
        input_data_tb=ipf,
        output_data_tb=opf,
        backend=backend,
        hls_config=config,
        io_type=io_type,
    )

    hls_model.compile()

Expected behavior

Successful synthesis.

Actual behavior

Output:

delete project dir
delete tar.gz
test(
  (avgpool): AvgPool1d(kernel_size=(4,), stride=(4,), padding=(0,))
  (relu): ReLU()
)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
test                                     [3, 2]                    --
├─AvgPool1d: 1-1                         [3, 2, 1]                 --
├─ReLU: 1-2                              [3, 2]                    --
==========================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
{'InputShape': (2, 4),
 'LayerName': {'avgpool': {'ConvImplementation': 'LineBuffer',
                           'Precision': {'accum': 'auto', 'result': 'auto'},
                           'ReuseFactor': 1,
                           'Trace': False},
               'relu': {'Precision': {'result': 'auto',
                                      'table': 'fixed<18,8,TRN,WRAP,0>'},
                        'ReuseFactor': 1,
                        'TableSize': 1024,
                        'Trace': False},
               'squeeze': {'Precision': {'result': 'auto'}, 'Trace': False},
               'x': {'Precision': {'result': 'auto'}, 'Trace': False}},
 'Model': {'BramFactor': 0,
           'ChannelsLastConversion': 'internal',
           'Precision': {'default': 'ap_fixed<32,12>'},
           'ReuseFactor': 1,
           'Strategy': 'Resource',
           'TraceOutput': False,
           'TransposeOutputs': False},
 'PytorchModel': test(
  (avgpool): AvgPool1d(kernel_size=(4,), stride=(4,), padding=(0,))
  (relu): ReLU()
)}
Interpreting Model ...
Topology:
Layer name: avgpool, layer type: AveragePooling1D, input shape: [[None, 2, 4]]
Layer name: squeeze, layer type: Reshape, input shape: [[None, 2, 1]]
Layer name: relu, layer type: Activation, input shape: [[None, 2]]
Writing HLS project
Copying NNET files to local firmware directory
... copying AC ac_types headers from /home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/writer/../templates/catapult/ac_types/
... copying AC ac_math headers from /home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/writer/../templates/catapult/ac_math/
... copying AC ac_simutils headers from /home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/writer/../templates/catapult/ac_simutils/
Done
Traceback (most recent call last):
  File "/home/asicflow-user/work/ece-18-725/ResNet/python/test_pool_squeeze.py", line 102, in <module>
    hls_model.compile()
  File "/home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 801, in compile
    self._compile()
  File "/home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 820, in _compile
    self._top_function_lib = ctypes.cdll.LoadLibrary(lib_name)
  File "/home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/ctypes/__init__.py", line 452, in LoadLibrary
    return self._dlltype(name)
  File "/home/asicflow-user/miniconda3/envs/hls4ml/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/asicflow-user/work/ece-18-725/ResNet/python/hls4mlprj_pool_squeeze_Catapult_io_stream/firmware/myproject-72b5179b.so: cannot open shared object file: No such file or directory

Trying to run build_lib.sh shows the compilation error:

./build_lib.sh 
firmware/myproject.cpp: In function 'void myproject(ac_channel<nnet::array<ac_fixed<32, 12, true>, 2> >&, ac_channel<nnet::array<ac_fixed<32, 12, true>, 2> >&)':
firmware/myproject.cpp:43:53: error: cannot convert 'ac_channel<nnet::array<ac_fixed<32, 12, true>, 2> >' to 'nnet::array<ac_fixed<32, 12, true>, 2>*'
   43 |     nnet::transpose_2d<layer2_t, layer5_t, config5>(layer2_out, layer5_out); // transpose_input_for_squeeze
      |                                                     ^~~~~~~~~~
      |                                                     |
      |                                                     ac_channel<nnet::array<ac_fixed<32, 12, true>, 2> >
In file included from firmware/parameters.h:12,
                 from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_array.h:16:26: note:   initializing argument 1 of 'void nnet::transpose_2d(data_T*, res_T*) [with data_T = nnet::array<ac_fixed<32, 12, true>, 2>; res_T = nnet::array<ac_fixed<32, 12, true>, 1>; CONFIG_T = config5]'
   16 | void transpose_2d(data_T data[CONFIG_T::height * CONFIG_T::width], res_T data_t[CONFIG_T::height * CONFIG_T::width]) {
      |                   ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
firmware/nnet_utils/nnet_array.h: In instantiation of 'void nnet::transpose_2d(data_T*, res_T*) [with data_T = nnet::array<ac_fixed<32, 12, true>, 2>; res_T = nnet::array<ac_fixed<32, 12, true>, 1>; CONFIG_T = config5]':
firmware/myproject.cpp:43:75:   required from here
firmware/nnet_utils/nnet_array.h:21:46: error: no match for 'operator=' (operand types are 'nnet::array<ac_fixed<32, 12, true>, 1>' and 'nnet::array<ac_fixed<32, 12, true>, 2>')
   21 |             data_t[j * CONFIG_T::height + i] = data[i * CONFIG_T::width + j];
      |             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~
In file included from firmware/defines.h:4,
                 from firmware/myproject.h:8,
                 from firmware/myproject.cpp:3:
firmware/nnet_utils/nnet_types.h:21:12: note: candidate: 'nnet::array<T, N>& nnet::array<T, N>::operator=(const nnet::array<T, N>&) [with T = ac_fixed<32, 12, true>; unsigned int N = 1]'
   21 |     array &operator=(const array &other) {
      |            ^~~~~~~~
firmware/nnet_utils/nnet_types.h:21:35: note:   no known conversion for argument 1 from 'nnet::array<ac_fixed<32, 12, true>, 2>' to 'const nnet::array<ac_fixed<32, 12, true>, 1>&'
   21 |     array &operator=(const array &other) {
      |                      ~~~~~~~~~~~~~^~~~~
g++: error: myproject.o: No such file or directory

Optional

Possible fix

N/A

Additional context

Related bug: #1054

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions