Skip to content

Commit ca28629

Browse files
authored
Merge pull request #700 from nobodywho-ooo/model-meta-val-stuff
Expose string metadata fetching methods
2 parents ffb0427 + c33b4a5 commit ca28629

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,22 @@ pub enum ChatTemplateError {
8282
Utf8Error(#[from] std::str::Utf8Error),
8383
}
8484

85+
/// Failed fetching metadata value
86+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
87+
pub enum MetaValError {
88+
/// The provided string contains an unexpected null-byte
89+
#[error("null byte in string {0}")]
90+
NullError(#[from] NulError),
91+
92+
/// The returned data contains invalid UTF8 data
93+
#[error("FromUtf8Error {0}")]
94+
FromUtf8Error(#[from] FromUtf8Error),
95+
96+
/// Got negative return value. This happens if the key or index queried does not exist.
97+
#[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")]
98+
NegativeReturn(i32),
99+
}
100+
85101
/// Failed to Load context
86102
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
87103
pub enum LlamaContextLoadError {

llama-cpp-2/src/model.rs

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ use crate::model::params::LlamaModelParams;
1313
use crate::token::LlamaToken;
1414
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
1515
use crate::{
16-
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17-
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
16+
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError,
17+
LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError, NewLlamaChatMessageError,
18+
StringToTokenError, TokenToStringError,
1819
};
1920

2021
pub mod params;
@@ -490,6 +491,59 @@ impl LlamaModel {
490491
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) }).unwrap()
491492
}
492493

494+
/// Get metadata value as a string by key name
495+
pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
496+
let key_cstring = CString::new(key)?;
497+
let key_ptr = key_cstring.as_ptr();
498+
499+
extract_meta_string(
500+
|buf_ptr, buf_len| unsafe {
501+
llama_cpp_sys_2::llama_model_meta_val_str(
502+
self.model.as_ptr(),
503+
key_ptr,
504+
buf_ptr,
505+
buf_len,
506+
)
507+
},
508+
256,
509+
)
510+
}
511+
512+
/// Get the number of metadata key/value pairs
513+
pub fn meta_count(&self) -> i32 {
514+
unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
515+
}
516+
517+
/// Get metadata key name by index
518+
pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
519+
extract_meta_string(
520+
|buf_ptr, buf_len| unsafe {
521+
llama_cpp_sys_2::llama_model_meta_key_by_index(
522+
self.model.as_ptr(),
523+
index,
524+
buf_ptr,
525+
buf_len,
526+
)
527+
},
528+
256,
529+
)
530+
}
531+
532+
/// Get metadata value as a string by index
533+
pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
534+
extract_meta_string(
535+
|buf_ptr, buf_len| unsafe {
536+
llama_cpp_sys_2::llama_model_meta_val_str_by_index(
537+
self.model.as_ptr(),
538+
index,
539+
buf_ptr,
540+
buf_len,
541+
)
542+
},
543+
256,
544+
)
545+
}
546+
493547
/// Returns the rope type of the model.
494548
pub fn rope_type(&self) -> Option<RopeType> {
495549
match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
@@ -690,6 +744,42 @@ impl LlamaModel {
690744
}
691745
}
692746

747+
/// Generic helper function for extracting string values from the C API
748+
/// This are specifically useful for the the metadata functions, where we pass in a buffer
749+
/// to be populated by a string, not yet knowing if the buffer is large enough.
750+
/// If the buffer was not large enough, we get the correct length back, which can be used to
751+
/// construct a buffer of appropriate size.
752+
fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
753+
where
754+
F: Fn(*mut c_char, usize) -> i32,
755+
{
756+
let mut buffer = vec![0u8; capacity];
757+
758+
// call the foreign function
759+
let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
760+
if result < 0 {
761+
return Err(MetaValError::NegativeReturn(result));
762+
}
763+
764+
// check if the response fit in our buffer
765+
let returned_len = result as usize;
766+
if returned_len >= capacity {
767+
// buffer wasn't large enough, try again with the correct capacity.
768+
return extract_meta_string(c_function, returned_len + 1);
769+
}
770+
771+
// verify null termination
772+
debug_assert_eq!(
773+
buffer.get(returned_len),
774+
Some(&0),
775+
"should end with null byte"
776+
);
777+
778+
// resize, convert, and return
779+
buffer.truncate(returned_len);
780+
Ok(String::from_utf8(buffer)?)
781+
}
782+
693783
impl Drop for LlamaModel {
694784
fn drop(&mut self) {
695785
unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }

0 commit comments

Comments
 (0)