Skip to content

Commit d8ca243

Browse files
Merge pull request #64 from Radonirinaunimi/chebyshev
Optimize `Chebyshev` for Multi-dimensional data using batch vectorization
2 parents 56a92eb + a454896 commit d8ca243

File tree

15 files changed

+861
-54
lines changed

15 files changed

+861
-54
lines changed

.github/actions/cache-data/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ runs:
1010
uses: actions/cache@v4
1111
with:
1212
path: neopdf-data
13-
key: data-v8
13+
key: data-v10
1414
- name: Download data if cache miss
1515
if: steps.cache-data.outputs.cache-hit != 'true'
1616
run: |

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Added a logic to compute Chebyshev interpolations in batches (https://github.com/Radonirinaunimi/neopdf/pull/64).
1213
- Added proper LHAPDF drop-in compatibility layer for no-code migration.
1314
- Added an interface to the Wolfram Language to allow Rust APIs to be called in
1415
Mathematica.

maintainer/download-data.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ NEOPDF_SETS=(
2222
nNNPDF30_nlo_as_0118
2323
)
2424

25+
TMDLIB_SETS=(
26+
MAP22_grids_FF_Km_N3LL
27+
)
28+
2529
# Store the data in the root of the repository
2630
cd ..
2731
test -d neopdf-data || mkdir neopdf-data
@@ -37,3 +41,8 @@ done
3741
for neo in "${NEOPDF_SETS[@]}"; do
3842
wget --no-verbose --no-clobber -P neopdf-data "https://data.nnpdf.science/neopdf/data/${neo}.neopdf.lz4"
3943
done
44+
45+
# Dowload TMDlib sets
46+
for tmd in "${TMDLIB_SETS[@]}"; do
47+
wget --no-verbose --no-clobber -P neopdf-data "https://data.nnpdf.science/neopdf/data/${tmd}.neopdf.lz4"
48+
done

neopdf/benches/bench_pdf.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,32 @@ fn xfxq2(c: &mut Criterion) {
1010
});
1111
}
1212

13+
fn xfxq2_cheby(c: &mut Criterion) {
14+
let pdf = PDF::load("MAP22_grids_FF_Km_N3LL.neopdf.lz4", 0);
15+
16+
c.bench_function("xfxq2_cheby", |b| {
17+
b.iter(|| {
18+
pdf.xfxq2(
19+
std::hint::black_box(2),
20+
std::hint::black_box(&[1e-2, 5e-1, 10.0]),
21+
)
22+
})
23+
});
24+
}
25+
26+
fn xfxq2_cheby_batch(c: &mut Criterion) {
27+
let pdf = PDF::load("MAP22_grids_FF_Km_N3LL.neopdf.lz4", 0);
28+
29+
c.bench_function("xfxq2_cheby_batch", |b| {
30+
b.iter(|| {
31+
pdf.xfxq2_cheby_batch(
32+
std::hint::black_box(2),
33+
std::hint::black_box(&[&[1e-2, 5e-1, 10.0]]),
34+
)
35+
})
36+
});
37+
}
38+
1339
fn xfxq2s(c: &mut Criterion) {
1440
let pdf = PDF::load("NNPDF40_nnlo_as_01180", 0);
1541

@@ -45,5 +71,12 @@ fn xfxq2_members(c: &mut Criterion) {
4571
});
4672
}
4773

48-
criterion_group!(benches, xfxq2, xfxq2s, xfxq2_members);
74+
criterion_group!(
75+
benches,
76+
xfxq2,
77+
xfxq2s,
78+
xfxq2_members,
79+
xfxq2_cheby,
80+
xfxq2_cheby_batch
81+
);
4982
criterion_main!(benches);

neopdf/src/gridpdf.rs

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use core::panic;
99
use ndarray::{Array1, Array2};
1010
use serde::{Deserialize, Serialize};
11+
use std::collections::HashMap;
1112
use thiserror::Error;
1213

