diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index 9771a37956..a922c31f56 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -5,6 +5,7 @@ from contextlib import contextmanager import importlib.util from pathlib import Path +import shutil from typing import List, Optional, Tuple import torch @@ -62,7 +63,40 @@ def download_from_hub( download_files = ["tokenizer*", "generation_config.json", "config.json"] if not tokenizer_only: - bins, safetensors = find_weight_files(repo_id, access_token) + bins, safetensors, info = find_weight_files(repo_id, access_token) + + total_weight_size_bytes = 0 + if bins: + total_weight_size_bytes = sum( + (file.size or 0) + for file in info.siblings + if file.rfilename.endswith(".bin") or file.rfilename.endswith(".bin.index.json") + ) + elif safetensors: + total_weight_size_bytes = sum( + (file.size or 0) + for file in info.siblings + if file.rfilename.endswith(".safetensors") + ) + else: + raise ValueError(f"Couldn't find weight files for {repo_id}") + + weight_size_gb = total_weight_size_bytes / (1024**3) + free_space_bytes = shutil.disk_usage(str(checkpoint_dir)).free + free_space_gb = free_space_bytes / (1024**3) + + # 2x because we create lit_model.pth before deleting the downloaded weights, + # so we intermittenly have 2 sets of weights on disk + if weight_size_gb > 2*free_space_gb: + if os.getenv("LIGHTNING_CLOUD_SPACE_ID") is not None: + studio_text = " Please switch to a larger Studio with more disk space." + else: + studio_text = "" + raise RuntimeError( + f"Not enough disk space to download {repo_id} weights. " + f"Needed: ~{2*weight_size_gb:.2f} GB, free: ~{free_space_gb:.2f} GB.{studio_text}" + ) + if bins: # covers `.bin` files and `.bin.index.json` download_files.append("*.bin*") @@ -104,11 +138,11 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s from huggingface_hub.utils import filter_repo_objects with gated_repo_catcher(repo_id, access_token): - info = repo_info(repo_id, token=access_token) + info = repo_info(repo_id, token=access_token, files_metadata=True) filenames = [f.rfilename for f in info.siblings] bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) - return bins, safetensors + return bins, safetensors, info @contextmanager diff --git a/tests/test_rope.py b/tests/test_rope.py index 0aa10aeb58..bd28d1f012 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,11 +1,12 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import dataclass + import torch from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama -from transformers.models.llama.configuration_llama import LlamaConfig from litgpt.model import apply_rope, build_rope_cache @@ -17,7 +18,23 @@ def test_rope_gptneox(): x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() position_ids = torch.arange(seq_len).unsqueeze(0) - theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len) + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + + config = RoPEConfig( + dim=head_size, + max_position_embeddings=seq_len, + rope_theta=10_000, + hidden_size=head_size * n_head, + num_attention_heads=n_head + ) + + theirs_rot_emb = GPTNeoXRotaryEmbedding(config) theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids) ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device) @@ -35,13 +52,32 @@ def test_rope_gptneox(): def test_rope_llama_2(): head_dim = 64 rope_theta = 10_000 + num_heads = 4 + batch_size, seq_len = 1, 10 ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) - batch_size, seq_len = 1, 10 + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads + ) + + rot_emb = LlamaRotaryEmbedding(config) + qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -56,8 +92,6 @@ def test_rope_llama_2(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -76,13 +110,33 @@ def test_rope_llama_2(): def test_rope_llama_3(): head_dim = 64 rope_theta = 50_000 + num_heads = 4 + batch_size, seq_len = 1, 10 ################################## # Compare cos and sin ################################## + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + scaling_factor: float + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + scaling_factor=None + ) + # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) - batch_size, seq_len = 1, 10 + rot_emb = LlamaRotaryEmbedding(config) qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -97,8 +151,6 @@ def test_rope_llama_3(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -117,6 +169,8 @@ def test_rope_llama_3(): def test_rope_llama_3_1(): head_dim = 32 rope_theta = 50_000 + num_heads = 4 + batch_size, seq_len = 1, 131_072 their_rope_config = { "factor": 8.0, @@ -133,18 +187,32 @@ def test_rope_llama_3_1(): "original_max_seq_len": 8192 } - config = LlamaConfig( - rope_theta=rope_theta, - rope_scaling=their_rope_config, - head_dim=head_dim - ) - ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 131_072 + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + rope_type: str + rope_scaling: dict + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + rope_type="llama3", + rope_scaling=their_rope_config + ) + + rot_emb = LlamaRotaryEmbedding(config=config) qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -159,8 +227,6 @@ def test_rope_llama_3_1(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -179,6 +245,8 @@ def test_rope_llama_3_1(): def test_rope_llama_3_2(): head_dim = 32 rope_theta = 50_000 + batch_size, seq_len = 1, 131_072 + num_heads = 4 their_rope_config = { "factor": 32.0, @@ -195,18 +263,32 @@ def test_rope_llama_3_2(): "original_max_seq_len": 8192 } - config = LlamaConfig( - rope_theta=rope_theta, - rope_scaling=their_rope_config, - head_dim=head_dim - ) - ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 131_072 + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + rope_type: str + rope_scaling: dict + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + rope_type="llama3", + rope_scaling=their_rope_config + ) + + rot_emb = LlamaRotaryEmbedding(config) + qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -221,8 +303,6 @@ def test_rope_llama_3_2(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123)