@@ -462,6 +462,13 @@ def get_type(val):
462
462
sys .exit ()
463
463
464
464
465
+ class WriterState :
466
+ EMPTY = auto ()
467
+ HEADER = auto ()
468
+ KV_DATA = auto ()
469
+ TI_DATA = auto ()
470
+
471
+
465
472
class GGUFWriter :
466
473
fout : BufferedWriter
467
474
tensors : list [np .ndarray [Any , Any ]]
@@ -476,24 +483,37 @@ def __init__(self, path: os.PathLike[str] | str, arch: str):
476
483
self .ti_data = b""
477
484
self .ti_data_count = 0
478
485
self .tensors = []
486
+ self .state = WriterState .EMPTY
479
487
480
488
self .add_architecture ()
481
489
482
490
def write_header_to_file (self ):
491
+ if self .state is not WriterState .EMPTY :
492
+ raise ValueError (f'Expected output file to be empty, got { self .state } ' )
493
+
483
494
self .fout .write (struct .pack ("<I" , GGUF_MAGIC ))
484
495
self .fout .write (struct .pack ("<I" , GGUF_VERSION ))
485
496
self .fout .write (struct .pack ("<Q" , self .ti_data_count ))
486
497
self .fout .write (struct .pack ("<Q" , self .kv_data_count ))
487
498
self .flush ()
488
- # print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
499
+ #print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
500
+ self .state = WriterState .HEADER
489
501
490
502
def write_kv_data_to_file (self ):
503
+ if self .state is not WriterState .HEADER :
504
+ raise ValueError (f'Expected output file to contain the header, got { self .state } ' )
505
+
491
506
self .fout .write (self .kv_data )
492
507
self .flush ()
508
+ self .state = WriterState .KV_DATA
493
509
494
510
def write_ti_data_to_file (self ):
511
+ if self .state is not WriterState .KV_DATA :
512
+ raise ValueError (f'Expected output file to contain KV data, got { self .state } ' )
513
+
495
514
self .fout .write (self .ti_data )
496
515
self .flush ()
516
+ self .state = WriterState .TI_DATA
497
517
498
518
def add_key (self , key : str ):
499
519
self .add_val (key , GGUFValueType .STRING , add_vtype = False )
@@ -629,6 +649,9 @@ def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
629
649
fp .write (bytes ([0 ] * pad ))
630
650
631
651
def write_tensor_data (self , tensor : np .ndarray [Any , Any ]):
652
+ if self .state is not WriterState .TI_DATA :
653
+ raise ValueError (f'Expected output file to contain tensor info, got { self .state } ' )
654
+
632
655
self .write_padding (self .fout , self .fout .tell ())
633
656
tensor .tofile (self .fout )
634
657
self .write_padding (self .fout , tensor .nbytes )
0 commit comments