@@ -810,3 +810,65 @@ def forward(self, x):
810
810
hls_prediction = hls_model .predict (hls_input ).flatten ()
811
811
812
812
np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 0 , atol = 5e-2 )
813
+
814
+
815
+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Vitis' , 'Quartus' ])
816
+ @pytest .mark .parametrize ('io_type' , ['io_parallel' , 'io_stream' ])
817
+ def test_view (backend , io_type ):
818
+
819
+ class TestModel (nn .Module ):
820
+ def __init__ (self , n_in , n_out , size_in ):
821
+ super ().__init__ ()
822
+ self .view_mult = n_out * size_in
823
+
824
+ self .conv1 = nn .Conv1d (
825
+ n_in ,
826
+ n_out ,
827
+ kernel_size = 3 ,
828
+ padding = 1 ,
829
+ bias = False ,
830
+ )
831
+
832
+ def forward (self , x ):
833
+ z = self .conv1 (x )
834
+ z = z .view (- 1 , self .view_mult )
835
+ return z
836
+
837
+ n_in = 2
838
+ n_out = 4
839
+ size_in = 128
840
+ n_batch = 100
841
+
842
+ model = TestModel (n_in , n_out , size_in )
843
+ model = model .to (memory_format = torch .channels_last )
844
+ model .eval ()
845
+
846
+ X_input = np .random .rand (n_batch , n_in , size_in )
847
+ pytorch_prediction = model (torch .Tensor (X_input )).detach ().numpy ()
848
+
849
+ # X_input is channels last
850
+ X_input = np .ascontiguousarray (X_input .transpose (0 , 2 , 1 ))
851
+ config = config_from_pytorch_model (model , inputs_channel_last = True , transpose_outputs = False )
852
+
853
+ output_dir = str (test_root_path / f'hls4mlprj_pytorch_view_{ backend } _{ io_type } ' )
854
+ hls_model = convert_from_pytorch_model (
855
+ model ,
856
+ (None , n_in , size_in ),
857
+ hls_config = config ,
858
+ output_dir = output_dir ,
859
+ backend = backend ,
860
+ io_type = io_type ,
861
+ )
862
+
863
+ hls_model .compile ()
864
+
865
+ # reshape hls prediction to channels last, then transpose, then reshape
866
+ # to match .view
867
+ hls_prediction = np .reshape (
868
+ np .transpose (np .reshape (hls_model .predict (X_input ), (n_batch , size_in , n_out )), (0 , 2 , 1 )),
869
+ (n_batch , size_in * n_out ),
870
+ )
871
+
872
+ rtol = 0
873
+ atol = 5.0e-2
874
+ np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = rtol , atol = atol )
0 commit comments