@@ -877,3 +877,79 @@ def forward(self, x):
877
877
rtol = 0
878
878
atol = 5.0e-2
879
879
np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = rtol , atol = atol )
880
+
881
+
882
+ class EinsumOuterProduct (nn .Module ):
883
+ def __init__ (self ):
884
+ super ().__init__ ()
885
+
886
+ def forward (self , x , y ):
887
+ return torch .einsum ('bi,bj->bij' , x , y )
888
+
889
+
890
+ class EinsumBatchMatMul (nn .Module ):
891
+ def __init__ (self ):
892
+ super ().__init__ ()
893
+
894
+ def forward (self , x , y ):
895
+ return torch .einsum ('bij,bjk->bik' , x , y )
896
+
897
+
898
+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Vitis' ])
899
+ @pytest .mark .parametrize ('io_type' , ['io_parallel' ])
900
+ def test_einsum_outer_product (backend , io_type ):
901
+
902
+ model = EinsumOuterProduct ()
903
+ model .eval ()
904
+
905
+ X_input = np .random .rand (3 , 4 )
906
+ Y_input = np .random .rand (3 , 5 )
907
+
908
+ pytorch_prediction = model (torch .Tensor (X_input ), torch .Tensor (Y_input )).detach ().numpy ()
909
+
910
+ config = config_from_pytorch_model (
911
+ model ,
912
+ [(None , 4 ), (None , 5 )],
913
+ default_precision = 'ap_fixed<16,6>' ,
914
+ channels_last_conversion = "internal" ,
915
+ transpose_outputs = False ,
916
+ )
917
+ output_dir = str (test_root_path / f'hls4mlprj_pytorch_einsum_outer_product_{ backend } _{ io_type } ' )
918
+
919
+ hls_model = convert_from_pytorch_model (model , hls_config = config , output_dir = output_dir , backend = backend , io_type = io_type )
920
+
921
+ hls_model .compile ()
922
+
923
+ hls_prediction = np .reshape (hls_model .predict ([X_input , Y_input ]), pytorch_prediction .shape )
924
+
925
+ np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 1e-2 , atol = 0.01 )
926
+
927
+
928
+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Vitis' ])
929
+ @pytest .mark .parametrize ('io_type' , ['io_parallel' ])
930
+ def test_einsum_batch_matmul (backend , io_type ):
931
+
932
+ model = EinsumBatchMatMul ()
933
+ model .eval ()
934
+
935
+ X_input = np .random .rand (3 , 2 , 5 )
936
+ Y_input = np .random .rand (3 , 5 , 4 )
937
+
938
+ pytorch_prediction = model (torch .Tensor (X_input ), torch .Tensor (Y_input )).detach ().numpy ()
939
+
940
+ config = config_from_pytorch_model (
941
+ model ,
942
+ [(None , 2 , 5 ), (None , 5 , 4 )],
943
+ default_precision = 'ap_fixed<16,6>' ,
944
+ channels_last_conversion = "internal" ,
945
+ transpose_outputs = False ,
946
+ )
947
+ output_dir = str (test_root_path / f'hls4mlprj_pytorch_einsum_batch_matmul_{ backend } _{ io_type } ' )
948
+
949
+ hls_model = convert_from_pytorch_model (model , hls_config = config , output_dir = output_dir , backend = backend , io_type = io_type )
950
+
951
+ hls_model .compile ()
952
+
953
+ hls_prediction = np .reshape (hls_model .predict ([X_input , Y_input ]), pytorch_prediction .shape )
954
+
955
+ np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 1e-2 , atol = 0.01 )
0 commit comments