@@ -895,6 +895,18 @@ def forward(self, x, y):
895
895
return torch .einsum ('bij,bjk->bik' , x , y )
896
896
897
897
898
+ class EinsumSingleInput (nn .Module ):
899
+ def __init__ (self , input_dim = 8 ):
900
+ super ().__init__ ()
901
+ self .input_dim = input_dim
902
+ self .linear = nn .Linear (self .input_dim , self .input_dim )
903
+
904
+ def forward (self , x ):
905
+ """using torch einsum to get the dot product"""
906
+ out = self .linear (x )
907
+ return torch .einsum ("ij,ij->i" , out , out )
908
+
909
+
898
910
@pytest .mark .parametrize ('backend' , ['Vivado' , 'Vitis' ])
899
911
@pytest .mark .parametrize ('io_type' , ['io_parallel' ])
900
912
def test_einsum_outer_product (backend , io_type ):
@@ -953,3 +965,32 @@ def test_einsum_batch_matmul(backend, io_type):
953
965
hls_prediction = np .reshape (hls_model .predict ([X_input , Y_input ]), pytorch_prediction .shape )
954
966
955
967
np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 1e-2 , atol = 0.01 )
968
+
969
+
970
+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Vitis' ])
971
+ @pytest .mark .parametrize ('io_type' , ['io_parallel' ])
972
+ def test_einsum_single_input (backend , io_type ):
973
+
974
+ model = EinsumSingleInput ()
975
+ model .eval ()
976
+
977
+ X_input = np .random .rand (3 , 8 )
978
+
979
+ pytorch_prediction = model (torch .Tensor (X_input )).detach ().numpy ()
980
+
981
+ config = config_from_pytorch_model (
982
+ model ,
983
+ [(None , 8 )],
984
+ default_precision = 'ap_fixed<16,6>' ,
985
+ channels_last_conversion = "internal" ,
986
+ transpose_outputs = False ,
987
+ )
988
+ output_dir = str (test_root_path / f'hls4mlprj_pytorch_einsum_single_input_{ backend } _{ io_type } ' )
989
+
990
+ hls_model = convert_from_pytorch_model (model , hls_config = config , output_dir = output_dir , backend = backend , io_type = io_type )
991
+
992
+ hls_model .compile ()
993
+
994
+ hls_prediction = np .reshape (hls_model .predict (X_input ), pytorch_prediction .shape )
995
+
996
+ np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 1e-2 , atol = 0.01 )
0 commit comments