Skip to content

Commit ccc13d8

Browse files
authored
Merge pull request #339 from SpM-lab/338-add-c-api-for-retrieving-default-sampling-for-spir_funcs
Add spir_basis_get_default_taus_ext
2 parents 3badc9f + a11afc0 commit ccc13d8

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

include/sparseir/sparseir.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,21 @@ int spir_basis_get_n_default_ws(const spir_basis *b, int *num_points);
598598
*/
599599
int spir_basis_get_default_ws(const spir_basis *b, double *points);
600600

601+
602+
/***
603+
* @brief Gets the default tau sampling points for ann IR basis.
604+
*
605+
* This function returns default tau sampling points for an IR basis object.
606+
*
607+
* @param b Pointer to the basis object
608+
* @param n_points Number of requested sampling points.
609+
* @param points Pre-allocated array to store the sampling points. The size of the array must be at least n_points.
610+
* @param n_points_returned Number of sampling points returned.
611+
* @return An integer status code:
612+
* - 0 (SPIR_COMPUTATION_SUCCESS) on success
613+
*/
614+
int spir_basis_get_default_taus_ext(const spir_basis *b, int n_points, double *points, int *n_points_returned);
615+
601616
/**
602617
* @brief Gets the number of default Matsubara sampling points for an IR basis.
603618
*

src/cinterface.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,56 @@ int spir_basis_get_default_taus(const spir_basis *b, double *points)
900900
}
901901
}
902902

903+
int spir_basis_get_default_taus_ext(
904+
const spir_basis *b, int n_points, double *points, int *n_points_returned)
905+
{
906+
if (!b || !points) {
907+
return SPIR_INVALID_ARGUMENT;
908+
}
909+
910+
auto impl = get_impl_basis(b);
911+
if (!impl) {
912+
return SPIR_GET_IMPL_FAILED;
913+
}
914+
915+
if (!is_ir_basis(b)) {
916+
DEBUG_LOG("Error: The basis is not an IR basis");
917+
return SPIR_INVALID_ARGUMENT;
918+
}
919+
920+
double beta = impl->get_beta();
921+
922+
try {
923+
Eigen::VectorXd tau_points;
924+
if (impl->get_statistics() == SPIR_STATISTICS_FERMIONIC) {
925+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Fermionic>>(impl);
926+
tau_points = sparseir::default_sampling_points(
927+
*(ir_basis->get_impl()->sve_result->u), n_points
928+
);
929+
*n_points_returned = tau_points.size();
930+
} else {
931+
auto ir_basis = _safe_static_pointer_cast<_IRBasis<sparseir::Bosonic>>(impl);
932+
tau_points = sparseir::default_sampling_points(
933+
*(ir_basis->get_impl()->sve_result->u), n_points
934+
);
935+
*n_points_returned = tau_points.size();
936+
}
937+
938+
// Copy the requested number of points
939+
// rescale the points to the original domain
940+
for (int i = 0; i < *n_points_returned; ++i) {
941+
tau_points(i) = (tau_points(i) + 1) / 2 * beta;
942+
if (tau_points(i) > 0.5 * beta) {
943+
tau_points(i) -= beta;
944+
}
945+
}
946+
std::copy(tau_points.data(), tau_points.data() + *n_points_returned, points);
947+
return SPIR_COMPUTATION_SUCCESS;
948+
} catch (const std::exception &e) {
949+
return SPIR_GET_IMPL_FAILED;
950+
}
951+
}
952+
903953
int spir_basis_get_n_default_matsus(const spir_basis *b, bool positive_only, int *num_points)
904954
{
905955
if (!b || !num_points) {

test/cpp/cinterface_sampling.cxx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ void test_tau_sampling()
8282
status = spir_basis_get_default_taus(basis, tau_points_org);
8383
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
8484

85+
int n_tau_points_ext = n_tau_points + 1;
86+
double *tau_points_ext = (double *)malloc(n_tau_points_ext * sizeof(double));
87+
int n_tau_points_returned;
88+
status = spir_basis_get_default_taus_ext(
89+
basis, n_tau_points_ext, tau_points_ext, &n_tau_points_returned);
90+
REQUIRE(status == SPIR_COMPUTATION_SUCCESS);
91+
REQUIRE(n_tau_points_returned == n_tau_points_ext);
92+
for (int i = 0; i < n_tau_points_returned; i++) {
93+
REQUIRE(tau_points_ext[i] >= -0.5 * beta);
94+
REQUIRE(tau_points_ext[i] <= 0.5 * beta);
95+
}
96+
8597
int sampling_status;
8698
spir_sampling *sampling = spir_tau_sampling_new(basis, n_tau_points, tau_points_org, &sampling_status);
8799
REQUIRE(sampling_status == SPIR_COMPUTATION_SUCCESS);

0 commit comments

Comments
 (0)