@@ -352,7 +352,7 @@ def forward(
352
352
x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
353
353
x_out_r = x_r * freqs_cos - x_i * freqs_sin
354
354
x_out_i = x_r * freqs_sin + x_i * freqs_cos
355
- x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
355
+ x_out = torch .stack ([x_out_r , x_out_i ], dim = - 1 ). flatten ( 2 )
356
356
return x_out
357
357
358
358
@@ -378,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
378
378
self .inv_scale = 1.0 / (float (self .head_dim ) ** 0.5 )
379
379
self .attention_qkv_bias = config .attention_qkv_bias
380
380
self .use_qk_norm = config .use_qk_norm
381
+ self .qk_norm_before_rope = config .qk_norm_before_rope
381
382
self .use_conv2d = False
382
383
383
384
self .wqs = nn .ModuleList (
@@ -449,12 +450,17 @@ def from_conv2ds(ts):
449
450
new_ks = from_conv2ds (new_ks )
450
451
new_vs = from_conv2ds (new_vs )
451
452
452
- if self .use_qk_norm :
453
+ if self .use_qk_norm and self . qk_norm_before_rope :
453
454
new_qs = [self .q_norm (q ) for q in new_qs ]
454
455
new_ks = [self .k_norm (k ) for k in new_ks ]
455
456
456
457
new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
457
458
new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
459
+
460
+ if self .use_qk_norm and not self .qk_norm_before_rope :
461
+ new_qs = [self .q_norm (q ) for q in new_qs ]
462
+ new_ks = [self .k_norm (k ) for k in new_ks ]
463
+
458
464
all_ks = []
459
465
all_vs = []
460
466
for i in range (self .n_kv_heads ):
@@ -505,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
505
511
506
512
if other .use_qk_norm :
507
513
self .use_qk_norm = True
514
+ self .qk_norm_before_rope = other .qk_norm_before_rope
508
515
self .q_norm = torch .nn .RMSNorm (other .q_norm_fn .dim , other .q_norm_fn .eps )
509
516
self .q_norm .load_state_dict (other .q_norm_fn .state_dict ())
510
517
self .k_norm = torch .nn .RMSNorm (other .k_norm_fn .dim , other .k_norm_fn .eps )
0 commit comments