Skip to content

Commit fab6844

Browse files
committed
fix bug load_from_file #148
1 parent 498a0b5 commit fab6844

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/thundersvm/thundersvm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,19 +465,19 @@ def load_from_file(self, path):
465465
n_feature = (c_int * 1)()
466466
thundersvm.get_sv_max_index(c_void_p(self.model), n_feature)
467467
self.n_features = n_feature[0]
468-
469468
csr_row = (c_int * (self.n_sv + 1))()
470469
csr_col = (c_int * (self.n_sv * self.n_features))()
471470
csr_data = (c_float * (self.n_sv * self.n_features))()
472471
data_size = (c_int * 1)()
473-
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, c_void_p(self.model))
472+
sv_indices = (c_int * self.n_sv)()
473+
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, sv_indices, c_void_p(self.model))
474474
self.row = np.array([csr_row[index] for index in range(0, self.n_sv + 1)])
475475
self.col = np.array([csr_col[index] for index in range(0, data_size[0])])
476476
self.data = np.array([csr_data[index] for index in range(0, data_size[0])])
477477
self.support_vectors_ = sp.csr_matrix((self.data, self.col, self.row))
478478
# if self._sparse == False:
479479
# self.support_vectors_ = self.support_vectors_.toarray(order = 'C')
480-
480+
self.support_ = np.array([sv_indices[index] for index in range(0, self.n_sv)]).astype(int)
481481
dual_coef = (c_float * ((self.n_classes - 1) * self.n_sv))()
482482
thundersvm.get_coef(dual_coef, self.n_classes, self.n_sv, c_void_p(self.model))
483483
self.dual_coef_ = np.array([dual_coef[index] for index in range(0, (self.n_classes - 1) * self.n_sv)]).astype(float)

0 commit comments

Comments
 (0)