1314
use super::alphas::AlphaS;
@@ -244,7 +245,7 @@ impl GridPDF {
244245
Some(ForcePositive::ClipNegative) => value.max(0.0),
245246
Some(ForcePositive::ClipSmall) => value.max(1e-10),
246247
Some(ForcePositive::NoClipping) => value,
247-
None => value,
248+
_ => value,
248249
}
249250
}
250251

@@ -286,9 +287,10 @@ impl GridPDF {
286287
Error::SubgridNotFound { x, q2 }
287288
})?;
288289

289-
let pid_idx = self.knot_array.pid_index(flavor_id).ok_or_else(|| {
290-
Error::InterpolationError(format!("Invalid flavor ID: {}", flavor_id))
291-
})?;
290+
let pid_idx = self
291+
.knot_array
292+
.pid_index(flavor_id)
293+
.ok_or_else(|| Error::InterpolationError(format!("Invalid flavor ID: {flavor_id}")))?;
292294

293295
let use_log = matches!(
294296
self.info.interpolator_type,
@@ -335,6 +337,81 @@ impl GridPDF {
335337
Array2::from_shape_vec(grid_shape, data).unwrap()
336338
}
337339

340+
/// Interpolates PDF values for multiple points in parallel using Chebyshev batch interpolation.
341+
///
342+
/// # Arguments
343+
///
344+
/// * `flavor_id` - The flavor ID.
345+
/// * `points` - A slice containing the collection of knots to interpolate on.
346+
/// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`.
347+
///
348+
/// # Returns
349+
///
350+
/// A `Vec<f64>` of interpolated PDF values.
351+
pub fn xfxq2_cheby_batch(&self, flavor_id: i32, points: &[&[f64]]) -> Result<Vec<f64>, Error> {
352+
if points.is_empty() {
353+
return Ok(Vec::new());
354+
}
355+
356+
let pid_idx = self
357+
.knot_array
358+
.pid_index(flavor_id)
359+
.ok_or_else(|| Error::InterpolationError(format!("Invalid flavor ID: {flavor_id}")))?;
360+
361+
if !matches!(self.info.interpolator_type, InterpolatorType::LogChebyshev) {
362+
return Err(Error::InterpolationError(
363+
"xfxq2_cheby_batch only supports LogChebyshev interpolator".to_string(),
364+
));
365+
}
366+
367+
let mut subgrid_groups: HashMap<usize, Vec<(usize, &[f64])>> = HashMap::new();
368+
for (i, point) in points.iter().enumerate() {
369+
let subgrid_idx = self.knot_array.find_subgrid(point).ok_or_else(|| {
370+
let (x, q2) = self.get_x_q2(point);
371+
Error::SubgridNotFound { x, q2 }
372+
})?;
373+
374+
subgrid_groups
375+
.entry(subgrid_idx)
376+
.or_default()
377+
.push((i, *point));
378+
}
379+
380+
let mut all_results: Vec<(usize, f64)> = Vec::new();
381+
382+
for (subgrid_idx, group) in subgrid_groups {
383+
let subgrid = &self.knot_array.subgrids[subgrid_idx];
384+
385+
let (indices, group_points): (Vec<_>, Vec<_>) = group.into_iter().unzip();
386+
387+
let log_points: Vec<Vec<f64>> = group_points
388+
.iter()
389+
.map(|p| p.iter().map(|&v| v.ln()).collect::<Vec<f64>>())
390+
.collect();
391+
392+
let batch_interpolator =
393+
InterpolatorFactory::create_batch_interpolator(subgrid, pid_idx)
394+
.map_err(Error::InterpolationError)?;
395+
396+
let results = batch_interpolator
397+
.interpolate(log_points)
398+
.map_err(|e| Error::InterpolationError(e.to_string()))?;
399+
400+
for (original_index, result) in indices.into_iter().zip(results) {
401+
all_results.push((original_index, result));
402+
}
403+
}
404+
405+
// sort the results according to the original index
406+
all_results.sort_by_key(|&(i, _)| i);
407+
let final_results = all_results
408+
.into_iter()
409+
.map(|(_, r)| self.apply_force_positive(r))
410+
.collect();
411+
412+
Ok(final_results)
413+
}
414+
338415
/// Get the values of the momentum fraction `x` and momentum scale `Q2`.
339416
///
340417
/// # Arguments

neopdf/src/interpolator.rs

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//! The [`SubGrid`] struct is defined in `subgrid.rs`.
1313
1414
use ndarray::{s, OwnedRepr};
15+
use ninterp::data::{InterpData2D, InterpData3D};
1516
use ninterp::error::InterpolateError;
1617
use ninterp::interpolator::{
1718
Extrapolate, Interp2D, Interp2DOwned, Interp3D, Interp3DOwned, InterpND, InterpNDOwned,
@@ -23,7 +24,7 @@ use ninterp::strategy::Linear;
2324
use super::metadata::InterpolatorType;
2425
use super::strategy::{
2526
BilinearInterpolation, LogBicubicInterpolation, LogBilinearInterpolation,
26-
LogChebyshevInterpolation, LogTricubicInterpolation,
27+
LogChebyshevBatchInterpolation, LogChebyshevInterpolation, LogTricubicInterpolation,
2728
};
2829
use super::subgrid::SubGrid;
2930

@@ -115,6 +116,40 @@ where
115116
}
116117
}
117118

