Skip to content

Commit 199724f

Browse files
committed
gguf : track writer state
1 parent da3256e commit 199724f

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

gguf-py/gguf/gguf.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,13 @@ def get_type(val):
462462
sys.exit()
463463

464464

465+
class WriterState:
466+
EMPTY = auto()
467+
HEADER = auto()
468+
KV_DATA = auto()
469+
TI_DATA = auto()
470+
471+
465472
class GGUFWriter:
466473
fout: BufferedWriter
467474
tensors: list[np.ndarray[Any, Any]]
@@ -476,24 +483,37 @@ def __init__(self, path: os.PathLike[str] | str, arch: str):
476483
self.ti_data = b""
477484
self.ti_data_count = 0
478485
self.tensors = []
486+
self.state = WriterState.EMPTY
479487

480488
self.add_architecture()
481489

482490
def write_header_to_file(self):
491+
if self.state is not WriterState.EMPTY:
492+
raise ValueError(f'Expected output file to be empty, got {self.state}')
493+
483494
self.fout.write(struct.pack("<I", GGUF_MAGIC))
484495
self.fout.write(struct.pack("<I", GGUF_VERSION))
485496
self.fout.write(struct.pack("<Q", self.ti_data_count))
486497
self.fout.write(struct.pack("<Q", self.kv_data_count))
487498
self.flush()
488-
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
499+
#print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
500+
self.state = WriterState.HEADER
489501

490502
def write_kv_data_to_file(self):
503+
if self.state is not WriterState.HEADER:
504+
raise ValueError(f'Expected output file to contain the header, got {self.state}')
505+
491506
self.fout.write(self.kv_data)
492507
self.flush()
508+
self.state = WriterState.KV_DATA
493509

494510
def write_ti_data_to_file(self):
511+
if self.state is not WriterState.KV_DATA:
512+
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
513+
495514
self.fout.write(self.ti_data)
496515
self.flush()
516+
self.state = WriterState.TI_DATA
497517

498518
def add_key(self, key: str):
499519
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
@@ -629,6 +649,9 @@ def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
629649
fp.write(bytes([0] * pad))
630650

631651
def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
652+
if self.state is not WriterState.TI_DATA:
653+
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
654+
632655
self.write_padding(self.fout, self.fout.tell())
633656
tensor.tofile(self.fout)
634657
self.write_padding(self.fout, tensor.nbytes)

0 commit comments

Comments
 (0)