15
15
import torch .nn as nn
16
16
from typing_extensions import Self
17
17
18
+ from litgpt .attention import DefaultKeysAndValues , MultiHeadSelfAttention
18
19
from litgpt .config import Config as BaseConfig
20
+ from litgpt .kvcache .base import KVCache
19
21
from litgpt .model import GPT as BaseModel
20
22
from litgpt .model import Block as BaseBlock
21
23
from litgpt .model import CausalSelfAttention as BaseCausalSelfAttention
@@ -42,8 +44,9 @@ def __init__(self, config: Config) -> None:
42
44
ln_f = config .norm_class (config .n_embd , eps = config .norm_eps ),
43
45
)
44
46
)
45
- self .mask_cache : Optional [ torch . Tensor ] = None
47
+ self .mha = MultiHeadSelfAttention ( config )
46
48
self .max_seq_length = self .config .block_size
49
+ self ._default_kv_cache = False
47
50
48
51
@classmethod
49
52
def from_name (cls , name : str , ** kwargs : Any ) -> Self :
@@ -57,56 +60,80 @@ def _init_weights(self, module: nn.Module) -> None:
57
60
58
61
59
62
class Block (BaseBlock ):
60
- def __init__ (self , config : Config , block_idx : int ) -> None :
61
- super ().__init__ (config , block_idx )
62
- self .attn = CausalSelfAttention (config , block_idx )
63
+ def __init__ (
64
+ self ,
65
+ config : Config ,
66
+ block_idx : int ,
67
+ kv_cache : Optional [KVCache ] = None ,
68
+ ) -> None :
69
+ super ().__init__ (config , block_idx , kv_cache )
70
+ self .attn = CausalSelfAttention (config , block_idx , kv_cache = kv_cache )
63
71
64
72
65
73
class CausalSelfAttention (BaseCausalSelfAttention ):
66
74
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
67
75
over the adaption prompt."""
68
76
69
- def __init__ (self , config : Config , block_idx : int ) -> None :
70
- super ().__init__ (config , block_idx )
71
- if block_idx >= config .adapter_start_layer :
77
+ def __init__ (
78
+ self ,
79
+ config : Config ,
80
+ block_idx : int ,
81
+ kv_cache : Optional [KVCache ] = None ,
82
+ ) -> None :
83
+ super ().__init__ (
84
+ config = config ,
85
+ block_idx = block_idx ,
86
+ kv_cache = kv_cache ,
87
+ )
88
+ self ._extend_forward = block_idx >= config .adapter_start_layer
89
+ if self ._extend_forward :
72
90
# adapter embedding layer
73
91
self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
74
92
# gate for adaption
75
93
self .gating_factor = torch .nn .Parameter (torch .zeros (1 , 1 , config .n_head , 1 ))
76
94
# kv cache for inference
77
95
self .adapter_kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
78
96
79
- def scaled_dot_product_attention (
80
- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
97
+ def _transform_output (
98
+ self ,
99
+ y : torch .Tensor ,
100
+ query : torch .Tensor ,
101
+ mha : MultiHeadSelfAttention ,
81
102
) -> torch .Tensor :
82
- y = super ().scaled_dot_product_attention (q , k , v , mask )
83
- if self .block_idx < self .config .adapter_start_layer :
84
- return y
85
-
86
- aT = self .config .adapter_prompt_length
87
- if self .adapter_kv_cache is not None :
88
- # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
89
- # are the same every call
90
- ak , av = self .adapter_kv_cache
91
- else :
92
- prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
93
- aqkv = self .qkv (prefix )
94
- q_per_kv = self .config .n_head // self .config .n_query_groups
95
- aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
96
- aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
97
- _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
98
- if self .config .n_query_groups != 1 :
99
- # for MHA this is a no-op
100
- ak = ak .repeat_interleave (q_per_kv , dim = 2 )
101
- av = av .repeat_interleave (q_per_kv , dim = 2 )
102
- ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
103
- av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
104
- self .adapter_kv_cache = (ak , av )
105
-
106
- T = q .size (2 )
107
- amask = torch .ones (T , aT , dtype = torch .bool , device = q .device )
108
- ay = super ().scaled_dot_product_attention (q , ak , av , amask )
109
- return y + self .gating_factor * ay
103
+ if self ._extend_forward :
104
+ B , T , _ = y .shape
105
+ y = y .view (B , T , self .config .n_head , self .config .head_size )
106
+ aT = self .config .adapter_prompt_length
107
+ if self .adapter_kv_cache is not None :
108
+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
109
+ # are the same every call
110
+ ak , av = self .adapter_kv_cache
111
+ else :
112
+ prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
113
+ aqkv = self .qkv (prefix )
114
+ q_per_kv = self .config .n_head // self .config .n_query_groups
115
+ aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
116
+ aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
117
+ _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
118
+ if self .config .n_query_groups != 1 :
119
+ # for MHA this is a no-op
120
+ ak = ak .repeat_interleave (q_per_kv , dim = 2 )
121
+ av = av .repeat_interleave (q_per_kv , dim = 2 )
122
+ ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
123
+ av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
124
+ self .adapter_kv_cache = (ak , av )
125
+
126
+ amask = torch .ones (T , aT , dtype = torch .bool , device = query .device )
127
+ a_k_and_v = DefaultKeysAndValues (keys = ak , values = av )
128
+ ay , _ = mha .scaled_dot_product_attention (
129
+ query = query ,
130
+ k_and_v = a_k_and_v ,
131
+ mask = amask ,
132
+ is_causal = False ,
133
+ )
134
+ y = (y + self .gating_factor * ay ).view (B , T , - 1 )
135
+
136
+ return y
110
137
111
138
def reset_parameters (self ) -> None :
112
139
if hasattr (self , "gating_factor" ):
0 commit comments