119+
/// An enum to dispatch batch interpolation to the correct Chebyshev interpolator.
120+
pub enum BatchInterpolator {
121+
Chebyshev2D(
122+
LogChebyshevBatchInterpolation<2>,
123+
InterpData2D<OwnedRepr<f64>>,
124+
),
125+
Chebyshev3D(
126+
LogChebyshevBatchInterpolation<3>,
127+
InterpData3D<OwnedRepr<f64>>,
128+
),
129+
}
130+
131+
impl BatchInterpolator {
132+
/// Interpolates a batch of points.
133+
pub fn interpolate(&self, points: Vec<Vec<f64>>) -> Result<Vec<f64>, InterpolateError> {
134+
match self {
135+
BatchInterpolator::Chebyshev2D(strategy, data) => {
136+
let points_2d: Vec<[f64; 2]> = points
137+
.into_iter()
138+
.map(|p| p.try_into().expect("Invalid point dimension for 2D"))
139+
.collect();
140+
strategy.interpolate(data, &points_2d)
141+
}
142+
BatchInterpolator::Chebyshev3D(strategy, data) => {
143+
let points_3d: Vec<[f64; 3]> = points
144+
.into_iter()
145+
.map(|p| p.try_into().expect("Invalid point dimension for 3D"))
146+
.collect();
147+
strategy.interpolate(data, &points_3d)
148+
}
149+
}
150+
}
151+
}
152+
118153
/// Factory for creating dynamic interpolators based on interpolation type and grid dimensions.
119154
pub struct InterpolatorFactory;
120155

