Skip to content

Implementation of a LLaDA-inspired Masked Diffusion Model for Text using PURE BYTE-LEVEL TOKENIZATION (cuz why not) and Mixed Precision Training for speed.

License

Notifications You must be signed in to change notification settings

nuni-neomu-areumdawo/Diffusion-Language-Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Byte-Level Masked Diffusion Language Model (LLaDA-Inspired)

An implementation of a LLaDA-inspired Masked Diffusion Model for text generation, featuring pure byte-level tokenization and mixed-precision training (AMP) for efficiency. This project explores iterative text generation using diffusion, allowing the model to refine its output over multiple steps.

Disclaimer: This software is provided by a high-school student for demonstration, educational, and personal learning purposes. It is provided "AS IS" without warranty of any kind. The author disclaims all warranties, express or implied. This project incorporates and adapts code from other repositories (e.g., RoPE from Llama, MLA), which are referenced where applicable. Documentation and conventions may have inconsistencies; please report any issues.

Key Features

  • LLaDA-Inspired Masked Diffusion: Implements the core concepts of masked diffusion for text, where parts of the input are masked and iteratively unmasked/predicted.
  • Pure Byte-Level Tokenization: Treats text as a sequence of raw bytes (0-255), plus special tokens. This simplifies tokenization, avoids out-of-vocabulary issues, and makes the model inherently multilingual at a byte level.
  • Mixed Precision Training (AMP): Utilizes torch.cuda.amp for faster training and reduced memory footprint on compatible GPUs (bfloat16 or float16).
  • Rotary Positional Embeddings (RoPE): Employs RoPE for effective relative positional encoding, adapted from implementations like Llama.
  • Configurable Transformer Architecture:
    • Standard Transformer blocks with pre-normalization (RMSNorm).
    • Simplified Multi-Head Latent Attention (MLA) variant.
    • Feed-Forward Network with Mish activation (chosen for training benefits over SwiGLU/SiLU, though potentially slower per step).
  • Conditional Generation: Supports generating text conditioned on a given prompt during inference.
  • Efficient Data Handling: Custom ByteLevelTextDataset for loading and chunking byte sequences.
  • TensorBoard Logging: Integrated for monitoring training loss and learning rate.

Why This Approach?

Traditional auto-regressive language models generate text token by token from left to right, making it difficult to revise earlier decisions or plan globally. Diffusion models, successful in image and audio synthesis, offer an alternative:

  1. Iterative Refinement: They generate content by progressively denoising an initial state (e.g., a fully masked sequence) over multiple steps, allowing for global refinement.
  2. Non-Autoregressive Potential: This iterative nature can lead to better planning and coherence, as the model isn't locked into early choices.
  3. Simplicity of Byte-Level Processing: Byte-level tokenization removes the complexity of subword tokenizers (like BPE or SentencePiece) and vocabulary limitations, directly processing raw byte streams.

Prerequisites

  • Python 3.10+
  • PyTorch (tested with 2.0+). CUDA is highly recommended for performance.
  • tqdm (for progress bars)
  • tensorboard (for logging, install via pip install tensorboard)

Configuration

All configuration parameters are managed within the CONFIG class in the main Python script.

Usage

1. Data Preparation

  • Create two plain text files:
    • train.txt (or as specified in CONFIG.train_data_file): Your training data.
    • validation.txt (or as specified in CONFIG.validation_data_file): Your validation data.
  • The script will automatically create dummy versions if these files are not found, which is useful for initial testing but not for actual training.
  • The text should be UTF-8 encoded. The model will process the raw bytes.

2. Training

To start training, simply run the Python script:

python main.py
  • Training progress, loss, and other information will be printed to the console.
  • Checkpoints and validation outputs (if configured) will be saved periodically.

3. Monitoring (TensorBoard)

Training loss and learning rate are logged to TensorBoard. To view them:

  1. Open a new terminal in the project directory.
  2. Ensure your virtual environment is activated.
  3. Run:
    tensorboard --logdir=runs
  4. Open the URL provided (http://localhost:6006/) in your browser.

Inference

Currently, inference is primarily demonstrated within the validation loop of the training script (run_conditional_diffusion_inference function).

  • Prompt: A sequence of bytes (text).
  • Response Length: The desired length of the text to generate.
  • Process: The function iteratively refines a masked sequence conditioned on the prompt over CONFIG.inference_steps.
  • Output: The generated byte sequence, decoded to UTF-8 text (with replacement for invalid bytes).

Design Choices & Technical Notes

  • Byte-Level Tokenization:
    • Pros: Simplicity, no out-of-vocabulary (OOV) issues, handles any UTF-8 character.
    • Cons: Sequences become longer compared to subword tokenization.
  • Mish Activation in FFN: The FeedForward layer uses F.mish(self.w1(x)) * self.w3(x) instead of the more common F.silu(self.w1(x)) * self.w3(x) (SwiGLU/SiLU). Mish might lead to better/faster convergence in some cases, though it is slightly more computationally expensive than SiLU.
  • Simplified Multi-Head Latent Attention: The implemented SimplifiedMLAttention uses a single linear layer to compress the input for K and V, and then separate linear layers to decompress them. This is a simplified version of latent attention rather than a full reproduction of a specific complex MLA architecture (like DeepSeek's which also includes decoupled RoPE and other specificities).
  • RoPE Implementation: RoPE is applied to the full head dimension after Q/K projections. The freqs_cis buffer is precomputed up to CONFIG.rope_max_length_buffer.
  • Bidirectional Attention: Consistent with LLaDA's design for a non-autoregressive diffusion model, the attention mechanism is bidirectional (causal masking is not strictly enforced).
  • Inference Prompt/Response Lengths: During validation, prompts are generally longer than the generated response length. The CONFIG.inference_steps (e.g., 64) provides a good balance for generation quality in tests.

File Structure

.
├── main.py                     # Main Python script with model, training, etc.
├── train.txt                   # Training data file
├── validation.txt              # Validation data file
└── runs/                       # Directory for TensorBoard logs (tensorboard --logdir=runs)

Ideas

  • Tokenization Change to something more robust as byte-level, while very simple has significant performance degradation. 90%
  • Full MLA Implementation: Fully integrate and test the more complex MultiHeadLatentAttention 75%
  • Hyperparameter Optimization: Conduct systematic hyperparameter tuning.
  • Extensive Evaluation: Implement more evaluation metrics beside validation loss.
  • Standalone Inference Script: Create a dedicated script for inference with a trained model. 75%
  • Dataset Handling: More robust dataset handling, streaming for very large datasets. 100% ✅
  • Detailed Profiling: Profile different components (attention, FFN, data loading) to identify bottlenecks.

Acknowledgements & References

  • Inspired by the LLaDA paper
  • Rotary Positional Embeddings (RoPE)
  • Multi-Head Latent Attention (MLA)
  • The author's self-research and learning journey.

Citation

@misc{Simple Masked Diffusion Language Model,
  author       = nuni-neomu-areumdawo,
  title        = {Open-Source Masked Diffusion Language Model},
  year         = {2025},
  publisher    = {GitHub},
  journal      = {GitHub repository},
  howpublished = {\url{https://github.com/nuni-neomu-areumdawo/Diffusion-Language-Model}}
}

About

Implementation of a LLaDA-inspired Masked Diffusion Model for Text using PURE BYTE-LEVEL TOKENIZATION (cuz why not) and Mixed Precision Training for speed.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages