Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit c9e5c26

Browse files
committed
refactor: decouple loading from model
1 parent ecb9175 commit c9e5c26

File tree

2 files changed

+91
-110
lines changed

2 files changed

+91
-110
lines changed

llama-rs/src/loader2.rs

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use memmap2::Mmap;
55
use std::{
66
collections::HashMap,
77
fs::File,
8-
io::{BufRead, BufReader, Seek},
8+
io::{BufRead, BufReader, Read, Seek},
99
ops::ControlFlow,
1010
path::{Path, PathBuf},
1111
};
1212

1313
use crate::{
14-
loader_common::FileType, util, Hyperparameters, LoadError, LoadProgress, Model, TokenId,
15-
Vocabulary,
14+
loader_common::FileType, model::TensorLoader, util, Hyperparameters, LoadError, LoadProgress,
15+
Model, TokenId, Vocabulary,
1616
};
1717

1818
impl LoadError {
@@ -48,7 +48,7 @@ pub(crate) fn load(
4848
return Err(LoadError::MultipartNotSupported { paths });
4949
}
5050

51-
let mut file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
51+
let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
5252
source: e,
5353
path: main_path.to_owned(),
5454
})?;
@@ -102,28 +102,86 @@ pub(crate) fn load(
102102
None
103103
};
104104

105-
let model = Model::new_loader2(
105+
struct TensorLoader2<'a> {
106+
path: PathBuf,
107+
file: File,
108+
tensors: HashMap<String, TensorInfo>,
109+
context: ggml::Context,
110+
mmap: Option<Mmap>,
111+
load_progress_callback: &'a mut dyn FnMut(LoadProgress),
112+
loaded_tensors: HashMap<String, ggml::Tensor>,
113+
}
114+
impl TensorLoader<LoadError> for TensorLoader2<'_> {
115+
fn load(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
116+
let info = self
117+
.tensors
118+
.get(name)
119+
.ok_or_else(|| LoadError::UnknownTensor {
120+
path: self.path.clone(),
121+
tensor_name: name.to_owned(),
122+
})?;
123+
124+
let ctx = &self.context;
125+
let mut tensor = match ne.len() {
126+
1 => ctx.new_tensor_1d(info.element_type, ne[0]),
127+
2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
128+
3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
129+
_ => {
130+
return Err(LoadError::InvariantBroken {
131+
path: self.path.clone(),
132+
invariant: format!(
133+
"the tensor {name} had an unsupported dimension count: {ne:?}"
134+
),
135+
})
136+
}
137+
};
138+
139+
match self.mmap.as_ref() {
140+
Some(mmap) => unsafe {
141+
let ptr = mmap.as_ptr().offset(info.start_offset as isize);
142+
tensor.set_data(ptr as *mut std::ffi::c_void);
143+
},
144+
None => {
145+
let buf: &mut [u8] = unsafe {
146+
std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes())
147+
};
148+
self.file.seek(SeekFrom::Start(info.start_offset))?;
149+
self.file.read_exact(buf)?;
150+
}
151+
}
152+
153+
self.loaded_tensors.insert(name.to_owned(), tensor.share());
154+
(self.load_progress_callback)(LoadProgress::PartTensorLoaded {
155+
file: &self.path,
156+
current_tensor: self.loaded_tensors.len(),
157+
tensor_count: self.tensors.len(),
158+
});
159+
160+
Ok(tensor)
161+
}
162+
163+
fn finish(self) -> (ggml::Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
164+
(self.context, self.loaded_tensors, self.mmap)
165+
}
166+
}
167+
168+
let tensors_len = tensors.len();
169+
let tl = TensorLoader2 {
170+
path: path.clone(),
171+
file,
172+
tensors,
106173
context,
107-
hyperparameters,
108-
vocabulary,
109-
n_ff,
110-
path.clone(),
111-
&mut file,
112-
&tensors,
113174
mmap,
114-
|tensor_index| {
115-
(load_progress_callback)(LoadProgress::PartTensorLoaded {
116-
file: &path,
117-
current_tensor: tensor_index,
118-
tensor_count: tensors.len(),
119-
});
120-
},
121-
)?;
175+
load_progress_callback: &mut load_progress_callback,
176+
loaded_tensors: Default::default(),
177+
};
178+
179+
let model = Model::new_loader2(hyperparameters, vocabulary, n_ff, tl)?;
122180

123181
(load_progress_callback)(LoadProgress::PartLoaded {
124182
file: &path,
125183
byte_size: 0,
126-
tensor_count: tensors.len(),
184+
tensor_count: tensors_len,
127185
});
128186

129187
Ok(model)

llama-rs/src/model.rs

Lines changed: 13 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
use std::{
2-
collections::HashMap,
3-
fs::File,
4-
io::{Read, Seek, SeekFrom},
5-
path::{Path, PathBuf},
6-
};
1+
use std::{collections::HashMap, error::Error, path::Path};
72

