Skip to content

Commit f8aab09

Browse files
committed
added RopeScalingType
1 parent dd1fbea commit f8aab09

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/src/context/params.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,45 @@ use llama_cpp_sys_2::{ggml_type, llama_context_params};
33
use std::fmt::Debug;
44
use std::num::NonZeroU32;
55

6+
/// A rusty wrapper around `rope_scaling_type`.
7+
#[repr(i8)]
8+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9+
pub enum RopeScalingType {
10+
/// The scaling type is unspecified
11+
Unspecified = -1,
12+
/// No scaling
13+
None = 0,
14+
/// Linear scaling
15+
Linear = 1,
16+
/// Yarn scaling
17+
Yarn = 2,
18+
}
19+
20+
/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
21+
/// the value is not recognized.
22+
impl From<i8> for RopeScalingType {
23+
fn from(value: i8) -> Self {
24+
match value {
25+
0 => Self::None,
26+
1 => Self::Linear,
27+
2 => Self::Yarn,
28+
_ => Self::Unspecified,
29+
}
30+
}
31+
}
32+
33+
/// Create a `c_int` from a `RopeScalingType`.
34+
impl From<RopeScalingType> for i8 {
35+
fn from(value: RopeScalingType) -> Self {
36+
match value {
37+
RopeScalingType::None => 0,
38+
RopeScalingType::Linear => 1,
39+
RopeScalingType::Yarn => 2,
40+
RopeScalingType::Unspecified => -1,
41+
}
42+
}
43+
}
44+
645
/// A safe wrapper around `llama_context_params`.
746
#[derive(Debug, Clone, Copy, PartialEq)]
847
#[allow(
@@ -18,7 +57,7 @@ pub struct LlamaContextParams {
1857
pub n_batch: u32,
1958
pub n_threads: u32,
2059
pub n_threads_batch: u32,
21-
pub rope_scaling_type: i8,
60+
pub rope_scaling_type: RopeScalingType,
2261
pub rope_freq_base: f32,
2362
pub rope_freq_scale: f32,
2463
pub yarn_ext_factor: f32,
@@ -83,7 +122,7 @@ impl From<llama_context_params> for LlamaContextParams {
83122
mul_mat_q,
84123
logits_all,
85124
embedding,
86-
rope_scaling_type,
125+
rope_scaling_type: RopeScalingType::from(rope_scaling_type),
87126
yarn_ext_factor,
88127
yarn_attn_factor,
89128
yarn_beta_fast,
@@ -131,7 +170,7 @@ impl From<LlamaContextParams> for llama_context_params {
131170
mul_mat_q,
132171
logits_all,
133172
embedding,
134-
rope_scaling_type,
173+
rope_scaling_type: i8::from(rope_scaling_type),
135174
yarn_ext_factor,
136175
yarn_attn_factor,
137176
yarn_beta_fast,

llama-cpp-2/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ pub enum StringToTokenError {
182182
/// let elapsed = end - start;
183183
///
184184
/// assert!(elapsed >= 10)
185+
#[must_use]
185186
pub fn ggml_time_us() -> i64 {
186187
unsafe { llama_cpp_sys_2::ggml_time_us() }
187188
}

0 commit comments

Comments
 (0)