@@ -455,6 +490,94 @@ impl InterpolatorFactory {
455490
_ => panic!("Unsupported 5D interpolator: {:?}", interp_type),
456491
}
457492
}
493+
494+
pub fn create_batch_interpolator(
495+
subgrid: &SubGrid,
496+
pid_idx: usize,
497+
) -> Result<BatchInterpolator, String> {
498+
match subgrid.interpolation_config() {
499+
InterpolationConfig::TwoD => {
500+
let mut strategy = LogChebyshevBatchInterpolation::<2>::default();
501+
let grid_slice = subgrid.grid_slice(pid_idx).to_owned();
502+
503+
let data = InterpData2D::new(
504+
subgrid.xs.mapv(f64::ln),
505+
subgrid.q2s.mapv(f64::ln),
506+
grid_slice,
507+
)
508+
.map_err(|e| e.to_string())?;
509+
strategy.init(&data).map_err(|e| e.to_string())?;
510+
511+
Ok(BatchInterpolator::Chebyshev2D(strategy, data))
512+
}
513+
InterpolationConfig::ThreeDNucleons => {
514+
let mut strategy = LogChebyshevBatchInterpolation::<3>::default();
515+
let grid_data = subgrid.grid.slice(s![.., 0, pid_idx, 0, .., ..]).to_owned();
516+
517+
let reshaped_data = grid_data
518+
.into_shape_with_order((
519+
subgrid.nucleons.len(),
520+
subgrid.xs.len(),
521+
subgrid.q2s.len(),
522+
))
523+
.expect("Failed to reshape 3D data");
524+
525+
let data = InterpData3D::new(
526+
subgrid.nucleons.mapv(f64::ln),
527+
subgrid.xs.mapv(f64::ln),
528+
subgrid.q2s.mapv(f64::ln),
529+
reshaped_data,
530+
)
531+
.map_err(|e| e.to_string())?;
532+
strategy.init(&data).map_err(|e| e.to_string())?;
533+
534+
Ok(BatchInterpolator::Chebyshev3D(strategy, data))
535+
}
536+
InterpolationConfig::ThreeDAlphas => {
537+
let mut strategy = LogChebyshevBatchInterpolation::<3>::default();
538+
let grid_data = subgrid.grid.slice(s![0, .., pid_idx, 0, .., ..]).to_owned();
539+
540+
let reshaped_data = grid_data
541+
.into_shape_with_order((
542+
subgrid.alphas.len(),
543+
subgrid.xs.len(),
544+
subgrid.q2s.len(),
545+
))
546+
.expect("Failed to reshape 3D data");
547+
548+
let data = InterpData3D::new(
549+
subgrid.alphas.mapv(f64::ln),
550+
subgrid.xs.mapv(f64::ln),
551+
subgrid.q2s.mapv(f64::ln),
552+
reshaped_data,
553+
)
554+
.map_err(|e| e.to_string())?;
555+
strategy.init(&data).map_err(|e| e.to_string())?;
556+
557+
Ok(BatchInterpolator::Chebyshev3D(strategy, data))
558+
}
559+
InterpolationConfig::ThreeDKt => {
560+
let mut strategy = LogChebyshevBatchInterpolation::<3>::default();
561+
let grid_data = subgrid.grid.slice(s![0, 0, pid_idx, .., .., ..]).to_owned();
562+
563+
let reshaped_data = grid_data
564+
.into_shape_with_order((subgrid.kts.len(), subgrid.xs.len(), subgrid.q2s.len()))
565+
.expect("Failed to reshape 3D data");
566+
567+
let data = InterpData3D::new(
568+
subgrid.kts.mapv(f64::ln),
569+
subgrid.xs.mapv(f64::ln),
570+
subgrid.q2s.mapv(f64::ln),
571+
reshaped_data,
572+
)
573+
.map_err(|e| e.to_string())?;
574+
strategy.init(&data).map_err(|e| e.to_string())?;
575+
576+
Ok(BatchInterpolator::Chebyshev3D(strategy, data))
577+
}
578+
_ => Err("Unsupported dimension for batch interpolation".to_string()),
579+
}
580+
}
458581
}
459582

460583
#[cfg(test)]

neopdf/src/pdf.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,23 @@ impl PDF {
287287
self.grid_pdf.xfxq2s(pids, slice_points)
288288
}
289289

290+
/// Interpolates the PDF value (xf) for multiple points using Chebyshev batch interpolation.
291+
///
292+
/// Abstraction to the `GridPDF::xfxq2_cheby_batch` method.
293+
///
294+
/// # Arguments
295+
///
296+
/// * `pid` - The flavor ID.
297+
/// * `points` - A slice containing the collection of knots to interpolate on.
298+
/// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`.
299+
///
300+
/// # Returns
301+
///
302+
/// A `Vec<f64>` of interpolated PDF values.
303+
pub fn xfxq2_cheby_batch(&self, pid: i32, points: &[&[f64]]) -> Vec<f64> {
304+
self.grid_pdf.xfxq2_cheby_batch(pid, points).unwrap()
305+
}
306+
290307
/// Interpolates the strong coupling constant `alpha_s` for a given Q2.
291308
///
292309
/// Abstraction to the `GridPDF::alphas_q2` method.

0 commit comments

Comments
 (0)