Skip to content

Add Multi-head Latent Attention (DeepSeekv2) #1945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ checkpoints
out
wandb
events.out.tfevents*
evaluate_model
test.ipynb
mla_test.py

# test artifacts from tests/test_readme.py
**/custom_finetuning_dataset.json
Expand Down
52 changes: 52 additions & 0 deletions README_MLA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Multi-Head Latent Attention (MLA)

## Overview
This document outlines the modifications made to the codebase in the `litgpt` repository to add support for Multi-Head Latent Attention (MLA) block from [DeepSeekV2](https://arxiv.org/abs/2405.04434).

## Changes Made
1. **Configuration**: Added `latent_attention: Optional[bool] = False` parameter to the configuration file to enable the MLA block.
2. **MLA module**: Implemented the MLA module as a separate component in the `litgpt` codebase.
3. **KVCacheCompressed**: Added support for the `KVCacheCompressed` class to store the key-value pairs for the MLA block.
4. **Model**: Modified the GPT model to include the **MLA block** as an alternative component based on the configuration parameter `latent_attention`.
5. **Training**: Updated the training script to support the MLA block and added support for training with the new configuration file `config_hub/pretrain/cfg.yaml`.

## Installation
Follow the updated installation instructions in the `README.md` file.

## Usage
1. **Configuration**: Set the `latent_attention` parameter to `True` in the configuration file to enable the MLA block.
2. **Training**: Run the training script with the updated configuration file.
```bash
litgpt pretrain --config config_hub/pretrain/cfg.yaml
```
3. **Inference**: Use the trained model for inference as follows:
```bash
litgpt generate out/pretrain/mla/final/
```

## Results
Results are available at [this link](https://docs.google.com/spreadsheets/d/1-VnTDoK5JuNPGMjory_z1hQkI7y-RgiTpTsUpa3bVEg/edit?usp=sharing).

The results highlight that MQA and GQA considerably reduce memory usage and increase the speed of training. However, this comes at the cost of a significant decrease in performance compared to the baseline model.

The MLA block demonstrates a better trade-off between memory usage, speed, and performance. It shows a slight drop in performance compared to the baseline model, while also reducing memory usage. This also comes with a slight increase in training and inference speed. Smaller projection dimensions have been tested for the MLA block, showing a consistent reduction of memory usage but with a significant drop in performance.

Overall, results are not as significant as expected due to the small scale of the model (limited by the GPU memory) and the short training time (~10k steps). Further experiments on larger models, bigger datasets, and longer training are expected to highlight the benefits of the MLA block. Also, further experiments with layer normalization and other hyperparameters are expected to improve the performance of the MLA block.

## Notes
- Pythia was used as model for the experiments because it comes with many versions at different scales.
- `pythia-160m` (160M parameters) was the largest model that could be trained on a single GPU with 16GB memory.
- For the same reason, the `tinystories` dataset was used for the experiments and the models were trained for only 100M tokens (~10k steps).
- Experiments on larger models, bigger datasets, and longer training are expected to further highlight the benefits of the MLA block.
- All the tested implementations use FlashAttention (as implemented in torch) by default.
- The resulting implementation of MLA depends on the `litgpt` codebase (especially the `CausalSelfAttention` class).
- The implementation of the MLA block is based on the DeepSeekV2 paper and includes support for KV caching (`KVCacheCompressed`) and decoupled RoPE (`apply_rope_mla`).
- A further improvement would be to optimize the implementation for speed and memory usage (for example, by merging matrices at inference like in LoRA).
> Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^{𝑂}$. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid the computational overhead for recomputing $k^C_t$ and $v^𝐶_𝑡$ during inference.

Unfortunately, this was not implemented due to time constraints.

## Visual Representation
The visual representation of the MLA block with my implementation notes is as follows:

![MLA Block](./mla.png)
131 changes: 131 additions & 0 deletions config_hub/pretrain/cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: pythia-160m

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
name: pythia-160m
hf_config:
org: EleutherAI
name: pythia-160m
block_size: 2048
n_layer: 12
n_embd: 768
n_head: 12
padding_multiple: 128
norm_class_name: LayerNorm
norm_qk: false

# Whether to use latent attention (MLA). (type: bool, default: false)
latent_attention: true
# Whether to use MQA (head_size = 1), MLA (1 < head_size < n_head), or MHA (head_size = n_head).
# Not compatible with latent_attention.
n_query_groups: 12

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/mla

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: TinyStories

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 100

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
global_batch_size: 128

# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 100

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 100000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length:

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 6e-5

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Whether to evaluate on the validation set at the end the training
final_validation: true

# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW

init_args:
# (type: float, default: 0.001)
lr: 6e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# How many nodes to use. (type: int, default: 1)
num_nodes: 1

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/EleutherAI/pythia-160m

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
3 changes: 3 additions & 0 deletions litgpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

import torch
import torch._dynamo # fallback to eager mode if torchscript fails
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options

from litgpt.chat.base import main as chat_fn
Expand All @@ -28,6 +29,8 @@
from litgpt.scripts.download import download_from_hub as download_fn
from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn

torch._dynamo.config.suppress_errors = True


def main() -> None:
parser_data = {
Expand Down
8 changes: 7 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Config:
# Transformer block (self-attention)
n_head: int = 32
head_size: Optional[int] = None
latent_attention: Optional[bool] = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
Expand Down Expand Up @@ -76,7 +77,7 @@ class Config:
attention_logit_softcapping: Optional[float] = None
# Rotary position embedding (RoPE)
rope_base: int = 10000
rotary_percentage: float = 0.25
rotary_percentage: float = 0.5
rope_condense_ratio: int = 1
rope_adjustments: Optional[dict] = None
# Transformer block (MLP)
Expand Down Expand Up @@ -311,6 +312,7 @@ def norm_class(self) -> Type:
n_layer=6,
n_embd=128,
n_head=4,
latent_attention=True,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json
Expand All @@ -321,6 +323,7 @@ def norm_class(self) -> Type:
n_layer=6,
n_embd=256,
n_head=8,
latent_attention=True,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
Expand All @@ -331,6 +334,7 @@ def norm_class(self) -> Type:
n_layer=6,
n_embd=512,
n_head=8,
latent_attention=True,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
Expand All @@ -341,6 +345,7 @@ def norm_class(self) -> Type:
n_layer=12,
n_embd=768,
n_head=12,
latent_attention=True,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
Expand All @@ -351,6 +356,7 @@ def norm_class(self) -> Type:
n_layer=24,
n_embd=1024,
n_head=16,
latent_attention=True,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
Expand Down
16 changes: 12 additions & 4 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,11 @@ def generate(
@torch.inference_mode()
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
prompt: str = "Once upon a time,",
*,
sys_prompt: Optional[str] = None,
num_samples: int = 1,
max_new_tokens: int = 50,
num_samples: int = 100,
max_new_tokens: int = 500,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
Expand Down Expand Up @@ -537,6 +537,8 @@ def main(
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

L.seed_everything(1234)
ts = []
ms = []
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
Expand All @@ -551,10 +553,16 @@ def main(
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(tokenizer.decode(y))
tokenizer.decode(y)
# fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
ts.append(tokens_generated / t)
ms.append(torch.cuda.memory_allocated() / 1e9)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
fabric.print(f"Average memory used: {sum(ms) / len(ms):.02f} GB", file=sys.stderr)
fabric.print(torch.cuda.memory_summary(), file=sys.stderr)
fabric.print(f"Average tokens/sec: {sum(ts) / len(ts):.02f}", file=sys.stderr)
Loading
Loading