Skip to content

Commit 07c252f

Browse files
committed
model : add Jamba to Mamba-specific hparams printing
1 parent 20f8e43 commit 07c252f

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/llama-model.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4842,16 +4842,6 @@ void llama_model::print_info() const {
48424842
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
48434843
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
48444844
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
4845-
}
4846-
4847-
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
4848-
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
4849-
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
4850-
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4851-
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4852-
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
4853-
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4854-
48554845
if (!classifier_labels.empty()) {
48564846
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
48574847

@@ -4862,6 +4852,15 @@ void llama_model::print_info() const {
48624852
}
48634853
}
48644854

4855+
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) {
4856+
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
4857+
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
4858+
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4859+
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4860+
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
4861+
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4862+
}
4863+
48654864
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
48664865
if (pimpl->n_elements >= 1e12) {
48674866
LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);

0 commit comments

Comments
 (0)