@@ -3,6 +3,45 @@ use llama_cpp_sys_2::{ggml_type, llama_context_params};
3
3
use std:: fmt:: Debug ;
4
4
use std:: num:: NonZeroU32 ;
5
5
6
+ /// A rusty wrapper around `rope_scaling_type`.
7
+ #[ repr( i8 ) ]
8
+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
9
+ pub enum RopeScalingType {
10
+ /// The scaling type is unspecified
11
+ Unspecified = -1 ,
12
+ /// No scaling
13
+ None = 0 ,
14
+ /// Linear scaling
15
+ Linear = 1 ,
16
+ /// Yarn scaling
17
+ Yarn = 2 ,
18
+ }
19
+
20
+ /// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
21
+ /// the value is not recognized.
22
+ impl From < i8 > for RopeScalingType {
23
+ fn from ( value : i8 ) -> Self {
24
+ match value {
25
+ 0 => Self :: None ,
26
+ 1 => Self :: Linear ,
27
+ 2 => Self :: Yarn ,
28
+ _ => Self :: Unspecified ,
29
+ }
30
+ }
31
+ }
32
+
33
+ /// Create a `c_int` from a `RopeScalingType`.
34
+ impl From < RopeScalingType > for i8 {
35
+ fn from ( value : RopeScalingType ) -> Self {
36
+ match value {
37
+ RopeScalingType :: None => 0 ,
38
+ RopeScalingType :: Linear => 1 ,
39
+ RopeScalingType :: Yarn => 2 ,
40
+ RopeScalingType :: Unspecified => -1 ,
41
+ }
42
+ }
43
+ }
44
+
6
45
/// A safe wrapper around `llama_context_params`.
7
46
#[ derive( Debug , Clone , Copy , PartialEq ) ]
8
47
#[ allow(
@@ -18,7 +57,7 @@ pub struct LlamaContextParams {
18
57
pub n_batch : u32 ,
19
58
pub n_threads : u32 ,
20
59
pub n_threads_batch : u32 ,
21
- pub rope_scaling_type : i8 ,
60
+ pub rope_scaling_type : RopeScalingType ,
22
61
pub rope_freq_base : f32 ,
23
62
pub rope_freq_scale : f32 ,
24
63
pub yarn_ext_factor : f32 ,
@@ -83,7 +122,7 @@ impl From<llama_context_params> for LlamaContextParams {
83
122
mul_mat_q,
84
123
logits_all,
85
124
embedding,
86
- rope_scaling_type,
125
+ rope_scaling_type : RopeScalingType :: from ( rope_scaling_type ) ,
87
126
yarn_ext_factor,
88
127
yarn_attn_factor,
89
128
yarn_beta_fast,
@@ -131,7 +170,7 @@ impl From<LlamaContextParams> for llama_context_params {
131
170
mul_mat_q,
132
171
logits_all,
133
172
embedding,
134
- rope_scaling_type,
173
+ rope_scaling_type : i8 :: from ( rope_scaling_type ) ,
135
174
yarn_ext_factor,
136
175
yarn_attn_factor,
137
176
yarn_beta_fast,
0 commit comments