@@ -71,15 +71,13 @@ def build(self, model, target="all"):
71
71
raise Exception ("Currently untested on non-Linux OS" )
72
72
73
73
def _numpy_to_dat (self , model , x ):
74
- if len (self .get_input_variables ()) != 1 :
75
- raise Exception ("Currently unsupported for multi-input projects" )
74
+ if len (model .get_input_variables ()) != 1 :
75
+ raise Exception ("Currently unsupported for multi-input/output projects" )
76
76
77
77
# Verify numpy array of correct shape
78
- expected_shape = (np .newaxis , model .get_input_variables ()[0 ].size ())
79
- print (f"Expected model input shape: { expected_shape } " )
80
- print (f"Give numpy array shape: { x .shape } " )
81
- if expected_shape != x .shape :
82
- raise Exception (f'Input shape mismatch, got { x .shape } , expected { expected_shape } ' )
78
+ expected_shape = model .get_input_variables ()[0 ].size ()
79
+ if expected_shape != x .shape [- 1 ]:
80
+ raise Exception (f'Input shape mismatch, got { x .shape } , expected (_, { expected_shape } )' )
83
81
84
82
# Write to tb_data/tb_input_features.dat
85
83
input_dat = open (f'{ model .config .get_output_dir ()} /tb_data/tb_input_features.dat' , 'w' )
@@ -90,16 +88,8 @@ def _numpy_to_dat(self, model, x):
90
88
91
89
def _dat_to_numpy (self , model ):
92
90
expected_shape = model .get_output_variables ()[0 ].size ()
93
- y = np .array ([], dtype = float ).reshape (0 , expected_shape )
94
-
95
- output_dat = open (f'{ model .config .get_output_dir ()} /tb_data/hw_results.dat' , 'r' )
96
- for line in output_dat .readlines ():
97
- data = [list (map (float , line .strip ().split ()))]
98
- if len (data ) != expected_shape :
99
- raise Exception ('Error in output file. Does not match expected model output shape.' )
100
- y = np .concatenate (y , np .array (data )[np .newaxis , :], axis = 0 )
101
- output_dat .close ()
102
-
91
+ output_file = f'{ model .config .get_output_dir ()} /tb_data/hw_results.dat'
92
+ y = np .loadtxt (output_file , dtype = float ).reshape (- 1 , expected_shape )
103
93
return y
104
94
105
95
def hardware_predict (self , model , x ):
0 commit comments