1212//! The [`SubGrid`] struct is defined in `subgrid.rs`.
1313
1414use ndarray:: { s, OwnedRepr } ;
15+ use ninterp:: data:: { InterpData2D , InterpData3D } ;
1516use ninterp:: error:: InterpolateError ;
1617use ninterp:: interpolator:: {
1718 Extrapolate , Interp2D , Interp2DOwned , Interp3D , Interp3DOwned , InterpND , InterpNDOwned ,
@@ -23,7 +24,7 @@ use ninterp::strategy::Linear;
2324use super :: metadata:: InterpolatorType ;
2425use super :: strategy:: {
2526 BilinearInterpolation , LogBicubicInterpolation , LogBilinearInterpolation ,
26- LogChebyshevInterpolation , LogTricubicInterpolation ,
27+ LogChebyshevBatchInterpolation , LogChebyshevInterpolation , LogTricubicInterpolation ,
2728} ;
2829use super :: subgrid:: SubGrid ;
2930
@@ -115,6 +116,40 @@ where
115116 }
116117}
117118
119+ /// An enum to dispatch batch interpolation to the correct Chebyshev interpolator.
120+ pub enum BatchInterpolator {
121+ Chebyshev2D (
122+ LogChebyshevBatchInterpolation < 2 > ,
123+ InterpData2D < OwnedRepr < f64 > > ,
124+ ) ,
125+ Chebyshev3D (
126+ LogChebyshevBatchInterpolation < 3 > ,
127+ InterpData3D < OwnedRepr < f64 > > ,
128+ ) ,
129+ }
130+
131+ impl BatchInterpolator {
132+ /// Interpolates a batch of points.
133+ pub fn interpolate ( & self , points : Vec < Vec < f64 > > ) -> Result < Vec < f64 > , InterpolateError > {
134+ match self {
135+ BatchInterpolator :: Chebyshev2D ( strategy, data) => {
136+ let points_2d: Vec < [ f64 ; 2 ] > = points
137+ . into_iter ( )
138+ . map ( |p| p. try_into ( ) . expect ( "Invalid point dimension for 2D" ) )
139+ . collect ( ) ;
140+ strategy. interpolate ( data, & points_2d)
141+ }
142+ BatchInterpolator :: Chebyshev3D ( strategy, data) => {
143+ let points_3d: Vec < [ f64 ; 3 ] > = points
144+ . into_iter ( )
145+ . map ( |p| p. try_into ( ) . expect ( "Invalid point dimension for 3D" ) )
146+ . collect ( ) ;
147+ strategy. interpolate ( data, & points_3d)
148+ }
149+ }
150+ }
151+ }
152+
118153/// Factory for creating dynamic interpolators based on interpolation type and grid dimensions.
119154pub struct InterpolatorFactory ;
120155
@@ -455,6 +490,94 @@ impl InterpolatorFactory {
455490 _ => panic ! ( "Unsupported 5D interpolator: {:?}" , interp_type) ,
456491 }
457492 }
493+
494+ pub fn create_batch_interpolator (
495+ subgrid : & SubGrid ,
496+ pid_idx : usize ,
497+ ) -> Result < BatchInterpolator , String > {
498+ match subgrid. interpolation_config ( ) {
499+ InterpolationConfig :: TwoD => {
500+ let mut strategy = LogChebyshevBatchInterpolation :: < 2 > :: default ( ) ;
501+ let grid_slice = subgrid. grid_slice ( pid_idx) . to_owned ( ) ;
502+
503+ let data = InterpData2D :: new (
504+ subgrid. xs . mapv ( f64:: ln) ,
505+ subgrid. q2s . mapv ( f64:: ln) ,
506+ grid_slice,
507+ )
508+ . map_err ( |e| e. to_string ( ) ) ?;
509+ strategy. init ( & data) . map_err ( |e| e. to_string ( ) ) ?;
510+
511+ Ok ( BatchInterpolator :: Chebyshev2D ( strategy, data) )
512+ }
513+ InterpolationConfig :: ThreeDNucleons => {
514+ let mut strategy = LogChebyshevBatchInterpolation :: < 3 > :: default ( ) ;
515+ let grid_data = subgrid. grid . slice ( s ! [ .., 0 , pid_idx, 0 , .., ..] ) . to_owned ( ) ;
516+
517+ let reshaped_data = grid_data
518+ . into_shape_with_order ( (
519+ subgrid. nucleons . len ( ) ,
520+ subgrid. xs . len ( ) ,
521+ subgrid. q2s . len ( ) ,
522+ ) )
523+ . expect ( "Failed to reshape 3D data" ) ;
524+
525+ let data = InterpData3D :: new (
526+ subgrid. nucleons . mapv ( f64:: ln) ,
527+ subgrid. xs . mapv ( f64:: ln) ,
528+ subgrid. q2s . mapv ( f64:: ln) ,
529+ reshaped_data,
530+ )
531+ . map_err ( |e| e. to_string ( ) ) ?;
532+ strategy. init ( & data) . map_err ( |e| e. to_string ( ) ) ?;
533+
534+ Ok ( BatchInterpolator :: Chebyshev3D ( strategy, data) )
535+ }
536+ InterpolationConfig :: ThreeDAlphas => {
537+ let mut strategy = LogChebyshevBatchInterpolation :: < 3 > :: default ( ) ;
538+ let grid_data = subgrid. grid . slice ( s ! [ 0 , .., pid_idx, 0 , .., ..] ) . to_owned ( ) ;
539+
540+ let reshaped_data = grid_data
541+ . into_shape_with_order ( (
542+ subgrid. alphas . len ( ) ,
543+ subgrid. xs . len ( ) ,
544+ subgrid. q2s . len ( ) ,
545+ ) )
546+ . expect ( "Failed to reshape 3D data" ) ;
547+
548+ let data = InterpData3D :: new (
549+ subgrid. alphas . mapv ( f64:: ln) ,
550+ subgrid. xs . mapv ( f64:: ln) ,
551+ subgrid. q2s . mapv ( f64:: ln) ,
552+ reshaped_data,
553+ )
554+ . map_err ( |e| e. to_string ( ) ) ?;
555+ strategy. init ( & data) . map_err ( |e| e. to_string ( ) ) ?;
556+
557+ Ok ( BatchInterpolator :: Chebyshev3D ( strategy, data) )
558+ }
559+ InterpolationConfig :: ThreeDKt => {
560+ let mut strategy = LogChebyshevBatchInterpolation :: < 3 > :: default ( ) ;
561+ let grid_data = subgrid. grid . slice ( s ! [ 0 , 0 , pid_idx, .., .., ..] ) . to_owned ( ) ;
562+
563+ let reshaped_data = grid_data
564+ . into_shape_with_order ( ( subgrid. kts . len ( ) , subgrid. xs . len ( ) , subgrid. q2s . len ( ) ) )
565+ . expect ( "Failed to reshape 3D data" ) ;
566+
567+ let data = InterpData3D :: new (
568+ subgrid. kts . mapv ( f64:: ln) ,
569+ subgrid. xs . mapv ( f64:: ln) ,
570+ subgrid. q2s . mapv ( f64:: ln) ,
571+ reshaped_data,
572+ )
573+ . map_err ( |e| e. to_string ( ) ) ?;
574+ strategy. init ( & data) . map_err ( |e| e. to_string ( ) ) ?;
575+
576+ Ok ( BatchInterpolator :: Chebyshev3D ( strategy, data) )
577+ }
578+ _ => Err ( "Unsupported dimension for batch interpolation" . to_string ( ) ) ,
579+ }
580+ }
458581}
459582
460583#[ cfg( test) ]
0 commit comments