Skip to content

Commit e3088e6

Browse files
k223kimpre-commit-ci[bot]Bordat-vi
authored
[1/4] feat: add gemma 3 27b (#1998)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Thomas Viehmann <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent 74f0fd8 commit e3088e6

File tree

4 files changed

+210
-9
lines changed

4 files changed

+210
-9
lines changed

litgpt/config.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class Config:
8282
# The base period of the RoPE embeddings for local attention.
8383
# If not provided, rope_theta will be used for both local and global attention.
8484
rope_local_base_freq: Optional[float] = None
85+
rope_indices: Optional[List] = None
8586

8687
def __post_init__(self):
8788
if not self.name:
@@ -1053,6 +1054,45 @@ def norm_class(self) -> Type:
10531054
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it"
10541055
configs.append(copy)
10551056

1057+
##################
1058+
# Google Gemma 3
1059+
##################
1060+
gemma3 = [
1061+
# https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json
1062+
dict(
1063+
name="Gemma-3-27b-it",
1064+
hf_config=dict(org="google", name="gemma-3-27b-it"),
1065+
scale_embeddings=True,
1066+
attention_scores_scalar=168,
1067+
vocab_size=262144,
1068+
block_size=131072,
1069+
sliding_window_size=1024,
1070+
# 5 local layers for every global layer
1071+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],
1072+
intermediate_size=21504,
1073+
n_embd=5376,
1074+
n_layer=62,
1075+
n_head=32,
1076+
n_query_groups=16,
1077+
head_size=128,
1078+
rotary_percentage=1.0,
1079+
rope_adjustments=dict(factor=8.0),
1080+
parallel_residual=False,
1081+
bias=False,
1082+
norm_class_name="RMSNorm",
1083+
mlp_class_name="GemmaMLP",
1084+
gelu_approximate="tanh",
1085+
post_attention_norm=True,
1086+
post_mlp_norm=True,
1087+
norm_qk=True,
1088+
rope_base=1000000,
1089+
rope_local_base_freq=10000,
1090+
# 5 local layers for every global layer
1091+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],
1092+
),
1093+
]
1094+
configs.extend(gemma3)
1095+
10561096
##################
10571097
# Google CodeGemma
10581098
##################

litgpt/model.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,18 @@ def forward(
154154
if self.config.scale_embeddings:
155155
x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)
156156

157-
for block in self.transformer.h:
158-
x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)
157+
for block_idx, block in enumerate(self.transformer.h):
158+
if self.config.rope_indices is not None:
159+
x = block(
160+
x,
161+
cos[..., self.config.rope_indices[block_idx]],
162+
sin[..., self.config.rope_indices[block_idx]],
163+
mask,
164+
input_pos,
165+
input_pos_maxp1,
166+
)
167+
else:
168+
x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)
159169
x = self.transformer.ln_f(x)
160170
clamp_head = (
161171
partial(do_softcapping, thresh=self.config.final_logit_softcapping)
@@ -186,6 +196,10 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso
186196
elif num_params_present == 4:
187197
# These parameters should always be used together so that we don't interfere with standard rope
188198
extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}
199+
elif "factor" in self.config.rope_adjustments:
200+
# linear RoPE
201+
adjusted_params_required = ["factor"]
202+
extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}
189203
else:
190204
# Some but not all parameters are specified; raise an error
191205
missing_params = [
@@ -215,7 +229,10 @@ def set_kv_cache(
215229
dtype: Optional[torch.dtype] = None,
216230
) -> None:
217231
if rope_cache_length is None:
218-
rope_cache_length = self.cos.size(-1)
232+
if len(self.cos.shape) == 2:
233+
rope_cache_length = self.cos.size(-1)
234+
else:
235+
rope_cache_length = self.cos[..., 0].size(-1)
219236

220237
if max_seq_length is None:
221238
max_seq_length = self.max_seq_length
@@ -329,8 +346,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
329346
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]
330347

331348
if config.norm_qk:
332-
self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)
333-
self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps)
349+
self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps)
350+
self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps)
334351
else:
335352
self.norm_q = self.norm_k = None
336353

@@ -370,10 +387,6 @@ def forward(
370387
# Split qkv into query, key and value matrices.
371388
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)
372389

