An implementation of a LLaDA-inspired Masked Diffusion Model for text generation, featuring pure byte-level tokenization and mixed-precision training (AMP) for efficiency. This project explores iterative text generation using diffusion, allowing the model to refine its output over multiple steps.
Disclaimer: This software is provided by a high-school student for demonstration, educational, and personal learning purposes. It is provided "AS IS" without warranty of any kind. The author disclaims all warranties, express or implied. This project incorporates and adapts code from other repositories (e.g., RoPE from Llama, MLA), which are referenced where applicable. Documentation and conventions may have inconsistencies; please report any issues.
- LLaDA-Inspired Masked Diffusion: Implements the core concepts of masked diffusion for text, where parts of the input are masked and iteratively unmasked/predicted.
- Pure Byte-Level Tokenization: Treats text as a sequence of raw bytes (0-255), plus special tokens. This simplifies tokenization, avoids out-of-vocabulary issues, and makes the model inherently multilingual at a byte level.
- Mixed Precision Training (AMP): Utilizes
torch.cuda.amp
for faster training and reduced memory footprint on compatible GPUs (bfloat16 or float16). - Rotary Positional Embeddings (RoPE): Employs RoPE for effective relative positional encoding, adapted from implementations like Llama.
- Configurable Transformer Architecture:
- Standard Transformer blocks with pre-normalization (RMSNorm).
- Simplified Multi-Head Latent Attention (MLA) variant.
- Feed-Forward Network with Mish activation (chosen for training benefits over SwiGLU/SiLU, though potentially slower per step).
- Conditional Generation: Supports generating text conditioned on a given prompt during inference.
- Efficient Data Handling: Custom
ByteLevelTextDataset
for loading and chunking byte sequences. - TensorBoard Logging: Integrated for monitoring training loss and learning rate.
Traditional auto-regressive language models generate text token by token from left to right, making it difficult to revise earlier decisions or plan globally. Diffusion models, successful in image and audio synthesis, offer an alternative:
- Iterative Refinement: They generate content by progressively denoising an initial state (e.g., a fully masked sequence) over multiple steps, allowing for global refinement.
- Non-Autoregressive Potential: This iterative nature can lead to better planning and coherence, as the model isn't locked into early choices.
- Simplicity of Byte-Level Processing: Byte-level tokenization removes the complexity of subword tokenizers (like BPE or SentencePiece) and vocabulary limitations, directly processing raw byte streams.
- Python 3.10+
- PyTorch (tested with 2.0+). CUDA is highly recommended for performance.
tqdm
(for progress bars)tensorboard
(for logging, install viapip install tensorboard
)
All configuration parameters are managed within the CONFIG
class in the main Python script.
- Create two plain text files:
train.txt
(or as specified inCONFIG.train_data_file
): Your training data.validation.txt
(or as specified inCONFIG.validation_data_file
): Your validation data.
- The script will automatically create dummy versions if these files are not found, which is useful for initial testing but not for actual training.
- The text should be UTF-8 encoded. The model will process the raw bytes.
To start training, simply run the Python script:
python main.py
- Training progress, loss, and other information will be printed to the console.
- Checkpoints and validation outputs (if configured) will be saved periodically.
Training loss and learning rate are logged to TensorBoard. To view them:
- Open a new terminal in the project directory.
- Ensure your virtual environment is activated.
- Run:
tensorboard --logdir=runs
- Open the URL provided (
http://localhost:6006/
) in your browser.
Currently, inference is primarily demonstrated within the validation loop of the training script (run_conditional_diffusion_inference
function).
- Prompt: A sequence of bytes (text).
- Response Length: The desired length of the text to generate.
- Process: The function iteratively refines a masked sequence conditioned on the prompt over
CONFIG.inference_steps
. - Output: The generated byte sequence, decoded to UTF-8 text (with replacement for invalid bytes).
- Byte-Level Tokenization:
- Pros: Simplicity, no out-of-vocabulary (OOV) issues, handles any UTF-8 character.
- Cons: Sequences become longer compared to subword tokenization.
- Mish Activation in FFN: The
FeedForward
layer usesF.mish(self.w1(x)) * self.w3(x)
instead of the more commonF.silu(self.w1(x)) * self.w3(x)
(SwiGLU/SiLU). Mish might lead to better/faster convergence in some cases, though it is slightly more computationally expensive than SiLU. - Simplified Multi-Head Latent Attention: The implemented
SimplifiedMLAttention
uses a single linear layer to compress the input for K and V, and then separate linear layers to decompress them. This is a simplified version of latent attention rather than a full reproduction of a specific complex MLA architecture (like DeepSeek's which also includes decoupled RoPE and other specificities). - RoPE Implementation: RoPE is applied to the full head dimension after Q/K projections. The
freqs_cis
buffer is precomputed up toCONFIG.rope_max_length_buffer
. - Bidirectional Attention: Consistent with LLaDA's design for a non-autoregressive diffusion model, the attention mechanism is bidirectional (causal masking is not strictly enforced).
- Inference Prompt/Response Lengths: During validation, prompts are generally longer than the generated response length. The
CONFIG.inference_steps
(e.g., 64) provides a good balance for generation quality in tests.
.
├── main.py # Main Python script with model, training, etc.
├── train.txt # Training data file
├── validation.txt # Validation data file
└── runs/ # Directory for TensorBoard logs (tensorboard --logdir=runs)
- Tokenization Change to something more robust as byte-level, while very simple has significant performance degradation. 90%
- Full MLA Implementation: Fully integrate and test the more complex
MultiHeadLatentAttention
75% - Hyperparameter Optimization: Conduct systematic hyperparameter tuning.
- Extensive Evaluation: Implement more evaluation metrics beside validation loss.
- Standalone Inference Script: Create a dedicated script for inference with a trained model. 75%
- Dataset Handling: More robust dataset handling, streaming for very large datasets. 100% ✅
- Detailed Profiling: Profile different components (attention, FFN, data loading) to identify bottlenecks.
- Inspired by the LLaDA paper
- Rotary Positional Embeddings (RoPE)
- Multi-Head Latent Attention (MLA)
- The author's self-research and learning journey.
@misc{Simple Masked Diffusion Language Model,
author = nuni-neomu-areumdawo,
title = {Open-Source Masked Diffusion Language Model},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/nuni-neomu-areumdawo/Diffusion-Language-Model}}
}