@@ -94,53 +94,60 @@ def forward(self, prefix: paddle.Tensor):
94
94
95
95
96
96
class RotaryEmbeddings (nn .Layer ):
97
- def __init__ (self , hidden_size , base = 10000.0 , learnable = False ):
97
+ def __init__ (self , hidden_size , base = 10000.0 , position_encoding_2d = True ):
98
98
super ().__init__ ()
99
99
self .dtype = paddle .get_default_dtype ()
100
100
inv_freq = 1.0 / (base ** (paddle .arange (0 , hidden_size , 2 ).astype ("float32" ) / hidden_size ))
101
101
inv_freq = inv_freq .astype (self .dtype )
102
- self .learnable = learnable
103
- if learnable :
104
- self .inv_freq = nn .Parameter (inv_freq )
105
- self .max_seq_len_cached = None
102
+ self .position_encoding_2d = position_encoding_2d
103
+ self .register_buffer ("inv_freq" , inv_freq )
104
+ self .max_seq_len_cached = - 1
105
+ self .cos_cached = None
106
+ self .sin_cached = None
107
+
108
+ def get_rotary_embeds (self , cos , sin , position_ids ):
109
+ # [s, b, 1, h/n]
110
+ cos = cos .squeeze (1 )[position_ids ].unsqueeze (2 )
111
+ sin = sin .squeeze (1 )[position_ids ].unsqueeze (2 )
112
+ return paddle .stack ([cos , sin ], axis = 0 )
113
+
114
+ def forward (self , position_ids ):
115
+ seq_len = position_ids .max () + 1
116
+ if self .max_seq_len_cached < 0 or seq_len > self .max_seq_len_cached :
117
+ self .max_seq_len_cached = seq_len
118
+
119
+ # x.shape = [b, s, n, h/n/2]
120
+ t = paddle .arange (seq_len , dtype = self .inv_freq .dtype )
121
+ # [s, h/n/2]
122
+ # TODO: Failed for fp16 when converting to static graph.
123
+ freqs = paddle .einsum ("i,j->ij" , t , self .inv_freq )
124
+ freqs = freqs .cast (self .dtype )
125
+ # [s, h/n]
126
+ emb = paddle .concat ([freqs , freqs ], axis = - 1 )
127
+ if self .dtype == paddle .bfloat16 :
128
+ emb = emb .cast ("float32" )
129
+ # [s, 1, h/n]
130
+ cos_cached = emb .cos ().unsqueeze (1 )
131
+ sin_cached = emb .sin ().unsqueeze (1 )
132
+
133
+ if self .dtype == paddle .bfloat16 :
134
+ cos_cached = cos_cached .astype (self .dtype )
135
+ sin_cached = sin_cached .astype (self .dtype )
136
+
137
+ self .cos_cached , self .sin_cached = cos_cached , sin_cached
138
+
139
+ cos , sin = self .cos_cached [:seq_len , ...], self .sin_cached [:seq_len , ...]
140
+ if self .position_encoding_2d :
141
+ block_position_ids = position_ids [:, 1 , :].transpose ([1 , 0 ])
142
+ position_ids = position_ids [:, 0 , :].transpose ([1 , 0 ])
143
+ block_rotary_embeds = self .get_rotary_embeds (cos , sin , block_position_ids )
144
+ position_rotary_embeds = self .get_rotary_embeds (cos , sin , position_ids )
145
+ rotary_embeds = paddle .stack ([position_rotary_embeds , block_rotary_embeds ], axis = 0 )
106
146
else :
107
- self .register_buffer ("inv_freq" , inv_freq )
108
- self .max_seq_len_cached = None
109
- self .cos_cached = None
110
- self .sin_cached = None
111
-
112
- def forward (self , x , seq_dim = 1 , seq_len = None ):
113
- if seq_len is None :
114
- seq_len = x .shape [seq_dim ]
115
-
116
- # x.shape = [b, s, n, h/n/2]
117
- # TODO: Remove the condition for converting to static graph.
118
- # if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
119
- # self.max_seq_len_cached = None if self.learnable else seq_len
120
- # [s]
121
- t = paddle .arange (seq_len ).astype (self .dtype )
122
- # [s, h/n/2]
123
- # TODO: Failed for fp16 when converting to static graph.
124
- freqs = paddle .einsum ("i,j->ij" , t .astype ("float32" ), self .inv_freq .astype ("float32" ))
125
- freqs = freqs .astype (self .dtype )
126
- # [s, h/n]
127
- emb = paddle .concat ([freqs , freqs ], axis = - 1 )
128
- if self .dtype == paddle .bfloat16 :
129
- emb = emb .astype ("float32" )
130
- # [s, 1, h/n]
131
- cos_cached = emb .cos ().unsqueeze (1 )
132
- sin_cached = emb .sin ().unsqueeze (1 )
133
-
134
- if self .dtype == paddle .bfloat16 :
135
- cos_cached = cos_cached .astype (self .dtype )
136
- sin_cached = sin_cached .astype (self .dtype )
137
-
138
- if self .learnable :
139
- return cos_cached , sin_cached
140
-
141
- self .cos_cached , self .sin_cached = cos_cached , sin_cached
142
-
143
- return self .cos_cached [:seq_len , ...], self .sin_cached [:seq_len , ...]
147
+ position_ids = position_ids .transpose ([1 , 0 ])
148
+ rotary_embeds = self .get_rotary_embeds (cos , sin , position_ids )
149
+
150
+ return rotary_embeds
144
151
145
152
146
153
class ChatGLMAttention (nn .Layer ):
@@ -161,13 +168,6 @@ def __init__(self, config: ChatGLMConfig):
161
168
self .attention_head_size = config .hidden_size // config .num_attention_heads
162
169
self .hidden_size = config .hidden_size
163
170
self .position_encoding_2d = config .position_encoding_2d
164
- self .rotary_embeddings = RotaryEmbeddings (
165
- self .hidden_size // (self .num_attention_heads * 2 )
166
- if self .position_encoding_2d
167
- else self .hidden_size // self .num_attention_heads ,
168
- base = 10000.0 ,
169
- learnable = False ,
170
- )
171
171
self .scale_mask_softmax = False
172
172
self .dtype = paddle .get_default_dtype ()
173
173
@@ -192,40 +192,32 @@ def _rotate_half(self, x):
192
192
x1 , x2 = paddle .chunk (x , 2 , axis = - 1 )
193
193
return paddle .concat ([- x2 , x1 ], axis = - 1 )
194
194
195
- def _apply_rotary_position_embed_index (self , q , k , cos , sin , position_ids ):
195
+ def _apply_rotary_position_embed_index (self , q , k , cos , sin ):
196
196
# q.shape = [s, b, n, h/n/2], cos.shape = [s, 1, h/n], position_ids.shape = [s, b]
197
- # [s, b, 1, h/n]
198
- cos = cos .squeeze (1 )[position_ids ].unsqueeze (2 )
199
- sin = sin .squeeze (1 )[position_ids ].unsqueeze (2 )
200
197
# [s, b, n, h/n]
201
198
q = q * cos + self ._rotate_half (q ) * sin
202
199
k = k * cos + self ._rotate_half (k ) * sin
203
200
return q , k
204
201
205
- def _core_attention (self , q_layer : Tensor , k_layer : Tensor , position_ids : Tensor ):
202
+ def _core_attention (self , q_layer : Tensor , k_layer : Tensor , position_ids : Tensor , rotary_embeds : Tensor ):
206
203
# Set store_true, position_encoding_2d=False by default.
207
204
if self .config .position_encoding_2d :
208
205
# [s, b, n, h/n/2]
209
206
q1 , q2 = paddle .chunk (q_layer , 2 , axis = - 1 )
210
207
k1 , k2 = paddle .chunk (k_layer , 2 , axis = - 1 )
211
- # [s, 1, h/n]
212
- cos , sin = self .rotary_embeddings (q1 , seq_len = position_ids .max () + 1 )
213
- # [s, b]
214
- block_position_ids = position_ids [:, 1 , :].transpose ([1 , 0 ])
215
- position_ids = position_ids [:, 0 , :].transpose ([1 , 0 ])
208
+
209
+ pcos , psin = rotary_embeds [0 ][0 ], rotary_embeds [0 ][1 ]
210
+ bcos , bsin = rotary_embeds [1 ][0 ], rotary_embeds [1 ][1 ]
216
211
217
212
# [s, b, n, h/n]
218
- q1 , k1 = self ._apply_rotary_position_embed_index (q1 , k1 , cos , sin , position_ids )
219
- q2 , k2 = self ._apply_rotary_position_embed_index (q2 , k2 , cos , sin , block_position_ids )
213
+ q1 , k1 = self ._apply_rotary_position_embed_index (q1 , k1 , pcos , psin )
214
+ q2 , k2 = self ._apply_rotary_position_embed_index (q2 , k2 , bcos , bsin )
220
215
q_layer = paddle .concat ([q1 , q2 ], axis = - 1 )
221
216
k_layer = paddle .concat ([k1 , k2 ], axis = - 1 )
222
217
else :
223
- # [s, b]
224
- position_ids = position_ids .transpose ([1 , 0 ])
225
- # [s, 1, h/n]
226
- cos , sin = self .rotary_embeddings (q_layer , seq_len = position_ids .max () + 1 )
218
+ cos , sin = rotary_embeds [0 ], rotary_embeds [1 ]
227
219
# [s, b, n, h/n]
228
- q_layer , k_layer = self ._apply_rotary_position_embed_index (q_layer , k_layer , cos , sin , position_ids )
220
+ q_layer , k_layer = self ._apply_rotary_position_embed_index (q_layer , k_layer , cos , sin )
229
221
return q_layer , k_layer
230
222
231
223
def forward (
@@ -236,6 +228,7 @@ def forward(
236
228
use_cache : bool = False ,
237
229
cache : Tensor = None ,
238
230
layer_id = 0 ,
231
+ rotary_embeds = None ,
239
232
):
240
233
# [s, b, h]
241
234
query_length , batch_size = hidden_states .shape [:2 ]
@@ -248,7 +241,7 @@ def forward(
248
241
# [s, b, n, h//n]
249
242
q_layer , k_layer , v_layer = paddle .split (mixed_layer , 3 , axis = - 1 )
250
243
# [s, b, n, h/n]
251
- q_layer , k_layer = self ._core_attention (q_layer , k_layer , position_ids )
244
+ q_layer , k_layer = self ._core_attention (q_layer , k_layer , position_ids , rotary_embeds )
252
245
253
246
if cache is not None :
254
247
cache_k , cache_v = cache [0 ], cache [1 ]
@@ -342,6 +335,7 @@ def forward(
342
335
position_ids : Tensor ,
343
336
use_cache : bool = False ,
344
337
cache : Tensor = None ,
338
+ rotary_embeds : Tensor = None ,
345
339
):
346
340
# Layer norm before transformer layer
347
341
attention_input = self .input_layernorm (hidden_states )
@@ -353,6 +347,7 @@ def forward(
353
347
cache = cache ,
354
348
use_cache = use_cache ,
355
349
layer_id = self .layer_id ,
350
+ rotary_embeds = rotary_embeds ,
356
351
)
357
352
# Residual connection
358
353
alpha = (2 * self .config .num_hidden_layers ) ** 0.5
@@ -415,6 +410,13 @@ def __init__(self, config: ChatGLMConfig):
415
410
self .position_encoding_2d = config .position_encoding_2d
416
411
self .hidden_size = config .hidden_size
417
412
self .enable_recompute = config .recompute
413
+ self .num_attention_heads = config .num_attention_heads
414
+ self .rotary_embeddings = RotaryEmbeddings (
415
+ self .hidden_size // (self .num_attention_heads * 2 )
416
+ if self .position_encoding_2d
417
+ else self .hidden_size // self .num_attention_heads ,
418
+ base = 10000.0 ,
419
+ )
418
420
# self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
419
421
420
422
if self .config .tensor_parallel_degree > 1 :
@@ -468,6 +470,7 @@ def recompute_training(
468
470
position_ids : Tensor ,
469
471
use_cache : bool ,
470
472
cache : Tensor ,
473
+ rotary_embeds : Tensor ,
471
474
):
472
475
def create_custom_forward (module ):
473
476
def custom_forward (* inputs ):
@@ -482,6 +485,7 @@ def custom_forward(*inputs):
482
485
position_ids ,
483
486
use_cache ,
484
487
cache ,
488
+ rotary_embeds ,
485
489
use_reentrant = False ,
486
490
)
487
491
return hidden_states
@@ -509,6 +513,8 @@ def forward(
509
513
inputs_embeds = self .word_embeddings (input_ids )
510
514
inputs_embeds = inputs_embeds .transpose ([1 , 0 , 2 ])
511
515
516
+ rotary_embeds = self .rotary_embeddings (position_ids )
517
+
512
518
if cache is None :
513
519
if self .config .pre_seq_len is not None :
514
520
cache = self .get_prompt (batch_size = input_ids .shape [0 ], dtype = inputs_embeds .dtype )
@@ -537,6 +543,7 @@ def forward(
537
543
position_ids = position_ids ,
538
544
use_cache = use_cache ,
539
545
cache = cache_i ,
546
+ rotary_embeds = rotary_embeds ,
540
547
)
541
548
else :
542
549
hidden_states , new_cache = layer (
@@ -545,6 +552,7 @@ def forward(
545
552
position_ids = position_ids ,
546
553
use_cache = use_cache ,
547
554
cache = cache_i ,
555
+ rotary_embeds = rotary_embeds ,
548
556
)
549
557
550
558
if use_cache :
0 commit comments