Custom pipeline parallelism implementation for Llama3.1-70B enabling mechanistic interpretability research
This project presents a custom implementation of pipeline parallelism for the Llama3.1-70B model, designed specifically to enable mechanistic interpretability research. Unlike existing high-level frameworks such as vLLM and DeepSpeed, which abstract away layer-level access, our implementation provides direct access to individual transformer layers while maintaining efficient distributed inference across multiple GPUs.
- Layer-wise Access: Direct access to hidden states after each transformer layer
- Intervention Capabilities: Ability to modify representations at any pipeline stage
- Debugging Transparency: Complete visibility into tensor shapes and processing flow
- Custom Analysis: Freedom to implement custom probing and analysis tools
- Memory Efficiency: Enables research on large models without requiring prohibitive hardware
- 4-GPU Pipeline: Optimized for 4-GPU configurations with 80 transformer layers
The implementation targets a 4-GPU configuration where the Llama3.1-70B model (with 80 transformer layers) is distributed using pipeline parallelism:
GPU | Components |
---|---|
GPU 0 | Layers 0-19 + Embeddings + Rotary Embeddings |
GPU 1 | Layers 20-39 |
GPU 2 | Layers 40-59 |
GPU 3 | Layers 60-79 + Layer Norm + Language Model Head |
The device mapping assigns specific model components to GPUs while offloading unused components to disk. The model loading process bypasses CPU RAM entirely through Hugging Face Accelerate's intelligent dispatching mechanism:
- Meta Device Initialization: Model structure is initialized without allocating memory for weights
- Direct GPU Loading: Weights are loaded directly from disk to designated GPUs
- Disk Offloading: Unused components are offloaded to fast NVMe storage
- Python 3.8 or higher
- CUDA-capable GPUs (4 GPUs recommended)
- PyTorch 2.0+
- Transformers 4.30+
- Accelerate 0.20+
git clone https://github.com/your-username/pipeline-parallel-llama.git
cd pipeline-parallel-llama
pip install -e .
pip install -r requirements.txt
# Run distributed inference with 4 GPUs
torchrun --nproc_per_node=4 -m pipeline_parallel_llama.cli inference \
--model-path /path/to/llama-3.1-70b \
--prompt "The first number in this list [34, 56, 78] is: "
from pipeline_parallel_llama import (
setup_distributed,
load_model_shard,
generate_pipeline
)
# Initialize distributed environment
rank, world_size, local_rank, device = setup_distributed()
# Load model shard for current GPU
model, tokenizer, start_layer, end_layer, offload_dir, rotary_emb = load_model_shard(
model_path="/path/to/llama-3.1-70b",
local_rank=local_rank,
world_size=world_size
)
# Run pipeline inference
result = generate_pipeline(
model=model,
tokenizer=tokenizer,
start_layer=start_layer,
end_layer=end_layer,
prompt="Your prompt here",
rank=rank,
world_size=world_size,
device=device,
rotary_emb=rotary_emb
)
One of the most significant technical challenges was handling rotary position embeddings when models are sharded across devices. The LlamaRotaryEmbedding
module becomes inaccessible to downstream GPUs, causing pipeline failures.
Solution: We implemented a standalone rotary embedding accessible to all ranks:
def create_rotary_embedding(config, device):
"""Create a standalone rotary embedding for computing position embeddings."""
return LlamaRotaryEmbedding(config=config, device=device)
The pipeline requires coordinated communication between GPUs. Each rank receives activations from the previous stage, processes them through assigned layers, and forwards them to the next stage:
def send_tensors_to_next_rank(hidden_states, input_ids, position_embeddings,
batch_size, seq_len, device, dst_rank):
"""Send tensors to the next rank in the pipeline."""
# Send shape information first
shape_tensor = torch.tensor([batch_size, seq_len], dtype=torch.long, device=device)
dist.send(shape_tensor, dst=dst_rank)
# Send actual data
dist.send(hidden_states.contiguous().detach(), dst=dst_rank)
dist.send(input_ids.contiguous().detach(), dst=dst_rank)
# ... position embeddings
Each GPU processes its assigned layers while maintaining complete access to intermediate representations:
def forward_layers(model, hidden_states, input_ids, start_layer, end_layer,
rank, device, position_embeddings):
"""Forward pass through model layers on current rank."""
for i in range(start_layer, end_layer):
layer = model.model.layers[i]
layer_outputs = layer(
hidden_states,
position_ids=position_ids,
position_embeddings=position_embeddings,
# ... other parameters
)
hidden_states = layer_outputs[0]
# Full visibility into each layer's output
return hidden_states
The implementation successfully demonstrates end-to-end inference with the following characteristics:
- Total inference time: 4.43 seconds for single token generation
- Memory efficiency: ~34-36GB allocated per GPU (within 94.5GB capacity)
- Functional accuracy: Correctly generates expected outputs
[R0] Input IDs: tensor([[128000, 791, 1176, ...]], device='cuda:0'), Shape: torch.Size([1, 19])
[R0] Embedded: torch.Size([1, 19, 8192])
[R0] Position embeddings computed: cos=torch.Size([1, 19, 128]), sin=torch.Size([1, 19, 128])
[R0] Processing layers 0-19...
[R1] Receiving from GPU 0, processing layers 20-39...
[R2] Receiving from GPU 1, processing layers 40-59...
[R3] Receiving from GPU 2, processing layers 60-79...
[R3] Generated token ID: 1958
[R0] Generated token: '34'
[R0] Final result: '34'
[R0] Time taken: 4.43s
This implementation provides several key advantages for interpretability research:
- Layer-wise Access: Direct access to hidden states after each transformer layer
- Intervention Capabilities: Ability to modify representations at any pipeline stage
- Debugging Transparency: Complete visibility into tensor shapes and processing flow
- Custom Analysis: Freedom to implement custom probing and analysis tools
- Memory Efficiency: Enables research on large models without requiring prohibitive hardware
PIPELINE_LLAMA_MODEL_PATH
: Path to the Llama model directoryPIPELINE_LLAMA_PROMPT
: Default prompt for inferenceLOCAL_RANK
: GPU rank (set automatically by torchrun)
Adjust GPU memory limits in the model loading:
max_memory = {0: "80GB", 1: "80GB", 2: "80GB", 3: "80GB"}
- Single Token Generation: Current implementation focuses on single-token generation
- No KV Caching: Lacks key-value caching for multi-token generation efficiency
- Static Pipeline: Fixed 4-GPU configuration without dynamic load balancing
- Implementation of dynamic key-value caching for multi-token generation
- Support for variable GPU configurations
- Integration with interpretability tools like TransformerLens
- Batched inference capabilities
- Memory optimization for longer sequences
We welcome contributions! Please see our Contributing Guidelines for details.
git clone https://github.com/your-username/pipeline-parallel-llama.git
cd pipeline-parallel-llama
pip install -e ".[dev]"
pytest tests/
If you use this work in your research, please cite:
@misc{guiomar2024pipeline,
title={Reverse Engineering a Pipeline Parallel Llama3.1-70B with transformers, accelerate and torch.distributed},
author={Guiomar, Gonçalo},
year={2024},
institution={ETH AI Center}
}
This project is licensed under the MIT License - see the LICENSE file for details.
- Author: Gonçalo Guiomar, ETH AI Center Fellow
- Institution: ETH AI Center
- Framework Dependencies: PyTorch, Transformers, Accelerate
For questions and support:
- Create an issue on GitHub
- Email: goncalo.guiomar@ai.ethz.ch
Note: This is a research prototype designed for mechanistic interpretability studies. While functional, it may require adaptation for production use cases.