Skip to content

Commit 35b5765

Browse files
feat: add set_u0_sgrad (#62)
1 parent ce2cca5 commit 35b5765

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

src/execution/compiler.rs

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,27 @@ impl<M: CodegenModule> Compiler<M> {
333333
});
334334
}
335335

336+
pub fn set_u0_sgrad(&self, yy: &[f64], dyy: &mut [f64], data: &[f64], ddata: &mut [f64]) {
337+
self.check_state_len(yy, "yy");
338+
self.check_state_len(dyy, "dyy");
339+
self.check_data_len(data, "data");
340+
self.check_data_len(ddata, "ddata");
341+
self.with_threading(|i, dim| unsafe {
342+
(self
343+
.jit_sens_grad_functions
344+
.as_ref()
345+
.expect("module does not support sens autograd")
346+
.set_u0_sgrad)(
347+
yy.as_ptr(),
348+
dyy.as_ptr() as *mut f64,
349+
data.as_ptr(),
350+
ddata.as_ptr() as *mut f64,
351+
i,
352+
dim,
353+
);
354+
});
355+
}
356+
336357
pub fn set_u0_rgrad(&self, yy: &[f64], dyy: &mut [f64], data: &[f64], ddata: &mut [f64]) {
337358
self.check_state_len(yy, "yy");
338359
self.check_state_len(dyy, "dyy");
@@ -798,7 +819,7 @@ impl<M: CodegenModule> Compiler<M> {
798819
&self,
799820
inputs: &[f64],
800821
dinputs: &[f64],
801-
data: &mut [f64],
822+
data: &[f64],
802823
ddata: &mut [f64],
803824
) {
804825
self.check_inputs_len(inputs, "inputs");
@@ -809,7 +830,7 @@ impl<M: CodegenModule> Compiler<M> {
809830
(self.jit_grad_functions.set_inputs_grad)(
810831
inputs.as_ptr(),
811832
dinputs.as_ptr(),
812-
data.as_mut_ptr(),
833+
data.as_ptr(),
813834
ddata.as_mut_ptr(),
814835
)
815836
};
@@ -1985,4 +2006,46 @@ mod tests {
19852006

19862007
handle.join().unwrap();
19872008
}
2009+
2010+
#[cfg(feature = "llvm")]
2011+
#[test]
2012+
fn test_u0_sgrad_llvm() {
2013+
test_u0_sgrad::<crate::LlvmModule>();
2014+
}
2015+
2016+
#[allow(dead_code)]
2017+
fn test_u0_sgrad<M: CodegenModuleCompile + CodegenModuleJit>() {
2018+
let full_text = "
2019+
in = [a]
2020+
a { 1.0 }
2021+
u { 2 * a * a }
2022+
F { -u }
2023+
";
2024+
let model = parse_ds_string(full_text).unwrap();
2025+
let discrete_model = DiscreteModel::build("test_u0_sgrad", &model).unwrap();
2026+
let compiler =
2027+
Compiler::<M>::from_discrete_model(&discrete_model, Default::default()).unwrap();
2028+
let mut data = compiler.get_new_data();
2029+
let mut ddata = compiler.get_new_data();
2030+
let a = vec![0.6];
2031+
let da = vec![1.0];
2032+
compiler.set_inputs(a.as_slice(), data.as_mut_slice());
2033+
compiler.set_inputs_grad(
2034+
a.as_slice(),
2035+
da.as_slice(),
2036+
data.as_slice(),
2037+
ddata.as_mut_slice(),
2038+
);
2039+
let mut u0 = vec![0.0];
2040+
let mut du0 = vec![0.0];
2041+
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
2042+
compiler.set_u0_sgrad(
2043+
u0.as_mut_slice(),
2044+
du0.as_mut_slice(),
2045+
data.as_slice(),
2046+
ddata.as_mut_slice(),
2047+
);
2048+
assert_relative_eq!(u0.as_slice(), vec![2.0 * a[0] * a[0]].as_slice());
2049+
assert_relative_eq!(du0.as_slice(), vec![4.0 * a[0] * da[0]].as_slice());
2050+
}
19882051
}

src/execution/interface.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ pub type U0Func = unsafe extern "C" fn(
9191
thread_id: UIntType,
9292
thread_dim: UIntType,
9393
);
94+
pub type U0SensGradFunc = unsafe extern "C" fn(
95+
u: *const RealType,
96+
du: *mut RealType,
97+
data: *const RealType,
98+
ddata: *mut RealType,
99+
thread_id: UIntType,
100+
thread_dim: UIntType,
101+
);
94102
pub type U0GradFunc = unsafe extern "C" fn(
95103
u: *const RealType,
96104
du: *mut RealType,
@@ -347,13 +355,14 @@ impl JitGradRFunctions {
347355
}
348356

349357
pub(crate) struct JitSensGradFunctions {
358+
pub(crate) set_u0_sgrad: U0SensGradFunc,
350359
pub(crate) rhs_sgrad: RhsSensGradFunc,
351360
pub(crate) calc_out_sgrad: CalcOutSensGradFunc,
352361
}
353362

354363
impl JitSensGradFunctions {
355364
pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
356-
let required_symbols = ["rhs_sgrad", "calc_out_sgrad"];
365+
let required_symbols = ["rhs_sgrad", "calc_out_sgrad", "set_u0_sgrad"];
357366
for symbol in &required_symbols {
358367
if !symbol_map.contains_key(*symbol) {
359368
return Err(anyhow!("Missing required symbol: {}", symbol));
@@ -364,10 +373,13 @@ impl JitSensGradFunctions {
364373
let calc_out_sgrad = unsafe {
365374
std::mem::transmute::<*const u8, CalcOutSensGradFunc>(symbol_map["calc_out_sgrad"])
366375
};
376+
let set_u0_sgrad =
377+
unsafe { std::mem::transmute::<*const u8, U0SensGradFunc>(symbol_map["set_u0_sgrad"]) };
367378

368379
Ok(Self {
369380
rhs_sgrad,
370381
calc_out_sgrad,
382+
set_u0_sgrad,
371383
})
372384
}
373385
}

src/execution/llvm/codegen.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ impl CodegenModuleCompile for LlvmModule {
399399
"rhs_sgrad",
400400
)?;
401401

402+
module.codegen_mut().compile_gradient(
403+
set_u0,
404+
&[
405+
CompileGradientArgType::DupNoNeed,
406+
CompileGradientArgType::DupNoNeed,
407+
CompileGradientArgType::Const,
408+
CompileGradientArgType::Const,
409+
],
410+
CompileMode::ForwardSens,
411+
"set_u0_sgrad",
412+
)?;
413+
402414
module.codegen_mut().compile_gradient(
403415
calc_out_full,
404416
&[

0 commit comments

Comments
 (0)