Skip to content

Commit 170b992

Browse files
authored
call init_weights before generation (#1371)
Since #1338 the `freqs_cis` buffer is no longer persisted/read in any code path with the intention being that it is re-calculated at the model loading/initialization. However this requires calling `init_weights` on the model, which `scripts/test_generate.py` currently is not doing. As of right now running generation on the pretrained Llama 3 models will result in garbled outputs Convert weights: `python ./scripts/convert_llama_to_dcp.py /home/emozilla/hf/Llama-3-8B/original /home/emozilla/dcp/Llama-3-8B` Run generation: `CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml CHECKPOINT_DIR=/home/emozilla/dcp/Llama-3-8B PROMPT="A long time ago in a galaxy far, far away" ./scripts/generate/run_llama_generate.sh` HEAD ``` <|begin_of_text|>A long time ago in a galaxy far, far away000 centershift Equity KelleyYe требаyrais& Romgraph1Kォ IDEA globalčil at390dagThe,inLikeBelow uptimeRoman_constsBothtz_RATE phủ ``` With fix ``` <|begin_of_text|>A long time ago in a galaxy far, far away… Aspirations were bursting and Jedi were making a big imprint in the arts, in the government, and in our lives.  That was 34 or ```
1 parent 5a26243 commit 170b992

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

scripts/generate/test_generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def test_generate(
140140

141141
# materalize model
142142
model.to_empty(device=device_type)
143+
with torch.no_grad():
144+
model.init_weights()
143145
model.eval()
144146

145147
state_dict = model.state_dict()

0 commit comments

Comments
 (0)