Skip to content

Commit be62977

Browse files
Merge pull request #340 from SpM-lab/terasaki/matsus-ext
Add spir_basis_get_default_matsus_ext
2 parents ccc13d8 + b58b89b commit be62977

File tree

4 files changed

+160
-16
lines changed

4 files changed

+160
-16
lines changed

include/sparseir/sparseir.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,47 @@ int spir_basis_get_n_default_matsus(const spir_basis *b, bool positive_only,
663663
int spir_basis_get_default_matsus(const spir_basis *b, bool positive_only,
664664
int64_t *points);
665665

666+
/**
667+
* @brief Gets the number of default Matsubara sampling points for an IR basis.
668+
*
669+
* This function returns the number of default sampling points in Matsubara
670+
* frequencies (iωn) that are automatically chosen for optimal conditioning of
671+
* the sampling matrix. These points are the extrema of the highest-order basis
672+
* function in Matsubara frequencies.
673+
*
674+
* @param b Pointer to a finite temperature basis object (must be an IR basis)
675+
* @param positive_only If true, only positive frequencies are used
676+
* @param L Number of requested sampling points.
677+
* @param num_points_returned Pointer to store the number of sampling points returned.
678+
* @return An integer status code:
679+
* - 0 (SPIR_COMPUTATION_SUCCESS) on success
680+
* - A non-zero error code on failure
681+
*
682+
* @note This function is only available for IR basis objects
683+
* @note The default sampling points are chosen to provide near-optimal
684+
* conditioning for the given basis size
685+
* @see spir_basis_get_default_matsus
686+
*/
687+
int spir_basis_get_n_default_matsus_ext(const spir_basis *b, bool positive_only, int L, int *num_points_returned);
688+
689+
/**
690+
* @brief Gets the default Matsubara sampling points for an IR basis.
691+
*
692+
* This function fills the provided array with the default sampling points in
693+
* Matsubara frequencies (iωn) that are automatically chosen for optimal
694+
* conditioning of the sampling matrix. These points are the extrema of the
695+
* highest-order basis function in Matsubara frequencies.
696+
*
697+
* @param b Pointer to a finite temperature basis object (must be an IR basis)
698+
* @param positive_only If true, only positive frequencies are used
699+
* @param n_points Number of requested sampling points.
700+
* @param points Pre-allocated array to store the sampling points. The size of the array must be at least n_points.
701+
* @param n_points_returned Number of sampling points returned.
702+
* @return An integer status code:
703+
* - 0 (SPIR_COMPUTATION_SUCCESS) on success
704+
*/
705+
int spir_basis_get_default_matsus_ext(const spir_basis *b, bool positive_only, int n_points, int64_t *points, int *n_points_returned);
706+
666707
/**
667708
* @brief Creates a new Discrete Lehmann Representation (DLR) basis.
668709
*

src/cinterface.cpp

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ spir_kernel* spir_logistic_kernel_new(double lambda, int* status)
2525
try {
2626
auto kernel_ptr = std::make_shared<sparseir::LogisticKernel>(lambda);
2727
std::shared_ptr<sparseir::AbstractKernel> abstract_kernel = _safe_static_pointer_cast<sparseir::AbstractKernel>(kernel_ptr);
28-
28+
2929
// Check if dynamic_cast works at this point
3030
auto check_logistic = _safe_dynamic_pointer_cast<sparseir::LogisticKernel>(abstract_kernel);
31-
31+
3232
*status = SPIR_COMPUTATION_SUCCESS;
3333
return create_kernel(abstract_kernel);
3434
} catch (const std::exception &e) {
@@ -118,9 +118,9 @@ spir_sve_result* spir_sve_result_new(const spir_kernel *k, double epsilon, int*
118118
*status = SPIR_GET_IMPL_FAILED;
119119
return nullptr;
120120
}
121-
121+
122122
std::shared_ptr<sparseir::SVEResult> sve_result;
123-
123+
124124
if (auto logistic = std::dynamic_pointer_cast<sparseir::LogisticKernel>(impl)) {
125125
sve_result = std::make_shared<sparseir::SVEResult>(sparseir::compute_sve(*logistic, epsilon));
126126
} else if (auto bose = std::dynamic_pointer_cast<sparseir::RegularizedBoseKernel>(impl)) {
@@ -235,7 +235,7 @@ spir_sampling* spir_tau_sampling_new_with_matrix(int order, int statistics, int
235235
DEBUG_LOG("Error: Invalid statistics");
236236
return nullptr;
237237
}
238-
238+
239239
// check order
240240
if (order != SPIR_ORDER_ROW_MAJOR && order != SPIR_ORDER_COLUMN_MAJOR) {
241241
*status = SPIR_INVALID_ARGUMENT;
@@ -334,14 +334,14 @@ spir_sampling* spir_matsu_sampling_new_with_matrix(
334334
*status = SPIR_INVALID_ARGUMENT;
335335
return nullptr;
336336
}
337-
337+
338338
// check statistics
339339
if (statistics != SPIR_STATISTICS_FERMIONIC && statistics != SPIR_STATISTICS_BOSONIC) {
340340
*status = SPIR_INVALID_ARGUMENT;
341341
DEBUG_LOG("Error: Invalid statistics");
342342
return nullptr;
343343
}
344-
344+
345345
// check order
346346
if (order != SPIR_ORDER_ROW_MAJOR && order != SPIR_ORDER_COLUMN_MAJOR) {
347347
*status = SPIR_INVALID_ARGUMENT;
@@ -476,7 +476,7 @@ int spir_dlr2ir_dd(const spir_basis *dlr, int order, int ndim,
476476
auto impl = get_impl_basis(dlr);
477477
if (!impl)
478478
return SPIR_GET_IMPL_FAILED;
479-
479+
480480
if (!is_dlr_basis(dlr)) {
481481
DEBUG_LOG("Error: The basis is not a DLR basis");
482482
return SPIR_INVALID_ARGUMENT;
@@ -523,7 +523,7 @@ int spir_ir2dlr_dd(const spir_basis *dlr, int order,
523523
auto impl = get_impl_basis(dlr);
524524
if (!impl)
525525
return SPIR_GET_IMPL_FAILED;
526-
526+
527527
if (!is_dlr_basis(dlr)) {
528528
DEBUG_LOG("Error: The basis is not a DLR basis");
529529
return SPIR_INVALID_ARGUMENT;
@@ -545,7 +545,7 @@ int spir_ir2dlr_zz(const spir_basis *dlr, int order,
545545
auto impl = get_impl_basis(dlr);
546546
if (!impl)
547547
return SPIR_GET_IMPL_FAILED;
548-
548+
549549
if (!is_dlr_basis(dlr)) {
550550
DEBUG_LOG("Error: The basis is not a DLR basis");
551551
return SPIR_INVALID_ARGUMENT;
@@ -1016,6 +1016,74 @@ int spir_basis_get_default_matsus(const spir_basis *b, bool positive_only, int64
10161016
}
10171017
}
10181018

1019+
int spir_basis_get_n_default_matsus_ext(const spir_basis *b, bool positive_only, int L, int *num_points_returned)
1020+
{
1021+
if (!b || !num_points_returned) {
1022+
return SPIR_INVALID_ARGUMENT;
1023+
}
1024+
1025+
auto impl = get_impl_basis(b);
1026+
if (!impl) {
1027+
return SPIR_GET_IMPL_FAILED;
1028+
}
1029+
1030+
if (!is_ir_basis(b)) {
1031+
DEBUG_LOG("Error: The basis is not an IR basis");
1032+
return SPIR_INVALID_ARGUMENT;
1033+
}
1034+
1035+
try {
1036+
if (impl->get_statistics() == SPIR_STATISTICS_FERMIONIC) {
1037+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Fermionic>>(impl);
1038+
auto points = ir_basis->default_matsubara_sampling_points_ext(L, positive_only);
1039+
*num_points_returned = points.size();
1040+
return SPIR_COMPUTATION_SUCCESS;
1041+
} else {
1042+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Bosonic>>(impl);
1043+
auto points = ir_basis->default_matsubara_sampling_points_ext(L, positive_only);
1044+
*num_points_returned = points.size();
1045+
return SPIR_COMPUTATION_SUCCESS;
1046+
}
1047+
} catch (const std::exception &e) {
1048+
return SPIR_GET_IMPL_FAILED;
1049+
}
1050+
}
1051+
1052+
int spir_basis_get_default_matsus_ext(const spir_basis *b, bool positive_only, int L, int64_t *points, int *n_points_returned)
1053+
{
1054+
if (!b || !points) {
1055+
return SPIR_INVALID_ARGUMENT;
1056+
}
1057+
1058+
auto impl = get_impl_basis(b);
1059+
if (!impl) {
1060+
return SPIR_GET_IMPL_FAILED;
1061+
}
1062+
1063+
if (!is_ir_basis(b)) {
1064+
DEBUG_LOG("Error: The basis is not an IR basis");
1065+
return SPIR_INVALID_ARGUMENT;
1066+
}
1067+
1068+
try {
1069+
if (impl->get_statistics() == SPIR_STATISTICS_FERMIONIC) {
1070+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Fermionic>>(impl);
1071+
auto matsubara_points = ir_basis->default_matsubara_sampling_points_ext(L, positive_only);
1072+
*n_points_returned = matsubara_points.size();
1073+
std::copy(matsubara_points.begin(), matsubara_points.end(), points);
1074+
return SPIR_COMPUTATION_SUCCESS;
1075+
} else {
1076+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Bosonic>>(impl);
1077+
auto matsubara_points = ir_basis->default_matsubara_sampling_points_ext(L, positive_only);
1078+
*n_points_returned = matsubara_points.size();
1079+
std::copy(matsubara_points.begin(), matsubara_points.end(), points);
1080+
return SPIR_COMPUTATION_SUCCESS;
1081+
}
1082+
} catch (const std::exception &e) {
1083+
return SPIR_GET_IMPL_FAILED;
1084+
}
1085+
}
1086+
10191087
int spir_basis_get_stats(const spir_basis *b,
10201088
int *statistics)
10211089
{
@@ -1093,7 +1161,7 @@ int spir_funcs_batch_eval(const spir_funcs *funcs,
10931161
// result is a matrix of size n_funcs x num_points in column-major order
10941162
Eigen::MatrixXd result = std::dynamic_pointer_cast<AbstractContinuousFunctions>(impl)->operator()(Eigen::Map<const Eigen::VectorXd>(xs, num_points));
10951163

1096-
// out is a matrix of size num_points x n_funcs
1164+
// out is a matrix of size num_points x n_funcs
10971165
if (order == SPIR_ORDER_ROW_MAJOR) {
10981166
// Copy the results to the output array
10991167
for (int i = 0; i < num_points; ++i) {
@@ -1475,7 +1543,7 @@ int spir_funcs_get_roots(const spir_funcs *funcs, double *roots)
14751543

14761544
// Get the roots from the implementation
14771545
Eigen::VectorXd roots_vec = continuous_impl->roots();
1478-
1546+
14791547
// Copy the roots to the output array
14801548
std::memcpy(roots, roots_vec.data(), roots_vec.size() * sizeof(double));
14811549

src/cinterface_impl/helper_types.hpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class AbstractContinuousFunctions : public _AbstractFuncs {
5050
virtual Eigen::MatrixXd operator()(const Eigen::VectorXd &xs) const = 0;
5151
virtual std::pair<double, double> get_domain() const = 0;
5252
virtual bool is_continuous_funcs() const override { return true; }
53-
53+
5454
// Returns the number of roots of the functions
5555
virtual int nroots() const = 0;
56-
56+
5757
// Returns the roots of the functions in non-ascending order
5858
virtual Eigen::VectorXd roots() const = 0;
5959
};
@@ -137,13 +137,13 @@ class OmegaFunctionsAdaptor : public AbstractContinuousFunctions {
137137
virtual std::shared_ptr<_AbstractFuncs> slice(const std::vector<size_t>& indices) const override {
138138
// First get the sliced implementation
139139
auto sliced_impl = impl->slice(indices);
140-
140+
141141
// Convert the sliced implementation to the correct type
142142
auto converted_impl = _safe_dynamic_pointer_cast<ImplType>(sliced_impl);
143143
if (!converted_impl) {
144144
throw std::runtime_error("Failed to convert sliced implementation to correct type");
145145
}
146-
146+
147147
// Create new adapter with the converted implementation
148148
return std::make_shared<OmegaFunctionsAdaptor<ImplType>>(converted_impl);
149149
}
@@ -295,6 +295,20 @@ class _IRBasis : public AbstractFiniteTempBasis {
295295
return points;
296296
}
297297

298+
std::vector<int64_t> default_matsubara_sampling_points_ext(int n_points, bool positive_only) const
299+
{
300+
bool fence = false;
301+
302+
std::vector<sparseir::MatsubaraFreq<S>> matsubara_points = impl->default_matsubara_sampling_points(n_points, fence, positive_only);
303+
std::vector<int64_t> points(matsubara_points.size());
304+
std::transform(
305+
matsubara_points.begin(), matsubara_points.end(), points.begin(),
306+
[](const sparseir::MatsubaraFreq<S> &freq) {
307+
return static_cast<int64_t>(freq.get_n());
308+
});
309+
return points;
310+
}
311+
298312
std::vector<double> default_omega_sampling_points() const
299313
{
300314
// convert from Eigen::VectorXd to std::vector<double>

test/cpp/cinterface_sampling.cxx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,27 @@ void test_matsubara_sampling_constructor()
913913
REQUIRE(smpl_points_positive_only[i] == smpl_points_positive_only_org[i]);
914914
}
915915

916+
int n_matsubara_points_returned_ext;
917+
int L = n_points_org;
918+
919+
status = spir_basis_get_n_default_matsus_ext(basis, false, L, &n_matsubara_points_returned_ext);
920+
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
921+
922+
int64_t *smpl_points_ext = (int64_t *)malloc(n_matsubara_points_returned_ext * sizeof(int64_t));
923+
status = spir_basis_get_default_matsus_ext(basis, false, L, smpl_points_ext, &n_matsubara_points_returned_ext);
924+
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
925+
926+
status = spir_basis_get_n_default_matsus_ext(basis, true, (int)(L / 2), &n_matsubara_points_returned_ext);
927+
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
928+
929+
int64_t *smpl_points_ext_positive_only = (int64_t *)malloc(n_matsubara_points_returned_ext * sizeof(int64_t));
930+
status = spir_basis_get_default_matsus_ext(basis, true, (int)(L / 2), smpl_points_ext_positive_only, &n_matsubara_points_returned_ext);
931+
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
932+
933+
for (int i = 0; i < n_matsubara_points_returned_ext; ++i) {
934+
REQUIRE(smpl_points_ext_positive_only[i] >= 0);
935+
}
936+
916937
// Clean up
917938
spir_sampling_release(sampling);
918939
spir_sampling_release(sampling_positive_only);

0 commit comments

Comments
 (0)