Skip to content

Commit f99ca4e

Browse files
ysjprojectsshijie.yupre-commit-ci[bot]
authored
Qwen3 MoE Preliminary: add intermediate_size argument to MLP modules (#2046)
Co-authored-by: shijie.yu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e6740f5 commit f99ca4e

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class Config:
8181
rope_adjustments: Optional[dict] = None
8282
# Transformer block (MLP)
8383
intermediate_size: Optional[int] = None
84+
moe_intermediate_size: Optional[int] = None
8485
bias: bool = True
8586
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
8687
gelu_approximate: str = "none"

litgpt/model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,11 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
516516

517517

518518
class GptNeoxMLP(nn.Module):
519-
def __init__(self, config: Config) -> None:
519+
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
520520
super().__init__()
521-
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
522-
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
521+
self.intermediate_size = intermediate_size or config.intermediate_size
522+
self.fc = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
523+
self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)
523524
self.config = config
524525

525526
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -529,11 +530,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
529530

530531

531532
class LLaMAMLP(nn.Module):
532-
def __init__(self, config: Config) -> None:
533+
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
533534
super().__init__()
534-
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
535-
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
536-
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
535+
self.intermediate_size = intermediate_size or config.intermediate_size
536+
self.fc_1 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
537+
self.fc_2 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
538+
self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)
537539
self.config = config
538540

539541
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -555,7 +557,9 @@ class LLaMAMoE(nn.Module):
555557
def __init__(self, config: Config) -> None:
556558
super().__init__()
557559
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
558-
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
560+
self.experts = nn.ModuleList(
561+
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
562+
)
559563
self.config = config
560564

561565
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)