@@ -29,13 +29,14 @@ def get_low_dim_basis(inf_matrix: InfluenceMatrix, compression: str = 'wavelet')
29
29
:type compression: str
30
30
:return: a list that contains the dimension reduction basis in the format of array(float)
31
31
"""
32
+ low_dim_basis = {}
32
33
num_of_beams = len (inf_matrix .beamlets_dict )
33
- low_dim_basis = list ()
34
+ num_of_beamlets = inf_matrix . beamlets_dict [ num_of_beams - 1 ][ 'end_beamlet' ] + 1
34
35
beam_id = [inf_matrix .beamlets_dict [i ]['beam_id' ] for i in range (num_of_beams )]
35
36
beamlets = inf_matrix .get_bev_2d_grid (beam_id = beam_id )
36
37
index_position = list ()
37
- num_of_beamlets = inf_matrix .beamlets_dict [num_of_beams - 1 ]['end_beamlet' ] + 1
38
38
for ind in range (num_of_beams ):
39
+ low_dim_basis [beam_id [ind ]] = []
39
40
for i in range (inf_matrix .beamlets_dict [ind ]['start_beamlet' ],
40
41
inf_matrix .beamlets_dict [ind ]['end_beamlet' ] + 1 ):
41
42
index_position .append ((np .where (beamlets [ind ] == i )[0 ][0 ], np .where (beamlets [ind ] == i )[1 ][0 ]))
@@ -64,9 +65,11 @@ def get_low_dim_basis(inf_matrix: InfluenceMatrix, compression: str = 'wavelet')
64
65
inf_matrix .beamlets_dict [b ]['end_beamlet' ] + 1 ):
65
66
approximation [ind ] = approximation_coeffs [index_position [ind ]]
66
67
horizontal [ind ] = horizontal_coeffs [index_position [ind ]]
67
- low_dim_basis .append (np .stack (( approximation , horizontal )))
68
+ low_dim_basis [ beam_id [ b ]] .append (np .transpose ( np . stack ([ approximation , horizontal ] )))
68
69
beamlet_2d_grid [row ][col ] = 0
69
- low_dim_basis = np .transpose (np .concatenate (low_dim_basis , axis = 0 ))
70
- u , s , vh = scipy .sparse .linalg .svds (low_dim_basis , k = min (low_dim_basis .shape [0 ], low_dim_basis .shape [1 ]) - 1 )
71
- ind = np .where (s > 0.0001 )
72
- return u [:, ind [0 ]]
70
+ for b in beam_id :
71
+ low_dim_basis [b ] = np .concatenate (low_dim_basis [b ], axis = 1 )
72
+ u , s , vh = scipy .sparse .linalg .svds (low_dim_basis [b ], k = min (low_dim_basis [b ].shape [0 ], low_dim_basis [b ].shape [1 ]) - 1 )
73
+ ind = np .where (s > 0.0001 )
74
+ low_dim_basis [b ] = u [:, ind [0 ]]
75
+ return np .concatenate ([low_dim_basis [b ] for b in beam_id ], axis = 1 )
0 commit comments