Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit 8254deb

Browse files
authored
Merge pull request #152 from philpax/load-tensors-as-stored
fix #149 - load tensors by type, ignoring filetype
2 parents 1b20306 + c9e5c26 commit 8254deb

File tree

13 files changed

+402
-241
lines changed

13 files changed

+402
-241
lines changed

ggml-loader/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ edition = "2021"
77

88
[dependencies]
99
ggml = { path = "../ggml" }
10-
thiserror = "*"
10+
thiserror = "1.0"

ggml-loader/src/lib.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use util::*;
1010

1111
pub type ElementType = ggml::Type;
1212

13-
/// file type containing the model
13+
/// the format of the file containing the model
1414
#[derive(Debug, PartialEq, Clone, Copy)]
1515
#[allow(clippy::upper_case_acronyms)]
1616
pub enum ContainerType {
@@ -21,7 +21,6 @@ pub enum ContainerType {
2121
/// mmap-able format
2222
GGJT,
2323
}
24-
2524
impl ContainerType {
2625
pub fn support_mmap(&self) -> bool {
2726
match self {
@@ -64,10 +63,19 @@ pub struct TensorInfo {
6463
pub n_dims: usize,
6564
pub dims: [usize; 2],
6665
pub n_elements: usize,
67-
pub ftype: ElementType,
66+
pub element_type: ElementType,
6867
/// start of tensor - start of file
6968
pub start_offset: u64,
7069
}
70+
impl TensorInfo {
71+
pub fn calc_size(&self) -> usize {
72+
let mut size = ggml::type_size(self.element_type);
73+
for &dim in &self.dims[0..self.n_dims] {
74+
size *= dim;
75+
}
76+
size / ggml::blck_size(self.element_type)
77+
}
78+
}
7179

7280
/// Info in hyperparameter used for later loading tasks. Used in callback.
7381
/// see [`LoadHandler::load_hyper_parameters`]
@@ -78,10 +86,7 @@ pub struct PartialHyperparameters {
7886

7987
pub enum TensorDataTreatment<'a> {
8088
CopyInto(&'a mut [u8]),
81-
SeekPast {
82-
/// should be `tensor.nbytes`
83-
n_bytes: usize,
84-
},
89+
Skip,
8590
}
8691

8792
#[allow(unused_variables)]
@@ -173,7 +178,9 @@ pub fn load_weights<T, R: BufRead + Seek>(
173178
// load tensor header
174179
let n_dims: usize = read_i32(reader)?.try_into()?;
175180
let name_len = read_i32(reader)?;
176-
let ftype = decode_element_type_res(read_i32(reader)?)?;
181+
let ftype = read_i32(reader)?;
182+
let ftype =
183+
ggml::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType(ftype))?;
177184

178185
let mut n_elements: usize = 1;
179186
let mut dims = [1usize, 1];
@@ -214,9 +221,10 @@ pub fn load_weights<T, R: BufRead + Seek>(
214221
dims,
215222
n_dims,
216223
n_elements,
217-
ftype,
224+
element_type: ftype,
218225
start_offset: offset_aligned,
219226
};
227+
let n_bytes = tensor_info.calc_size();
220228

221229
match controlflow_to_result(handler.tensor_buffer(tensor_info))? {
222230
TensorDataTreatment::CopyInto(buf) => {
@@ -225,7 +233,7 @@ pub fn load_weights<T, R: BufRead + Seek>(
225233
}
226234
reader.read_exact(buf)?;
227235
}
228-
TensorDataTreatment::SeekPast { n_bytes } => {
236+
TensorDataTreatment::Skip => {
229237
// skip if no buffer is given
230238
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
231239
}

ggml-loader/src/util.rs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pub use std::io::{BufRead, Seek, SeekFrom};
22
use std::ops::ControlFlow;
33

4-
use crate::{ElementType, LoadError};
4+
use crate::LoadError;
55

66
pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
77
let mut bytes = [0u8; N];
@@ -35,33 +35,6 @@ pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error>
3535
reader.fill_buf().map(|b| !b.is_empty())
3636
}
3737

38-
pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
39-
match ftype {
40-
0 => Some(ggml::Type::F32),
41-
1 => Some(ggml::Type::F16),
42-
2 => Some(ggml::Type::Q4_0),
43-
3 => Some(ggml::Type::Q4_1),
44-
_ => None,
45-
}
46-
}
47-
48-
pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
49-
match element_type {
50-
ggml::Type::F32 => Some(0),
51-
ggml::Type::F16 => Some(1),
52-
ggml::Type::Q4_0 => Some(2),
53-
ggml::Type::Q4_1 => Some(3),
54-
_ => None,
55-
}
56-
}
57-
58-
pub fn decode_element_type_res<T>(ftype: i32) -> Result<ElementType, LoadError<T>> {
59-
match decode_element_type(ftype) {
60-
Some(x) => Ok(x),
61-
None => Err(LoadError::UnsupportedElementType(ftype)),
62-
}
63-
}
64-
6538
pub fn controlflow_to_result<A, B>(x: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
6639
match x {
6740
ControlFlow::Continue(x) => Ok(x),

ggml/src/lib.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c;
2424
/// The currently-supported format version for `ggml` files.
2525
pub const FORMAT_VERSION: u32 = 1;
2626

27+
/// The size of a `ggml` object.
28+
pub const OBJECT_SIZE: usize = ggml_sys::GGML_OBJECT_SIZE;
29+
2730
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
2831
/// The type of a value in `ggml`.
2932
pub enum Type {
@@ -32,6 +35,12 @@ pub enum Type {
3235
Q4_0,
3336
/// Quantized 4-bit (type 1); used by GPTQ.
3437
Q4_1,
38+
/// Quantized 4-bit (type 2).
39+
Q4_2,
40+
/// Quantized 4-bit (type 3).
41+
Q4_3,
42+
/// Quantized 8-bit (type 0).
43+
Q8_0,
3544
/// Integer 32-bit.
3645
I32,
3746
/// Float 16-bit.
@@ -44,6 +53,9 @@ impl From<Type> for ggml_sys::ggml_type {
4453
match t {
4554
Type::Q4_0 => ggml_sys::ggml_type_GGML_TYPE_Q4_0,
4655
Type::Q4_1 => ggml_sys::ggml_type_GGML_TYPE_Q4_1,
56+
Type::Q4_2 => ggml_sys::ggml_type_GGML_TYPE_Q4_2,
57+
Type::Q4_3 => ggml_sys::ggml_type_GGML_TYPE_Q4_3,
58+
Type::Q8_0 => ggml_sys::ggml_type_GGML_TYPE_Q8_0,
4759
Type::I32 => ggml_sys::ggml_type_GGML_TYPE_I32,
4860
Type::F16 => ggml_sys::ggml_type_GGML_TYPE_F16,
4961
Type::F32 => ggml_sys::ggml_type_GGML_TYPE_F32,
@@ -56,6 +68,9 @@ impl TryFrom<ggml_sys::ggml_type> for Type {
5668
match t {
5769
ggml_sys::ggml_type_GGML_TYPE_Q4_0 => Ok(Type::Q4_0),
5870
ggml_sys::ggml_type_GGML_TYPE_Q4_1 => Ok(Type::Q4_1),
71+
ggml_sys::ggml_type_GGML_TYPE_Q4_2 => Ok(Type::Q4_2),
72+
ggml_sys::ggml_type_GGML_TYPE_Q4_3 => Ok(Type::Q4_3),
73+
ggml_sys::ggml_type_GGML_TYPE_Q8_0 => Ok(Type::Q8_0),
5974
ggml_sys::ggml_type_GGML_TYPE_I32 => Ok(Type::I32),
6075
ggml_sys::ggml_type_GGML_TYPE_F16 => Ok(Type::F16),
6176
ggml_sys::ggml_type_GGML_TYPE_F32 => Ok(Type::F32),
@@ -68,6 +83,9 @@ impl std::fmt::Display for Type {
6883
match self {
6984
Type::Q4_0 => write!(f, "q4_0"),
7085
Type::Q4_1 => write!(f, "q4_1"),
86+
Type::Q4_2 => write!(f, "q4_2"),
87+
Type::Q4_3 => write!(f, "q4_3"),
88+
Type::Q8_0 => write!(f, "q8_0"),
7189
Type::I32 => write!(f, "i32"),
7290
Type::F16 => write!(f, "f16"),
7391
Type::F32 => write!(f, "f32"),
@@ -510,6 +528,11 @@ pub struct Tensor {
510528
}
511529

512530
impl Tensor {
531+
/// Size of the `ggml_tensor` struct in bytes.
532+
///
533+
/// Exposed for purposes of determining context size.
534+
pub const C_TYPE_SIZE: usize = std::mem::size_of::<ggml_sys::ggml_tensor>();
535+
513536
/// Creates a shared copy of this tensor pointer.
514537
pub fn share(&self) -> Self {
515538
Tensor {

llama-cli/src/cli_args.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,12 @@ pub struct Convert {
373373
pub directory: PathBuf,
374374

375375
/// File type to convert to
376-
#[arg(long, short = 't', value_enum, default_value_t = ElementType::Q4_0)]
377-
pub element_type: ElementType,
376+
#[arg(long, short = 't', value_enum, default_value_t = FileType::Q4_0)]
377+
pub file_type: FileType,
378378
}
379379

380380
#[derive(Parser, Debug, ValueEnum, Clone, Copy)]
381-
pub enum ElementType {
381+
pub enum FileType {
382382
/// Quantized 4-bit (type 0).
383383
Q4_0,
384384
/// Quantized 4-bit (type 1); used by GPTQ.
@@ -388,13 +388,13 @@ pub enum ElementType {
388388
/// Float 32-bit.
389389
F32,
390390
}
391-
impl From<ElementType> for llama_rs::ElementType {
392-
fn from(t: ElementType) -> Self {
391+
impl From<FileType> for llama_rs::FileType {
392+
fn from(t: FileType) -> Self {
393393
match t {
394-
ElementType::Q4_0 => llama_rs::ElementType::Q4_0,
395-
ElementType::Q4_1 => llama_rs::ElementType::Q4_1,
396-
ElementType::F16 => llama_rs::ElementType::F16,
397-
ElementType::F32 => llama_rs::ElementType::F32,
394+
FileType::Q4_0 => llama_rs::FileType::MostlyQ4_0,
395+
FileType::Q4_1 => llama_rs::FileType::MostlyQ4_1,
396+
FileType::F16 => llama_rs::FileType::MostlyF16,
397+
FileType::F32 => llama_rs::FileType::F32,
398398
}
399399
}
400400
}

llama-cli/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fn main() -> Result<()> {
2222
Args::DumpTokens(args) => dump_tokens(&args)?,
2323
Args::Repl(args) => interactive(&args, false)?,
2424
Args::ChatExperimental(args) => interactive(&args, true)?,
25-
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()),
25+
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.file_type.into()),
2626
}
2727

2828
Ok(())

llama-rs/src/convert.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@ use std::{
1616
vec,
1717
};
1818

19-
use crate::{util, Hyperparameters, Vocabulary};
20-
use ggml_loader::util::encode_element_type;
19+
use crate::{loader_common::FileType, util, Hyperparameters, Vocabulary};
2120

2221
/// Converts a `pth` file to a `ggml` file.
23-
pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) {
22+
pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) {
2423
let tokenizer_path = model_directory.parent().unwrap().join("tokenizer.model");
2524
let vocab = load_vocabulary(tokenizer_path.as_path());
2625

27-
let hparams = load_hyperparameters(model_directory, element_type, &vocab);
26+
let hparams = load_hyperparameters(model_directory, file_type, &vocab);
2827

2928
let model_files = util::find_all_model_files(model_directory).unwrap();
3029

3130
for (i, _file) in model_files.iter().enumerate() {
32-
let fname_out = model_directory.join(format!("rust-model-{element_type}.bin"));
31+
let fname_out = model_directory.join(format!("rust-model-{file_type}.bin"));
3332
let mut file = File::create(fname_out).expect("Unable to create file");
3433
write_header(file.borrow_mut(), &hparams).unwrap();
3534
write_tokens(file.borrow_mut(), &vocab).unwrap();
@@ -66,11 +65,7 @@ fn load_vocabulary(path: &Path) -> Vocabulary {
6665
}
6766
}
6867

69-
fn load_hyperparameters(
70-
path: &Path,
71-
element_type: ggml::Type,
72-
vocab: &Vocabulary,
73-
) -> Hyperparameters {
68+
fn load_hyperparameters(path: &Path, file_type: FileType, vocab: &Vocabulary) -> Hyperparameters {
7469
#[derive(Deserialize)]
7570
struct HyperParametersJson {
7671
dim: usize,
@@ -83,7 +78,7 @@ fn load_hyperparameters(
8378
let json = read_to_string(path.join("params.json")).expect("Unable to read file");
8479
let json: HyperParametersJson = serde_json::from_str(&json).expect("Unable to parse json");
8580
Hyperparameters {
86-
element_type,
81+
file_type,
8782
n_ctx: 0,
8883
n_embd: json.dim,
8984
n_head: json.n_heads,
@@ -107,7 +102,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String
107102
i32::try_from(hparams.n_head).unwrap(),
108103
i32::try_from(hparams.n_layer).unwrap(),
109104
i32::try_from(hparams.n_embd / hparams.n_head).unwrap(),
110-
encode_element_type(hparams.element_type).unwrap(),
105+
hparams.file_type.into(),
111106
];
112107
let mut packed_values: Vec<u8> = vec![];
113108

llama-rs/src/inference_session.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl InferenceSession {
6868
.map(|(_, tok)| *tok)
6969
.collect();
7070

71-
if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx {
71+
if self.n_past + prompt_tokens.len() >= model.n_ctx() {
7272
return Err(InferenceError::ContextFull);
7373
}
7474

@@ -96,7 +96,7 @@ impl InferenceSession {
9696
params: &InferenceParameters,
9797
rng: &mut impl rand::Rng,
9898
) -> Result<&'v [u8], InferenceError> {
99-
if self.n_past + 1 >= model.hparams.n_ctx {
99+
if self.n_past + 1 >= model.n_ctx() {
100100
return Err(InferenceError::ContextFull);
101101
}
102102

llama-rs/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use inference_session::{
1919
InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType,
2020
SnapshotError,
2121
};
22-
pub use loader_common::{LoadError, LoadProgress};
22+
pub use loader_common::{FileType, LoadError, LoadProgress};
2323
pub use model::{Hyperparameters, Model};
2424
pub use util::TokenUtf8Buffer;
2525
pub use vocabulary::{TokenBias, TokenId, Vocabulary};

llama-rs/src/loader.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::{
77
};
88

99
use crate::{
10+
loader_common::FileType,
1011
util::{self, mulf},
1112
LoadError, LoadProgress, Model, TokenId, Vocabulary,
1213
};
@@ -69,9 +70,9 @@ pub(crate) fn load(
6970
n_head: read_i32(&mut reader)?.try_into()?,
7071
n_layer: read_i32(&mut reader)?.try_into()?,
7172
n_rot: read_i32(&mut reader)?.try_into()?,
72-
element_type: {
73+
file_type: {
7374
let ftype = read_i32(&mut reader)?;
74-
decode_element_type(ftype).ok_or_else(|| LoadError::UnsupportedElementType(ftype))
75+
FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))
7576
}?,
7677
};
7778

@@ -108,7 +109,13 @@ pub(crate) fn load(
108109
// for the big tensors, we have the option to store the data in 16-bit
109110
// floats or quantized in order to save memory and also to speed up the
110111
// computation
111-
let wtype = hparams.element_type;
112+
let wtype = match hparams.file_type {
113+
FileType::F32 => ggml::Type::F32,
114+
FileType::MostlyF16 => ggml::Type::F16,
115+
FileType::MostlyQ4_0 => ggml::Type::Q4_0,
116+
FileType::MostlyQ4_1 => ggml::Type::Q4_1,
117+
_ => unimplemented!(),
118+
};
112119

113120
let n_embd = hparams.n_embd;
114121
let n_layer = hparams.n_layer;
@@ -159,7 +166,7 @@ pub(crate) fn load(
159166
(None, None)
160167
};
161168

162-
let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type, mmap);
169+
let mut model = Model::new_loader1(context, hparams, vocabulary, n_ff, wtype, mmap);
163170
match model_type {
164171
ContainerType::GGMF | ContainerType::GGML => {
165172
let file_offset = reader.stream_position()?;
@@ -421,7 +428,7 @@ fn load_tensor_header_ggmf<'a>(
421428
}
422429

423430
fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option<usize> {
424-
let ftype = decode_element_type(ftype)?;
431+
let ftype = ggml::Type::try_from(ftype).ok()?;
425432
match ftype {
426433
ElementType::Q4_0 | ElementType::Q4_1 => {
427434
assert_eq!(ne[0] % 64, 0);

0 commit comments

Comments
 (0)