@@ -516,10 +516,11 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
516
516
517
517
518
518
class GptNeoxMLP (nn .Module ):
519
- def __init__ (self , config : Config ) -> None :
519
+ def __init__ (self , config : Config , intermediate_size : Optional [ int ] = None ) -> None :
520
520
super ().__init__ ()
521
- self .fc = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
522
- self .proj = nn .Linear (config .intermediate_size , config .n_embd , bias = config .bias )
521
+ self .intermediate_size = intermediate_size or config .intermediate_size
522
+ self .fc = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
523
+ self .proj = nn .Linear (self .intermediate_size , config .n_embd , bias = config .bias )
523
524
self .config = config
524
525
525
526
def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -529,11 +530,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
529
530
530
531
531
532
class LLaMAMLP (nn .Module ):
532
- def __init__ (self , config : Config ) -> None :
533
+ def __init__ (self , config : Config , intermediate_size : Optional [ int ] = None ) -> None :
533
534
super ().__init__ ()
534
- self .fc_1 = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
535
- self .fc_2 = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
536
- self .proj = nn .Linear (config .intermediate_size , config .n_embd , bias = config .bias )
535
+ self .intermediate_size = intermediate_size or config .intermediate_size
536
+ self .fc_1 = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
537
+ self .fc_2 = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
538
+ self .proj = nn .Linear (self .intermediate_size , config .n_embd , bias = config .bias )
537
539
self .config = config
538
540
539
541
def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -555,7 +557,9 @@ class LLaMAMoE(nn.Module):
555
557
def __init__ (self , config : Config ) -> None :
556
558
super ().__init__ ()
557
559
self .gate = nn .Linear (config .n_embd , config .n_expert , bias = False )
558
- self .experts = nn .ModuleList (LLaMAMLP (config ) for _ in range (config .n_expert ))
560
+ self .experts = nn .ModuleList (
561
+ LLaMAMLP (config , intermediate_size = config .moe_intermediate_size ) for _ in range (config .n_expert )
562
+ )
559
563
self .config = config
560
564
561
565
def forward (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments