Skip to content

Commit 014c8d2

Browse files
authored
Merge pull request #124 from AmperesAvengement/remove_dynamic_cast_from_headers
Remove dynamic_cast from headers
2 parents b964011 + 4055b25 commit 014c8d2

File tree

6 files changed

+90
-77
lines changed

6 files changed

+90
-77
lines changed

src/BLR/BLRMatrixMPI.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,24 @@ namespace strumpack {
139139
return m;
140140
}
141141

142+
template<typename scalar_t> DenseTile<scalar_t>&
143+
BLRMatrixMPI<scalar_t>::ltile_dense(std::size_t i, std::size_t j) {
144+
assert(i < rowblockslocal() && j < colblockslocal());
145+
assert(dynamic_cast<DenseTile<scalar_t>*>
146+
(blocks_[i+j*rowblockslocal()].get()));
147+
return *static_cast<DenseTile<scalar_t>*>
148+
(blocks_[i+j*rowblockslocal()].get());
149+
}
150+
151+
template<typename scalar_t> const DenseTile<scalar_t>&
152+
BLRMatrixMPI<scalar_t>::ltile_dense(std::size_t i, std::size_t j) const {
153+
assert(i < rowblockslocal() && j < colblockslocal());
154+
assert(dynamic_cast<const DenseTile<scalar_t>*>
155+
(blocks_[i+j*rowblockslocal()].get()));
156+
return *static_cast<const DenseTile<scalar_t>*>
157+
(blocks_[i+j*rowblockslocal()].get());
158+
}
159+
142160
template<typename scalar_t>
143161
typename RealType<scalar_t>::value_type
144162
BLRMatrixMPI<scalar_t>::normF() const {

src/BLR/BLRMatrixMPI.hpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,20 +272,8 @@ namespace strumpack {
272272
return *blocks_[i+j*rowblockslocal()].get();
273273
}
274274

275-
DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) {
276-
assert(i < rowblockslocal() && j < colblockslocal());
277-
assert(dynamic_cast<DenseTile<scalar_t>*>
278-
(blocks_[i+j*rowblockslocal()].get()));
279-
return *static_cast<DenseTile<scalar_t>*>
280-
(blocks_[i+j*rowblockslocal()].get());
281-
}
282-
const DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) const {
283-
assert(i < rowblockslocal() && j < colblockslocal());
284-
assert(dynamic_cast<const DenseTile<scalar_t>*>
285-
(blocks_[i+j*rowblockslocal()].get()));
286-
return *static_cast<const DenseTile<scalar_t>*>
287-
(blocks_[i+j*rowblockslocal()].get());
288-
}
275+
DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j);
276+
const DenseTile<scalar_t>& ltile_dense(std::size_t i, std::size_t j) const;
289277

290278
std::unique_ptr<BLRTile<scalar_t>>&
291279
block(std::size_t i, std::size_t j) {

src/HSS/HSSMatrix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ namespace strumpack {
138138
(new HSSMatrix<scalar_t>(*this));
139139
}
140140

141+
template<typename scalar_t> const HSSMatrix<scalar_t>*
142+
HSSMatrix<scalar_t>::child(int c) const {
143+
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
144+
}
145+
146+
template<typename scalar_t> HSSMatrix<scalar_t>*
147+
HSSMatrix<scalar_t>::child(int c) {
148+
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
149+
}
150+
141151
template<typename scalar_t> void
142152
HSSMatrix<scalar_t>::delete_trailing_block() {
143153
B01_.clear();

src/HSS/HSSMatrix.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,15 @@ namespace strumpack {
191191
* matrix. The value of c should be 0 or 1, and this HSS matrix
192192
* should not be a leaf!
193193
*/
194-
const HSSMatrix<scalar_t>* child(int c) const {
195-
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
196-
}
194+
const HSSMatrix<scalar_t>* child(int c) const;
197195

198196
/**
199197
* Return a raw (non-owning) pointer to child c of this HSS
200198
* matrix. A child of an HSS matrix is itself an HSS matrix. The
201199
* value of c should be 0 or 1, and this HSS matrix should not
202200
* be a leaf!
203201
*/
204-
HSSMatrix<scalar_t>* child(int c) {
205-
return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
206-
}
202+
HSSMatrix<scalar_t>* child(int c);
207203

208204
/**
209205
* Initialize this HSS matrix as the compressed HSS

src/HSS/HSSMatrixMPI.Schur.hpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -39,63 +39,6 @@ namespace strumpack {
3939
* Phi = (D0^{-1} * U0 * B01 * V1big^C)^C
4040
* = V1big * (D0^{-1} * U0 * B01)^C
4141
*/
42-
template<typename scalar_t> void HSSMatrixMPI<scalar_t>::Schur_update
43-
(DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const {
44-
if (this->leaf()) return;
45-
auto ch0 = child(0);
46-
auto ch1 = child(1);
47-
DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank());
48-
if (auto ch0mpi =
49-
dynamic_cast<const HSSMatrixMPI<scalar_t>*>(child(0))) {
50-
DistM_t chDU;
51-
if (ch0mpi->active()) {
52-
chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_);
53-
STRUMPACK_SCHUR_FLOPS
54-
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
55-
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols()));
56-
}
57-
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all());
58-
} else {
59-
auto ch0seq = dynamic_cast<const HSSMatrix<scalar_t>*>(child(0));
60-
DenseM_t chDU;
61-
if (ch0seq->active()) {
62-
chDU = ch0->ULV_mpi_.D_.gather().solve
63-
(ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_);
64-
STRUMPACK_SCHUR_FLOPS
65-
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
66-
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols()));
67-
}
68-
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all());
69-
}
70-
DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank());
71-
gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01);
72-
STRUMPACK_SCHUR_FLOPS
73-
(gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.)));
74-
75-
DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols());
76-
copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all());
77-
auto DUB01t = DUB01.transpose();
78-
DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols());
79-
copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all());
80-
DUB01t.clear();
81-
82-
DistSubLeaf<scalar_t> Theta_br(_theta.cols(), ch1, grid_local()),
83-
Phi_br(_phi.cols(), ch1, grid_local());
84-
DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols());
85-
DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols());
86-
long long int flops = 0;
87-
ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops);
88-
STRUMPACK_SCHUR_FLOPS(flops);
89-
Theta_br.from_block_row(Theta_ch);
90-
Phi_br.from_block_row(Phi_ch);
91-
Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols());
92-
Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols());
93-
copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all());
94-
copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all());
95-
96-
Vhat = DistM_t(grid(), Phi.cols(), Theta.cols());
97-
copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all());
98-
}
9942

10043
/**
10144
* Apply Schur complement the direct way:

src/HSS/HSSMatrixMPI.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,64 @@ namespace strumpack {
170170
setup_ranges(roff, coff);
171171
}
172172

173+
template<typename scalar_t> void HSSMatrixMPI<scalar_t>::Schur_update
174+
(DistM_t& Theta, DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const {
175+
if (this->leaf()) return;
176+
auto ch0 = child(0);
177+
auto ch1 = child(1);
178+
DistM_t DU(grid(), ch0->U_rows(), ch0->U_rank());
179+
if (auto ch0mpi =
180+
dynamic_cast<const HSSMatrixMPI<scalar_t>*>(child(0))) {
181+
DistM_t chDU;
182+
if (ch0mpi->active()) {
183+
chDU = ch0->ULV_mpi_.D_.solve(ch0mpi->U_.dense(), ch0->ULV_mpi_.piv_);
184+
STRUMPACK_SCHUR_FLOPS
185+
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
186+
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0mpi->U_.cols()));
187+
}
188+
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0, 0, DU, 0, 0, grid()->ctxt_all());
189+
} else {
190+
auto ch0seq = dynamic_cast<const HSSMatrix<scalar_t>*>(child(0));
191+
DenseM_t chDU;
192+
if (ch0seq->active()) {
193+
chDU = ch0->ULV_mpi_.D_.gather().solve
194+
(ch0seq->U_.dense(), ch0->ULV_mpi_.piv_, ch0seq->openmp_task_depth_);
195+
STRUMPACK_SCHUR_FLOPS
196+
(!ch0->ULV_mpi_.D_.is_master() ? 0 :
197+
blas::getrs_flops(ch0->ULV_mpi_.D_.rows(), ch0seq->U_.cols()));
198+
}
199+
copy(ch0->U_rows(), ch0->U_rank(), chDU, 0/*rank ch0*/, DU, 0, 0, grid()->ctxt_all());
200+
}
201+
DUB01 = DistM_t(grid(), ch0->U_rows(), ch1->V_rank());
202+
gemm(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.), DUB01);
203+
STRUMPACK_SCHUR_FLOPS
204+
(gemm_flops(Trans::N, Trans::N, scalar_t(1.), DU, B01_, scalar_t(0.)));
205+
206+
DistM_t _theta(ch1->grid(grid_local()), B10_.rows(), B10_.cols());
207+
copy(B10_.rows(), B10_.cols(), B10_, 0, 0, _theta, 0, 0, grid()->ctxt_all());
208+
auto DUB01t = DUB01.transpose();
209+
DistM_t _phi(ch1->grid(grid_local()), DUB01t.rows(), DUB01t.cols());
210+
copy(DUB01t.rows(), DUB01t.cols(), DUB01t, 0, 0, _phi, 0, 0, grid()->ctxt_all());
211+
DUB01t.clear();
212+
213+
DistSubLeaf<scalar_t> Theta_br(_theta.cols(), ch1, grid_local()),
214+
Phi_br(_phi.cols(), ch1, grid_local());
215+
DistM_t Theta_ch(ch1->grid(grid_local()), ch1->rows(), _theta.cols());
216+
DistM_t Phi_ch(ch1->grid(grid_local()), ch1->cols(), _phi.cols());
217+
long long int flops = 0;
218+
ch1->apply_UV_big(Theta_br, _theta, Phi_br, _phi, flops);
219+
STRUMPACK_SCHUR_FLOPS(flops);
220+
Theta_br.from_block_row(Theta_ch);
221+
Phi_br.from_block_row(Phi_ch);
222+
Theta = DistM_t(grid(), Theta_ch.rows(), Theta_ch.cols());
223+
Phi = DistM_t(grid(), Phi_ch.rows(), Phi_ch.cols());
224+
copy(Theta.rows(), Theta.cols(), Theta_ch, 0, 0, Theta, 0, 0, grid()->ctxt_all());
225+
copy(Phi.rows(), Phi.cols(), Phi_ch, 0, 0, Phi, 0, 0, grid()->ctxt_all());
226+
227+
Vhat = DistM_t(grid(), Phi.cols(), Theta.cols());
228+
copy(Vhat.rows(), Vhat.cols(), ch0->ULV_mpi_.Vhat(), 0, 0, Vhat, 0, 0, grid()->ctxt_all());
229+
}
230+
173231
template<typename scalar_t> void
174232
HSSMatrixMPI<scalar_t>::setup_local_context() {
175233
if (!this->leaf()) {

0 commit comments

Comments
 (0)