diff --git a/.gitignore b/.gitignore index 8c416e1b37..b235c703e7 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ checkpoints out wandb events.out.tfevents* +evaluate_model +test.ipynb +mla_test.py # test artifacts from tests/test_readme.py **/custom_finetuning_dataset.json diff --git a/README_MLA.md b/README_MLA.md new file mode 100644 index 0000000000..01ef90c310 --- /dev/null +++ b/README_MLA.md @@ -0,0 +1,52 @@ +# Multi-Head Latent Attention (MLA) + +## Overview +This document outlines the modifications made to the codebase in the `litgpt` repository to add support for Multi-Head Latent Attention (MLA) block from [DeepSeekV2](https://arxiv.org/abs/2405.04434). + +## Changes Made +1. **Configuration**: Added `latent_attention: Optional[bool] = False` parameter to the configuration file to enable the MLA block. +2. **MLA module**: Implemented the MLA module as a separate component in the `litgpt` codebase. +3. **KVCacheCompressed**: Added support for the `KVCacheCompressed` class to store the key-value pairs for the MLA block. +4. **Model**: Modified the GPT model to include the **MLA block** as an alternative component based on the configuration parameter `latent_attention`. +5. **Training**: Updated the training script to support the MLA block and added support for training with the new configuration file `config_hub/pretrain/cfg.yaml`. + +## Installation +Follow the updated installation instructions in the `README.md` file. + +## Usage +1. **Configuration**: Set the `latent_attention` parameter to `True` in the configuration file to enable the MLA block. +2. **Training**: Run the training script with the updated configuration file. + ```bash + litgpt pretrain --config config_hub/pretrain/cfg.yaml + ``` +3. **Inference**: Use the trained model for inference as follows: + ```bash + litgpt generate out/pretrain/mla/final/ + ``` + +## Results +Results are available at [this link](https://docs.google.com/spreadsheets/d/1-VnTDoK5JuNPGMjory_z1hQkI7y-RgiTpTsUpa3bVEg/edit?usp=sharing). + +The results highlight that MQA and GQA considerably reduce memory usage and increase the speed of training. However, this comes at the cost of a significant decrease in performance compared to the baseline model. + +The MLA block demonstrates a better trade-off between memory usage, speed, and performance. It shows a slight drop in performance compared to the baseline model, while also reducing memory usage. This also comes with a slight increase in training and inference speed. Smaller projection dimensions have been tested for the MLA block, showing a consistent reduction of memory usage but with a significant drop in performance. + +Overall, results are not as significant as expected due to the small scale of the model (limited by the GPU memory) and the short training time (~10k steps). Further experiments on larger models, bigger datasets, and longer training are expected to highlight the benefits of the MLA block. Also, further experiments with layer normalization and other hyperparameters are expected to improve the performance of the MLA block. + +## Notes +- Pythia was used as model for the experiments because it comes with many versions at different scales. +- `pythia-160m` (160M parameters) was the largest model that could be trained on a single GPU with 16GB memory. +- For the same reason, the `tinystories` dataset was used for the experiments and the models were trained for only 100M tokens (~10k steps). +- Experiments on larger models, bigger datasets, and longer training are expected to further highlight the benefits of the MLA block. +- All the tested implementations use FlashAttention (as implemented in torch) by default. +- The resulting implementation of MLA depends on the `litgpt` codebase (especially the `CausalSelfAttention` class). +- The implementation of the MLA block is based on the DeepSeekV2 paper and includes support for KV caching (`KVCacheCompressed`) and decoupled RoPE (`apply_rope_mla`). +- A further improvement would be to optimize the implementation for speed and memory usage (for example, by merging matrices at inference like in LoRA). + > Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^{𝑂}$. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid the computational overhead for recomputing $k^C_t$ and $v^𝐶_𝑡$ during inference. + + Unfortunately, this was not implemented due to time constraints. + +## Visual Representation +The visual representation of the MLA block with my implementation notes is as follows: + +![MLA Block](./mla.png) diff --git a/config_hub/pretrain/cfg.yaml b/config_hub/pretrain/cfg.yaml new file mode 100644 index 0000000000..5244372b9a --- /dev/null +++ b/config_hub/pretrain/cfg.yaml @@ -0,0 +1,131 @@ +# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with +# ``model_config``. (type: Optional[str], default: null) +model_name: pythia-160m + +# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with +# ``model_config``. (type: Optional[Config], default: null) +model_config: + name: pythia-160m + hf_config: + org: EleutherAI + name: pythia-160m + block_size: 2048 + n_layer: 12 + n_embd: 768 + n_head: 12 + padding_multiple: 128 + norm_class_name: LayerNorm + norm_qk: false + + # Whether to use latent attention (MLA). (type: bool, default: false) + latent_attention: true + # Whether to use MQA (head_size = 1), MLA (1 < head_size < n_head), or MHA (head_size = n_head). + # Not compatible with latent_attention. + n_query_groups: 12 + +# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in +# /teamspace/jobs//share. (type: , default: out/pretrain) +out_dir: out/pretrain/mla + +# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-mixed + +# Optional path to a checkpoint directory to initialize the model from. +# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) +initial_checkpoint_dir: + +# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume +# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing +# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. +# (type: Union[bool, Literal["auto"], Path], default: False) +resume: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. +data: TinyStories + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 1000 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 100 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512) + global_batch_size: 128 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 4 + + # Number of iterations with learning rate warmup active (type: int, default: 2000) + lr_warmup_steps: 100 + + # Number of epochs to train on (type: Optional[int], default: null) + epochs: + + # Total number of tokens to train on (type: Optional[int], default: 3000000000000) + max_tokens: 100000000 + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) + tie_embeddings: + + # (type: Optional[float], default: 1.0) + max_norm: 1.0 + + # (type: float, default: 4e-05) + min_lr: 6e-5 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + # Number of optimizer steps between evaluation calls (type: int, default: 1000) + interval: 1000 + + # Number of tokens to generate (type: Optional[int], default: null) + max_new_tokens: + + # Number of iterations (type: int, default: 100) + max_iters: 100 + + # Whether to evaluate on the validation set at the beginning of the training + initial_validation: false + + # Whether to evaluate on the validation set at the end the training + final_validation: true + +# Optimizer-related arguments +optimizer: + class_path: torch.optim.AdamW + + init_args: + # (type: float, default: 0.001) + lr: 6e-4 + + # (type: float, default: 0.01) + weight_decay: 0.1 + + # (type: tuple, default: (0.9,0.999)) + betas: + - 0.9 + - 0.95 + +# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) +devices: auto + +# How many nodes to use. (type: int, default: 1) +num_nodes: 1 + +# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data +# module require this. (type: Optional[Path], default: null) +tokenizer_dir: checkpoints/EleutherAI/pythia-160m + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard) +logger_name: tensorboard + +# The random seed to use for reproducibility. (type: int, default: 42) +seed: 42 diff --git a/litgpt/__main__.py b/litgpt/__main__.py index a92e2027c2..7e2b2dd46a 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -3,6 +3,7 @@ import warnings import torch +import torch._dynamo # fallback to eager mode if torchscript fails from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options from litgpt.chat.base import main as chat_fn @@ -28,6 +29,8 @@ from litgpt.scripts.download import download_from_hub as download_fn from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn +torch._dynamo.config.suppress_errors = True + def main() -> None: parser_data = { diff --git a/litgpt/config.py b/litgpt/config.py index 6b7748cf63..400eeab43d 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -45,6 +45,7 @@ class Config: # Transformer block (self-attention) n_head: int = 32 head_size: Optional[int] = None + latent_attention: Optional[bool] = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between @@ -76,7 +77,7 @@ class Config: attention_logit_softcapping: Optional[float] = None # Rotary position embedding (RoPE) rope_base: int = 10000 - rotary_percentage: float = 0.25 + rotary_percentage: float = 0.5 rope_condense_ratio: int = 1 rope_adjustments: Optional[dict] = None # Transformer block (MLP) @@ -311,6 +312,7 @@ def norm_class(self) -> Type: n_layer=6, n_embd=128, n_head=4, + latent_attention=True, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json @@ -321,6 +323,7 @@ def norm_class(self) -> Type: n_layer=6, n_embd=256, n_head=8, + latent_attention=True, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json @@ -331,6 +334,7 @@ def norm_class(self) -> Type: n_layer=6, n_embd=512, n_head=8, + latent_attention=True, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json @@ -341,6 +345,7 @@ def norm_class(self) -> Type: n_layer=12, n_embd=768, n_head=12, + latent_attention=True, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json @@ -351,6 +356,7 @@ def norm_class(self) -> Type: n_layer=24, n_embd=1024, n_head=16, + latent_attention=True, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 565ef08e23..12c4061a0d 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -430,11 +430,11 @@ def generate( @torch.inference_mode() def main( checkpoint_dir: Path, - prompt: str = "What food do llamas eat?", + prompt: str = "Once upon a time,", *, sys_prompt: Optional[str] = None, - num_samples: int = 1, - max_new_tokens: int = 50, + num_samples: int = 100, + max_new_tokens: int = 500, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -537,6 +537,8 @@ def main( fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) L.seed_everything(1234) + ts = [] + ms = [] for i in range(num_samples): t0 = time.perf_counter() y = generate( @@ -551,10 +553,16 @@ def main( t = time.perf_counter() - t0 for block in model.transformer.h: block.attn.kv_cache.reset_parameters() - fabric.print(tokenizer.decode(y)) + tokenizer.decode(y) + # fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) + ts.append(tokens_generated / t) + ms.append(torch.cuda.memory_allocated() / 1e9) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + fabric.print(f"Average memory used: {sum(ms) / len(ms):.02f} GB", file=sys.stderr) + fabric.print(torch.cuda.memory_summary(), file=sys.stderr) + fabric.print(f"Average tokens/sec: {sum(ts) / len(ts):.02f}", file=sys.stderr) diff --git a/litgpt/model.py b/litgpt/model.py index f2c3c99f6b..bc639d9f56 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -272,7 +272,7 @@ def __init__( ) self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps) - self.attn = CausalSelfAttention(config, block_idx) + self.attn = CausalSelfAttention(config, block_idx) if not config.latent_attention else MLA(config, block_idx) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() ) @@ -528,6 +528,218 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) +class MLA(nn.Module): + """ + Multi-Head Latent Attention (MLA) block from DeepSeekV2 https://arxiv.org/abs/2405.04434 + """ + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__() + + # key-value (2/3) and query (1/2) projection dimensions + self.q_proj_dim = config.n_embd // 2 + self.kv_proj_dim = 2 * config.n_embd // 3 + + # qk channel division for RoPE (50%-50%) + self.qk_rope_dim = config.head_size // 2 + self.qk_nope_dim = config.head_size // 2 # no positional embedding + + # q projections (bottleneck) + self.dq = nn.Linear(config.n_embd, self.q_proj_dim, bias=False) # down-projection + self.uq = nn.Linear(self.q_proj_dim, config.n_embd, bias=False) # up-projection + + # kv projections + self.dkv = nn.Linear( + config.n_embd, self.kv_proj_dim + self.qk_rope_dim, bias=False + ) # latent dimension for kv + shared key for RoPE + self.ukv = nn.Linear( + self.kv_proj_dim, config.n_embd + (config.n_head * self.qk_nope_dim), bias=False + ) # up-projection only for LoRA part + + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # unchanged + + # cache is disabled by default + self.kv_cache: Optional[KVCacheCompressed] = None + + # layer norm for projections + if config.norm_qk: + self.norm_kv = config.norm_class(self.kv_proj_dim, eps=config.norm_eps) + self.norm_q = config.norm_class(self.q_proj_dim, eps=config.norm_eps) + else: + self.norm_q = self.norm_kv = None + + # configuration + self.config = config + + def apply_rope_mla(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """ + Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch + dimension. + + Args: + x: Input tensor, `(B, ..., T, head_size)` + cos: Cached cosines, `(B, T, head_size)` or `(1, T, head_size)` + sin: Cached sines, `(B, T, head_size)` or `(1, T, head_size)` + + Returns: + Encoded tensor, `(B, ..., T, head_size)` + """ + if cos.dim() != 3: + raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") + if cos.shape != sin.shape: + raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") + head_size_half = x.size(-1) // 2 + x1 = x[..., :head_size_half] # (B, ..., T, head_size/2) + x2 = x[..., head_size_half:] # (B, ..., T, head_size/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size) + dims_diff = x.dim() - cos.dim() + if dims_diff > 0: + # Ensure that shapes of `x`, `cos`, `sin` align + new_shape = cos.shape[0:2] + (1,) * dims_diff + cos.shape[2:] + cos = cos.view(*new_shape) + sin = sin.view(*new_shape) + + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # notation: + # - B | batch size + # - T | time-step (sequence length) + # - C | model's embeddings size (n_embd) + # - C* | attentions's embeddings size + # - nh_(q,k,v) | number of heads for query, key and value + # - hs | head size + + B, T, _ = x.size() # batch size, sequence length + + # q projections + latent_q = self.dq(x) + if self.norm_q: + latent_q = self.norm_q(latent_q) + q = self.uq(latent_q) + q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) + q, q_for_rope = torch.split(q, [self.qk_nope_dim, self.qk_rope_dim], dim=-1) # split channels for RoPE + + # q decoupled for RoPE + q_for_rope = self.apply_rope_mla(q_for_rope[..., : self.config.rope_n_elem], cos, sin) + + # kv projections + if self.kv_cache: # kv cache + new_kv = self.dkv(x) + latent_kv = self.kv_cache(input_pos, new_kv) + + old_kv = latent_kv[..., : input_pos[0], :] + old_kv, old_k_for_rope = torch.split(old_kv, [self.kv_proj_dim, self.qk_rope_dim], dim=-1) + + new_kv, new_k_for_rope = torch.split(new_kv, [self.kv_proj_dim, self.qk_rope_dim], dim=-1) + + if self.norm_kv: # normalized separately as in the original implementation + new_kv = self.norm_kv(new_kv) + old_kv = self.norm_kv(old_kv) + + kv_for_lora = torch.cat([old_kv, new_kv], dim=1) + k_for_rope = torch.cat([old_k_for_rope, new_k_for_rope], dim=1) + + else: # no cache + latent_kv = self.dkv(x) + + kv_for_lora, k_for_rope = torch.split( + latent_kv, [self.kv_proj_dim, self.qk_rope_dim], dim=-1 + ) # split LoRA and RoPE additional shared head + + if self.norm_kv: + kv_for_lora = self.norm_kv(kv_for_lora) + + # kv projection back + kv = self.ukv(kv_for_lora) + + # Split qkv into query, key and value matrices. + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the + # embedding size (C) into num_heads (nh) and head_size (hs). + kv = kv.view(B, -1, self.config.n_head, self.config.head_size + self.qk_nope_dim).transpose(1, 2) + k, v = torch.split(kv, [self.qk_nope_dim, self.config.head_size], dim=-1) + + # k Rope + k_for_rope = k_for_rope.view(B, -1, 1, self.qk_rope_dim) # reshape to make it a 1-head tensor + k_for_rope = self.apply_rope_mla(k_for_rope[..., : self.config.rope_n_elem], cos, sin).transpose(1, 2) + + # apply position encoding to each head + k_for_rope = k_for_rope.repeat(1, self.config.n_head, 1, 1) + + # split into multiple heads + q = torch.cat([q, q_for_rope], dim=-1).transpose(1, 2) # (B, nh_q, T, hs) + k = torch.cat([k, k_for_rope], dim=-1) + v = v # already reshaped before the split + + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are + # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # of size `hs`. + + # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) + y = self.scaled_dot_product_attention(q, k, v, mask) + + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, self.config.head_size * self.config.n_head) + + # Output projection + return self.proj(y) # (B, T, C) + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) + + # with softcapping we cannot use SDPA + if self.config.attention_logit_softcapping is not None: + scores = q @ k.mT * scale + scores = do_softcapping(scores, self.config.attention_logit_softcapping) + if mask is None: + mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) + mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) + scores = scores + mask + scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) + y = scores @ v + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCacheCompressed": + kv_shape = (batch_size, max_seq_length, self.kv_proj_dim + self.qk_rope_dim) + return KVCacheCompressed(kv_shape=kv_shape, device=device, dtype=dtype) + + def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with legacy checkpoints.""" + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.{attr}" + current_key = f"{prefix}qkv.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + class GptNeoxMLP(nn.Module): def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: super().__init__() @@ -823,6 +1035,50 @@ def reset_parameters(self) -> None: torch.nn.init.zeros_(self.v) +class KVCacheCompressed(nn.Module): + """ + Buffers `kv` have shape + `(batch_size, max_seq_length, kv_proj_dim + qk_rope_dim)`. + """ + + def __init__( + self, + kv_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer("kv", torch.zeros(kv_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, input_pos: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Writes new values `kv` into the cache at the positions specified + by `input_pos` along the sequence dimension (`max_seq_length`). The batch + size of `kv` (`bs`) must be smaller or equal to `KVCacheCompressed` batch + size. Returns the full buffers, adjusted to the batch size `bs`. + + Args: + input_pos: Position index, `(bs, T)` or `(T,)` + kv: New values, `(bs, T, kv_proj_dim + qk_rope_dim)` + + Returns: + kv_full `(bs, max_seq_length, kv_proj_dim + qk_rope_dim)` + + """ + # move the buffer to the activation dtype for when AMP is used + self.kv = self.kv.to(kv.dtype) + # update the cache + bs = kv.size(0) + kv = batched_index_copy_(self.kv[:bs, ...], -2, input_pos, kv) + return kv + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.kv) + + def get_cache(self) -> torch.Tensor: + return self.kv + + def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) return torch.tril(ones).unsqueeze(0).unsqueeze(0) diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 4c5e0944cf..6c2b70087d 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -22,7 +22,7 @@ from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.config import name_to_config from litgpt.data import DataModule, TinyLlama -from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP +from litgpt.model import GPT, MLA, Block, CausalSelfAttention, Config, LLaMAMLP from litgpt.utils import ( CycleIterator, capture_hparams, @@ -487,7 +487,7 @@ def init_weights(module, std): # need a separate loop because `mod.proj` below is a `nn.Linear` too for mod in model.modules(): - if isinstance(mod, (LLaMAMLP, CausalSelfAttention)): + if isinstance(mod, (LLaMAMLP, CausalSelfAttention, MLA)): mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer)) if not isinstance(fabric.strategy, FSDPStrategy): diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index d7ce885c55..60ff9a113c 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -77,6 +77,14 @@ def copy_weights_gpt_neox( "transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", + # "transformer.h.{}.attn.dq.bias": "gpt_neox.layers.{}.attention.down_query.bias", + # "transformer.h.{}.attn.dq.weight": "gpt_neox.layers.{}.attention.down_query.weight", + # "transformer.h.{}.attn.uq.bias": "gpt_neox.layers.{}.attention.up_query.bias", + # "transformer.h.{}.attn.uq.weight": "gpt_neox.layers.{}.attention.up_query.weight", + # "transformer.h.{}.attn.dkv.bias": "gpt_neox.layers.{}.attention.down_keyvalue.bias", + # "transformer.h.{}.attn.dkv.weight": "gpt_neox.layers.{}.attention.down_keyvalue.weight", + # "transformer.h.{}.attn.ukv.bias": "gpt_neox.layers.{}.attention.up_keyvalue.bias", + # "transformer.h.{}.attn.ukv.weight": "gpt_neox.layers.{}.attention.up_keyvalue.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", diff --git a/mla.png b/mla.png new file mode 100644 index 0000000000..5c3a941acd Binary files /dev/null and b/mla.png differ