373-
if self.config.norm_qk:
374-
q = self.norm_q(q)
375-
k = self.norm_k(k)
376-
377390
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
378391
# embedding size (C) into num_heads (nh) and head_size (hs).
379392
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)
@@ -387,6 +400,10 @@ def forward(
387400
k = k.transpose(1, 2) # (B, nh_k, T, hs)
388401
v = v.transpose(1, 2) # (B, nh_v, T, hs)
389402

403+
if self.config.norm_qk:
404+
q = self.norm_q(q)
405+
k = self.norm_k(k)
406+
390407
# Unlike standard positional embeddings rotary embeddings must be applied at every layer.
391408
q_roped = apply_rope(q[..., :rope_n_elem], cos, sin)
392409
k_roped = apply_rope(k[..., :rope_n_elem], cos, sin)

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,77 @@ def copy_weights_gemma_2(
285285
pbar.update(progress_per_file)
286286

287287

288+
def copy_weights_gemma_3(
289+
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
290+
state_dict: Dict[str, torch.Tensor],
291+
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
292+
saver: Optional[incremental_save] = None,
293+
dtype: Optional[torch.dtype] = None,
294+
pbar: Optional[tqdm] = None,
295+
progress_per_file: Optional[float] = None,
296+
debug_mode: Optional[bool] = False,
297+
) -> None:
298+
weight_map = {
299+
"model.embed_tokens.weight": "transformer.wte.weight",
300+
"model.layers.{}.self_attn.q_proj.weight": None,
301+
"model.layers.{}.self_attn.k_proj.weight": None,
302+
"model.layers.{}.self_attn.v_proj.weight": None,
303+
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
304+
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
305+
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
306+
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
307+
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
308+
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
309+
"model.layers.{}.pre_feedforward_layernorm.weight": "transformer.h.{}.norm_2.weight",
310+
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
311+
"model.norm.weight": "transformer.ln_f.weight",
312+
"lm_head.weight": "lm_head.weight",
313+
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
314+
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
315+
}
316+
317+
if progress_per_file is not None:
318+
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
319+
320+
for from_name, param in hf_weights.items():
321+
name_template, *ids = layer_template(from_name, num_matches=2)
322+
to_name = weight_map[name_template]
323+
param = load_param(param, from_name, dtype, verbose=debug_mode)
324+
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
325+
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
326+
weight_name, weight_type = from_name.split(".")[-2:]
327+
qkv[weight_type][weight_name] = param
328+
329+
if to_name is None:
330+
continue
331+
to_name = to_name.format(*ids)
332+
if saver is not None:
333+
param = saver.store_early(param)
334+
state_dict[to_name] = param
335+
336+
if progress_per_file is not None:
337+
pbar.update(progress_per_file)
338+
339+
if "lm_head.weight" not in state_dict:
340+
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
341+
342+
for i in list(qkv_weights):
343+
for weight_type in list(qkv_weights[i]):
344+
qkv = qkv_weights[i][weight_type]
345+
if len(qkv) != 3:
346+
# qkv is split across different .bin files
347+
continue
348+
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
349+
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
350+
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
351+
qkv = torch.cat((q, k, v))
352+
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
353+
del qkv_weights[i][weight_type]
354+
355+
if progress_per_file is not None:
356+
pbar.update(progress_per_file)
357+
358+
288359
def copy_weights_phi(
289360
config: Config,
290361
qkv_weights: dict,

tests/test_model.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from transformers.models.falcon import FalconConfig, FalconForCausalLM
2424
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
2525
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
26+
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
2627
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
2728
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
2829
from transformers.models.mistral import MistralConfig, MistralForCausalLM
@@ -36,6 +37,7 @@
3637
from litgpt.scripts.convert_hf_checkpoint import (
3738
copy_weights_falcon,
3839
copy_weights_gemma_2,
40+
copy_weights_gemma_3,
3941
copy_weights_gpt_neox,
4042
copy_weights_hf_llama,
4143
copy_weights_phi,
@@ -799,6 +801,77 @@ def test_against_original_gemma_2(model_name, device, dtype):
799801
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
800802

801803

804+
@torch.inference_mode()
805+
@pytest.mark.parametrize("model_name", ["gemma-3-27b-it"])
806+
@pytest.mark.parametrize(
807+
("device", "dtype"),
808+
[
809+
(torch.device("cpu"), torch.float32),
810+
pytest.param(
811+
torch.device("cuda"),
812+
torch.float16,
813+
marks=[
814+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
815+
# is slightly different
816+
pytest.mark.xfail(raises=AssertionError, strict=False),
817+
_RunIf(min_cuda_gpus=1),
818+
],
819+
),
820+
],
821+
)
822+
def test_against_original_gemma_3(model_name, device, dtype):
823+
torch.set_default_dtype(dtype)
824+
825+
T = 20
826+
ours_config = Config.from_name(
827+
model_name,
828+
block_size=T,
829+
sliding_window_size=T // 2,
830+
n_layer=2,
831+
n_head=16,
832+
n_embd=32,
833+
intermediate_size=86,
834+
)
835+
836+
theirs_config = Gemma3TextConfig(
837+
vocab_size=ours_config.padded_vocab_size,
838+
hidden_size=ours_config.n_embd,
839+
head_dim=ours_config.head_size,
840+
num_attention_heads=ours_config.n_head,
841+
num_hidden_layers=ours_config.n_layer,
842+
intermediate_size=ours_config.intermediate_size,
843+
max_position_embeddings=ours_config.block_size,
844+
sliding_window=ours_config.sliding_window_size,
845+
rms_norm_eps=ours_config.norm_eps,
846+
num_key_value_heads=ours_config.n_query_groups,
847+
rope_theta=ours_config.rope_base,
848+
attention_bias=ours_config.bias,
849+
tie_word_embeddings=True,
850+
hidden_act="gelu_pytorch_tanh",
851+
attn_implementation="eager",
852+
query_pre_attn_scalar=ours_config.attention_scores_scalar,
853+
rope_scaling={"factor": 8.0, "rope_type": "linear"},
854+
rope_local_base_freq=ours_config.rope_local_base_freq,
855+
)
856+
857+
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
858+
theirs_state_dict = theirs_model.state_dict()
859+
# Gemma weights are shipped without `lm_head.weight`
860+
theirs_state_dict.pop("lm_head.weight")
861+
state_dict = {}
862+
863+
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
864+
ours_model = GPT(ours_config).to(device)
865+
ours_model.load_state_dict(state_dict)
866+
867+
# test end to end
868+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
869+
assert x.size(1) == T
870+
ours_y = ours_model(x)
871+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
872+
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
873+
874+
802875
@torch.inference_mode()
803876
@pytest.mark.parametrize(
804877
"model_name", ["Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B"]

0 commit comments

Comments
 (0)