Skip to content

Commit ab14019

Browse files
authored
Support diffusion models: Add Dream 7B (#14644)
* Support diffusion models: Add Dream 7B * Move diffusion to examples * Move stuff to examples. Add patch to not use kv-cache * Address review comments * Make sampling fast * llama: remove diffusion functions * Add basic timings + cleanup * More cleanup * Review comments: better formating, use LOG instead std::cerr, re-use batch, use ubatch instead of max_length * fixup! * Review: move everything to diffusion-cli for now
1 parent 6497834 commit ab14019

File tree

13 files changed

+804
-0
lines changed

13 files changed

+804
-0
lines changed

common/arg.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3423,5 +3423,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34233423
}
34243424
).set_examples({LLAMA_EXAMPLE_SERVER}));
34253425

3426+
// diffusion parameters
3427+
add_opt(common_arg(
3428+
{ "--diffusion-steps" }, "N",
3429+
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
3430+
[](common_params & params, int value) { params.diffusion.steps = value; }
3431+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3432+
add_opt(common_arg(
3433+
{ "--diffusion-eps" }, "F",
3434+
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
3435+
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
3436+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3437+
add_opt(common_arg(
3438+
{ "--diffusion-algorithm" }, "N",
3439+
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
3440+
params.diffusion.algorithm),
3441+
[](common_params & params, int value) { params.diffusion.algorithm = value; }
3442+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3443+
add_opt(common_arg(
3444+
{ "--diffusion-alg-temp" }, "F",
3445+
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3446+
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
3447+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3448+
add_opt(common_arg(
3449+
{ "--diffusion-visual" },
3450+
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3451+
params.diffusion.visual_mode ? "true" : "false"),
3452+
[](common_params & params) { params.diffusion.visual_mode = true; }
3453+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3454+
34263455
return ctx_arg;
34273456
}

common/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum llama_example {
8181
LLAMA_EXAMPLE_LOOKUP,
8282
LLAMA_EXAMPLE_PARALLEL,
8383
LLAMA_EXAMPLE_TTS,
84+
LLAMA_EXAMPLE_DIFFUSION,
8485

8586
LLAMA_EXAMPLE_COUNT,
8687
};
@@ -218,6 +219,14 @@ struct common_params_vocoder {
218219
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
219220
};
220221

222+
struct common_params_diffusion {
223+
int32_t steps = 64; // number of diffusion steps
224+
float eps = 1e-3f; // epsilon for timesteps
225+
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
226+
float alg_temp = 0.0f; // algorithm temperature
227+
bool visual_mode = false; // show progressive diffusion on screen
228+
};
229+
221230
enum common_reasoning_format {
222231
COMMON_REASONING_FORMAT_NONE,
223232
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@@ -269,6 +278,7 @@ struct common_params {
269278
struct common_params_sampling sampling;
270279
struct common_params_speculative speculative;
271280
struct common_params_vocoder vocoder;
281+
struct common_params_diffusion diffusion;
272282

273283
struct common_params_model model;
274284

convert_hf_to_gguf.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,6 +2778,76 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27782778
yield from super().modify_tensors(data_torch, name, bid)
27792779

27802780

2781+
@ModelBase.register("DreamModel")
2782+
class DreamModel(TextModel):
2783+
model_arch = gguf.MODEL_ARCH.DREAM
2784+
2785+
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
2786+
tokens: list[str] = []
2787+
toktypes: list[int] = []
2788+
2789+
from transformers import AutoTokenizer
2790+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
2791+
2792+
vocab_dict = tokenizer.get_vocab()
2793+
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
2794+
assert max(vocab_dict.values()) < vocab_size
2795+
2796+
tokpre = self.get_vocab_base_pre(tokenizer)
2797+
2798+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
2799+
added_vocab = tokenizer.get_added_vocab()
2800+
2801+
for i in range(vocab_size):
2802+
if i not in reverse_vocab:
2803+
tokens.append(f"[PAD{i}]")
2804+
toktypes.append(gguf.TokenType.UNUSED)
2805+
elif reverse_vocab[i] in added_vocab:
2806+
tokens.append(reverse_vocab[i])
2807+
# Check if it's a special token - treat special tokens as CONTROL tokens
2808+
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
2809+
if tokenizer.added_tokens_decoder[i].special:
2810+
toktypes.append(gguf.TokenType.CONTROL)
2811+
else:
2812+
toktypes.append(gguf.TokenType.USER_DEFINED)
2813+
else:
2814+
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
2815+
toktypes.append(gguf.TokenType.CONTROL)
2816+
else:
2817+
tokens.append(reverse_vocab[i])
2818+
toktypes.append(gguf.TokenType.NORMAL)
2819+
2820+
return tokens, toktypes, tokpre
2821+
2822+
def set_vocab(self):
2823+
try:
2824+
self._set_vocab_sentencepiece()
2825+
except FileNotFoundError:
2826+
self._set_vocab_gpt2()
2827+
2828+
def set_gguf_parameters(self):
2829+
super().set_gguf_parameters()
2830+
self._try_set_pooling_type()
2831+
2832+
# Dream models use non-causal attention for diffusion
2833+
self.gguf_writer.add_causal_attention(False)
2834+
# Handle RoPE scaling similar to Qwen2
2835+
rope_scaling = self.hparams.get("rope_scaling") or {}
2836+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
2837+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2838+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
2839+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
2840+
2841+
# Add Dream-specific parameters
2842+
mask_token_id = self.hparams.get("mask_token_id")
2843+
if mask_token_id is not None:
2844+
self.gguf_writer.add_mask_token_id(mask_token_id)
2845+
2846+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2847+
# Dream model tensors should be mapped directly since it's the base model
2848+
yield from super().modify_tensors(data_torch, name, bid)
2849+
2850+
27812851
@ModelBase.register("Ernie4_5_ForCausalLM")
27822852
class Ernie4_5Model(TextModel):
27832853
model_arch = gguf.MODEL_ARCH.ERNIE4_5

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ else()
3333
add_subdirectory(speculative-simple)
3434
add_subdirectory(gen-docs)
3535
add_subdirectory(training)
36+
add_subdirectory(diffusion)
3637
if (NOT GGML_BACKEND_DL)
3738
add_subdirectory(convert-llama2c-to-ggml)
3839
# these examples use the backends directly and cannot be built with dynamic loading

examples/diffusion/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-diffusion-cli)
2+
add_executable(${TARGET} diffusion-cli.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

0 commit comments

Comments
 (0)