Skip to content

Commit 976f8c5

Browse files
authored
Optimize ChatGLM performance (#6104)
* [speed] use cache for rotary embeddings * [speed] remove repetitive computation * [speed] remove repetitive computation
1 parent 78ee064 commit 976f8c5

File tree

1 file changed

+75
-67
lines changed

1 file changed

+75
-67
lines changed

paddlenlp/transformers/chatglm/modeling.py

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -94,53 +94,60 @@ def forward(self, prefix: paddle.Tensor):
9494

9595

9696
class RotaryEmbeddings(nn.Layer):
97-
def __init__(self, hidden_size, base=10000.0, learnable=False):
97+
def __init__(self, hidden_size, base=10000.0, position_encoding_2d=True):
9898
super().__init__()
9999
self.dtype = paddle.get_default_dtype()
100100
inv_freq = 1.0 / (base ** (paddle.arange(0, hidden_size, 2).astype("float32") / hidden_size))
101101
inv_freq = inv_freq.astype(self.dtype)
102-
self.learnable = learnable
103-
if learnable:
104-
self.inv_freq = nn.Parameter(inv_freq)
105-
self.max_seq_len_cached = None
102+
self.position_encoding_2d = position_encoding_2d
103+
self.register_buffer("inv_freq", inv_freq)
104+
self.max_seq_len_cached = -1
105+
self.cos_cached = None
106+
self.sin_cached = None
107+
108+
def get_rotary_embeds(self, cos, sin, position_ids):
109+
# [s, b, 1, h/n]
110+
cos = cos.squeeze(1)[position_ids].unsqueeze(2)
111+
sin = sin.squeeze(1)[position_ids].unsqueeze(2)
112+
return paddle.stack([cos, sin], axis=0)
113+
114+
def forward(self, position_ids):
115+
seq_len = position_ids.max() + 1
116+
if self.max_seq_len_cached < 0 or seq_len > self.max_seq_len_cached:
117+
self.max_seq_len_cached = seq_len
118+
119+
# x.shape = [b, s, n, h/n/2]
120+
t = paddle.arange(seq_len, dtype=self.inv_freq.dtype)
121+
# [s, h/n/2]
122+
# TODO: Failed for fp16 when converting to static graph.
123+
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
124+
freqs = freqs.cast(self.dtype)
125+
# [s, h/n]
126+
emb = paddle.concat([freqs, freqs], axis=-1)
127+
if self.dtype == paddle.bfloat16:
128+
emb = emb.cast("float32")
129+
# [s, 1, h/n]
130+
cos_cached = emb.cos().unsqueeze(1)
131+
sin_cached = emb.sin().unsqueeze(1)
132+
133+
if self.dtype == paddle.bfloat16:
134+
cos_cached = cos_cached.astype(self.dtype)
135+
sin_cached = sin_cached.astype(self.dtype)
136+
137+
self.cos_cached, self.sin_cached = cos_cached, sin_cached
138+
139+
cos, sin = self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
140+
if self.position_encoding_2d:
141+
block_position_ids = position_ids[:, 1, :].transpose([1, 0])
142+
position_ids = position_ids[:, 0, :].transpose([1, 0])
143+
block_rotary_embeds = self.get_rotary_embeds(cos, sin, block_position_ids)
144+
position_rotary_embeds = self.get_rotary_embeds(cos, sin, position_ids)
145+
rotary_embeds = paddle.stack([position_rotary_embeds, block_rotary_embeds], axis=0)
106146
else:
107-
self.register_buffer("inv_freq", inv_freq)
108-
self.max_seq_len_cached = None
109-
self.cos_cached = None
110-
self.sin_cached = None
111-
112-
def forward(self, x, seq_dim=1, seq_len=None):
113-
if seq_len is None:
114-
seq_len = x.shape[seq_dim]
115-
116-
# x.shape = [b, s, n, h/n/2]
117-
# TODO: Remove the condition for converting to static graph.
118-
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
119-
# self.max_seq_len_cached = None if self.learnable else seq_len
120-
# [s]
121-
t = paddle.arange(seq_len).astype(self.dtype)
122-
# [s, h/n/2]
123-
# TODO: Failed for fp16 when converting to static graph.
124-
freqs = paddle.einsum("i,j->ij", t.astype("float32"), self.inv_freq.astype("float32"))
125-
freqs = freqs.astype(self.dtype)
126-
# [s, h/n]
127-
emb = paddle.concat([freqs, freqs], axis=-1)
128-
if self.dtype == paddle.bfloat16:
129-
emb = emb.astype("float32")
130-
# [s, 1, h/n]
131-
cos_cached = emb.cos().unsqueeze(1)
132-
sin_cached = emb.sin().unsqueeze(1)
133-
134-
if self.dtype == paddle.bfloat16:
135-
cos_cached = cos_cached.astype(self.dtype)
136-
sin_cached = sin_cached.astype(self.dtype)
137-
138-
if self.learnable:
139-
return cos_cached, sin_cached
140-
141-
self.cos_cached, self.sin_cached = cos_cached, sin_cached
142-
143-
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
147+
position_ids = position_ids.transpose([1, 0])
148+
rotary_embeds = self.get_rotary_embeds(cos, sin, position_ids)
149+
150+
return rotary_embeds
144151

145152

146153
class ChatGLMAttention(nn.Layer):
@@ -161,13 +168,6 @@ def __init__(self, config: ChatGLMConfig):
161168
self.attention_head_size = config.hidden_size // config.num_attention_heads
162169
self.hidden_size = config.hidden_size
163170
self.position_encoding_2d = config.position_encoding_2d
164-
self.rotary_embeddings = RotaryEmbeddings(
165-
self.hidden_size // (self.num_attention_heads * 2)
166-
if self.position_encoding_2d
167-
else self.hidden_size // self.num_attention_heads,
168-
base=10000.0,
169-
learnable=False,
170-
)
171171
self.scale_mask_softmax = False
172172
self.dtype = paddle.get_default_dtype()
173173

@@ -192,40 +192,32 @@ def _rotate_half(self, x):
192192
x1, x2 = paddle.chunk(x, 2, axis=-1)
193193
return paddle.concat([-x2, x1], axis=-1)
194194

195-
def _apply_rotary_position_embed_index(self, q, k, cos, sin, position_ids):
195+
def _apply_rotary_position_embed_index(self, q, k, cos, sin):
196196
# q.shape = [s, b, n, h/n/2], cos.shape = [s, 1, h/n], position_ids.shape = [s, b]
197-
# [s, b, 1, h/n]
198-
cos = cos.squeeze(1)[position_ids].unsqueeze(2)
199-
sin = sin.squeeze(1)[position_ids].unsqueeze(2)
200197
# [s, b, n, h/n]
201198
q = q * cos + self._rotate_half(q) * sin
202199
k = k * cos + self._rotate_half(k) * sin
203200
return q, k
204201

205-
def _core_attention(self, q_layer: Tensor, k_layer: Tensor, position_ids: Tensor):
202+
def _core_attention(self, q_layer: Tensor, k_layer: Tensor, position_ids: Tensor, rotary_embeds: Tensor):
206203
# Set store_true, position_encoding_2d=False by default.
207204
if self.config.position_encoding_2d:
208205
# [s, b, n, h/n/2]
209206
q1, q2 = paddle.chunk(q_layer, 2, axis=-1)
210207
k1, k2 = paddle.chunk(k_layer, 2, axis=-1)
211-
# [s, 1, h/n]
212-
cos, sin = self.rotary_embeddings(q1, seq_len=position_ids.max() + 1)
213-
# [s, b]
214-
block_position_ids = position_ids[:, 1, :].transpose([1, 0])
215-
position_ids = position_ids[:, 0, :].transpose([1, 0])
208+
209+
pcos, psin = rotary_embeds[0][0], rotary_embeds[0][1]
210+
bcos, bsin = rotary_embeds[1][0], rotary_embeds[1][1]
216211

217212
# [s, b, n, h/n]
218-
q1, k1 = self._apply_rotary_position_embed_index(q1, k1, cos, sin, position_ids)
219-
q2, k2 = self._apply_rotary_position_embed_index(q2, k2, cos, sin, block_position_ids)
213+
q1, k1 = self._apply_rotary_position_embed_index(q1, k1, pcos, psin)
214+
q2, k2 = self._apply_rotary_position_embed_index(q2, k2, bcos, bsin)
220215
q_layer = paddle.concat([q1, q2], axis=-1)
221216
k_layer = paddle.concat([k1, k2], axis=-1)
222217
else:
223-
# [s, b]
224-
position_ids = position_ids.transpose([1, 0])
225-
# [s, 1, h/n]
226-
cos, sin = self.rotary_embeddings(q_layer, seq_len=position_ids.max() + 1)
218+
cos, sin = rotary_embeds[0], rotary_embeds[1]
227219
# [s, b, n, h/n]
228-
q_layer, k_layer = self._apply_rotary_position_embed_index(q_layer, k_layer, cos, sin, position_ids)
220+
q_layer, k_layer = self._apply_rotary_position_embed_index(q_layer, k_layer, cos, sin)
229221
return q_layer, k_layer
230222

231223
def forward(
@@ -236,6 +228,7 @@ def forward(
236228
use_cache: bool = False,
237229
cache: Tensor = None,
238230
layer_id=0,
231+
rotary_embeds=None,
239232
):
240233
# [s, b, h]
241234
query_length, batch_size = hidden_states.shape[:2]
@@ -248,7 +241,7 @@ def forward(
248241
# [s, b, n, h//n]
249242
q_layer, k_layer, v_layer = paddle.split(mixed_layer, 3, axis=-1)
250243
# [s, b, n, h/n]
251-
q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids)
244+
q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids, rotary_embeds)
252245

253246
if cache is not None:
254247
cache_k, cache_v = cache[0], cache[1]
@@ -342,6 +335,7 @@ def forward(
342335
position_ids: Tensor,
343336
use_cache: bool = False,
344337
cache: Tensor = None,
338+
rotary_embeds: Tensor = None,
345339
):
346340
# Layer norm before transformer layer
347341
attention_input = self.input_layernorm(hidden_states)
@@ -353,6 +347,7 @@ def forward(
353347
cache=cache,
354348
use_cache=use_cache,
355349
layer_id=self.layer_id,
350+
rotary_embeds=rotary_embeds,
356351
)
357352
# Residual connection
358353
alpha = (2 * self.config.num_hidden_layers) ** 0.5
@@ -415,6 +410,13 @@ def __init__(self, config: ChatGLMConfig):
415410
self.position_encoding_2d = config.position_encoding_2d
416411
self.hidden_size = config.hidden_size
417412
self.enable_recompute = config.recompute
413+
self.num_attention_heads = config.num_attention_heads
414+
self.rotary_embeddings = RotaryEmbeddings(
415+
self.hidden_size // (self.num_attention_heads * 2)
416+
if self.position_encoding_2d
417+
else self.hidden_size // self.num_attention_heads,
418+
base=10000.0,
419+
)
418420
# self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
419421

420422
if self.config.tensor_parallel_degree > 1:
@@ -468,6 +470,7 @@ def recompute_training(
468470
position_ids: Tensor,
469471
use_cache: bool,
470472
cache: Tensor,
473+
rotary_embeds: Tensor,
471474
):
472475
def create_custom_forward(module):
473476
def custom_forward(*inputs):
@@ -482,6 +485,7 @@ def custom_forward(*inputs):
482485
position_ids,
483486
use_cache,
484487
cache,
488+
rotary_embeds,
485489
use_reentrant=False,
486490
)
487491
return hidden_states
@@ -509,6 +513,8 @@ def forward(
509513
inputs_embeds = self.word_embeddings(input_ids)
510514
inputs_embeds = inputs_embeds.transpose([1, 0, 2])
511515

516+
rotary_embeds = self.rotary_embeddings(position_ids)
517+
512518
if cache is None:
513519
if self.config.pre_seq_len is not None:
514520
cache = self.get_prompt(batch_size=input_ids.shape[0], dtype=inputs_embeds.dtype)
@@ -537,6 +543,7 @@ def forward(
537543
position_ids=position_ids,
538544
use_cache=use_cache,
539545
cache=cache_i,
546+
rotary_embeds=rotary_embeds,
540547
)
541548
else:
542549
hidden_states, new_cache = layer(
@@ -545,6 +552,7 @@ def forward(
545552
position_ids=position_ids,
546553
use_cache=use_cache,
547554
cache=cache_i,
555+
rotary_embeds=rotary_embeds,
548556
)
549557

550558
if use_cache:

0 commit comments

Comments
 (0)