83
use crate::{
94
loader, loader2, loader_common::FileType, vocabulary::TokenId, EvaluateOutputRequest,
@@ -12,8 +7,6 @@ use crate::{
127
};
138
use memmap2::Mmap;
149

15-
use ggml_loader::TensorInfo;
16-
1710
/// The weights for the LLaMA model. All the mutable state is split into a
1811
/// separate struct `InferenceSession`.
1912
pub struct Model {
@@ -117,92 +110,17 @@ impl Model {
117110
}
118111
}
119112

120-
#[allow(clippy::too_many_arguments)]
121-
pub(crate) fn new_loader2(
122-
context: ggml::Context,
113+
pub(crate) fn new_loader2<E: Error>(
123114
hyperparameters: Hyperparameters,
124115
vocabulary: Vocabulary,
125116
n_ff: usize,
126-
path: PathBuf,
127-
file: &mut File,
128-
tensors: &HashMap<String, TensorInfo>,
129-
mmap: Option<Mmap>,
130-
progress_callback: impl FnMut(usize),
131-
) -> Result<Model, LoadError> {
117+
tensor_loader: impl TensorLoader<E>,
118+
) -> Result<Model, E> {
132119
let n_embd = hyperparameters.n_embd;
133120
let n_layer = hyperparameters.n_layer;
134121
let n_vocab = hyperparameters.n_vocab;
135122

136-
struct TensorLoader<'a, F: FnMut(usize)> {
137-
// Input
138-
path: PathBuf,
139-
file: &'a mut File,
140-
tensors: &'a HashMap<String, TensorInfo>,
141-
context: &'a ggml::Context,
142-
mmap_ptr: Option<*const u8>,
143-
progress_callback: F,
144-
145-
// Output
146-
loaded_tensors: HashMap<String, ggml::Tensor>,
147-
}
148-
impl<F: FnMut(usize)> TensorLoader<'_, F> {
149-
fn load(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
150-
let info = self
151-
.tensors
152-
.get(name)
153-
.ok_or_else(|| LoadError::UnknownTensor {
154-
path: self.path.clone(),
155-
tensor_name: name.to_owned(),
156-
})?;
157-
158-
let ctx = self.context;
159-
let mut tensor = match ne.len() {
160-
1 => ctx.new_tensor_1d(info.element_type, ne[0]),
161-
2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
162-
3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
163-
_ => {
164-
return Err(LoadError::InvariantBroken {
165-
path: self.path.clone(),
166-
invariant: format!(
167-
"the tensor {name} had an unsupported dimension count: {ne:?}"
168-
),
169-
})
170-
}
171-
};
172-
173-
match self.mmap_ptr {
174-
Some(mmap) => unsafe {
175-
let ptr = mmap.offset(info.start_offset as isize);
176-
tensor.set_data(ptr as *mut std::ffi::c_void);
177-
},
178-
None => {
179-
let buf: &mut [u8] = unsafe {
180-
std::slice::from_raw_parts_mut(
181-
tensor.data() as *mut u8,
182-
tensor.nbytes(),
183-
)
184-
};
185-
self.file.seek(SeekFrom::Start(info.start_offset))?;
186-
self.file.read_exact(buf)?;
187-
}
188-
}
189-
190-
self.loaded_tensors.insert(name.to_owned(), tensor.share());
191-
(self.progress_callback)(self.loaded_tensors.len());
192-
193-
Ok(tensor)
194-
}
195-
}
196-
let mut tl = TensorLoader {
197-
path,
198-
file,
199-
tensors,
200-
context: &context,
201-
mmap_ptr: mmap.as_ref().map(|m| m.as_ptr()),
202-
progress_callback,
203-
204-
loaded_tensors: Default::default(),
205-
};
123+
let mut tl = tensor_loader;
206124

207125
let tok_embeddings = tl.load("tok_embeddings.weight", &[n_embd, n_vocab])?;
208126
let norm = tl.load("norm.weight", &[n_embd])?;
@@ -246,7 +164,7 @@ impl Model {
246164
layers.push(layer);
247165
}
248166

249-
let tensors = tl.loaded_tensors;
167+
let (_context, tensors, _mmap) = tl.finish();
250168

251169
Ok(Model {
252170
hyperparameters,
@@ -256,8 +174,8 @@ impl Model {
256174
output,
257175
layers,
258176
tensors,
259-
_context: context,
260-
_mmap: mmap,
177+
_context,
178+
_mmap,
261179
})
262180
}
263181

@@ -626,6 +544,11 @@ pub struct Hyperparameters {
626544
pub file_type: FileType,
627545
}
628546

547+
pub(crate) trait TensorLoader<E: Error> {
548+
fn load(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, E>;
549+
fn finish(self) -> (ggml::Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
550+
}
551+
629552
struct Layer {
630553
attention_norm: ggml::Tensor,
631554

0 commit comments

Comments
 (0)