15
15
import torch .nn as nn
16
16
from typing_extensions import Self
17
17
18
+ from litgpt .attention import DefaultKeysAndValues , MultiHeadSelfAttention
18
19
from litgpt .config import Config as BaseConfig
20
+ from litgpt .kvcache .base import KVCache
19
21
from litgpt .model import GPT as BaseModel
20
22
from litgpt .model import Block as BaseBlock
21
23
from litgpt .model import CausalSelfAttention as BaseCausalSelfAttention
@@ -34,16 +36,23 @@ def __init__(self, config: Config) -> None:
34
36
assert config .padded_vocab_size is not None
35
37
self .config = config
36
38
37
- self .lm_head = nn .Linear (config .n_embd , config .padded_vocab_size , bias = config .lm_head_bias )
39
+ self .lm_head = nn .Linear (
40
+ config .n_embd ,
41
+ config .padded_vocab_size ,
42
+ bias = config .lm_head_bias ,
43
+ )
38
44
self .transformer = nn .ModuleDict (
39
45
dict (
40
46
wte = nn .Embedding (config .padded_vocab_size , config .n_embd ),
41
47
h = nn .ModuleList (Block (config , block_idx ) for block_idx in range (config .n_layer )),
42
48
ln_f = config .norm_class (config .n_embd , eps = config .norm_eps ),
43
49
)
44
50
)
45
- self .mask_cache : Optional [ torch . Tensor ] = None
51
+ self .mha = MultiHeadSelfAttention ( config )
46
52
self .max_seq_length = self .config .block_size
53
+ self ._start_of_layer_hook = config .start_of_layer_hook
54
+ # Have dense KV caches been created by `set_kv_cache`?
55
+ self ._default_kv_cache = False
47
56
48
57
@classmethod
49
58
def from_name (cls , name : str , ** kwargs : Any ) -> Self :
@@ -57,56 +66,80 @@ def _init_weights(self, module: nn.Module) -> None:
57
66
58
67
59
68
class Block (BaseBlock ):
60
- def __init__ (self , config : Config , block_idx : int ) -> None :
61
- super ().__init__ (config , block_idx )
62
- self .attn = CausalSelfAttention (config , block_idx )
69
+ def __init__ (
70
+ self ,
71
+ config : Config ,
72
+ block_idx : int ,
73
+ kv_cache : Optional [KVCache ] = None ,
74
+ ) -> None :
75
+ super ().__init__ (config , block_idx , kv_cache )
76
+ self .attn = CausalSelfAttention (config , block_idx , kv_cache = kv_cache )
63
77
64
78
65
79
class CausalSelfAttention (BaseCausalSelfAttention ):
66
80
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
67
81
over the adaption prompt."""
68
82
69
- def __init__ (self , config : Config , block_idx : int ) -> None :
70
- super ().__init__ (config , block_idx )
71
- if block_idx >= config .adapter_start_layer :
83
+ def __init__ (
84
+ self ,
85
+ config : Config ,
86
+ block_idx : int ,
87
+ kv_cache : Optional [KVCache ] = None ,
88
+ ) -> None :
89
+ super ().__init__ (
90
+ config = config ,
91
+ block_idx = block_idx ,
92
+ kv_cache = kv_cache ,
93
+ )
94
+ self ._extend_forward = block_idx >= config .adapter_start_layer
95
+ if self ._extend_forward :
72
96
# adapter embedding layer
73
97
self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
74
98
# gate for adaption
75
99
self .gating_factor = torch .nn .Parameter (torch .zeros (1 , 1 , config .n_head , 1 ))
76
100
# kv cache for inference
77
101
self .adapter_kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
78
102
79
- def scaled_dot_product_attention (
80
- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
103
+ def _transform_output (
104
+ self ,
105
+ y : torch .Tensor ,
106
+ query : torch .Tensor ,
107
+ mha : MultiHeadSelfAttention ,
81
108
) -> torch .Tensor :
82
- y = super ().scaled_dot_product_attention (q , k , v , mask )
83
- if self .block_idx < self .config .adapter_start_layer :
84
- return y
85
-
86
- aT = self .config .adapter_prompt_length
87
- if self .adapter_kv_cache is not None :
88
- # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
89
- # are the same every call
90
- ak , av = self .adapter_kv_cache
91
- else :
92
- prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
93
- aqkv = self .qkv (prefix )
94
- q_per_kv = self .config .n_head // self .config .n_query_groups
95
- aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
96
- aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
97
- _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
98
- if self .config .n_query_groups != 1 :
99
- # for MHA this is a no-op
100
- ak = ak .repeat_interleave (q_per_kv , dim = 2 )
101
- av = av .repeat_interleave (q_per_kv , dim = 2 )
102
- ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
103
- av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
104
- self .adapter_kv_cache = (ak , av )
105
-
106
- T = q .size (2 )
107
- amask = torch .ones (T , aT , dtype = torch .bool , device = q .device )
108
- ay = super ().scaled_dot_product_attention (q , ak , av , amask )
109
- return y + self .gating_factor * ay
109
+ if self ._extend_forward :
110
+ B , T , _ = y .shape
111
+ y = y .view (B , T , self .config .n_head , self .config .head_size )
112
+ aT = self .config .adapter_prompt_length
113
+ if self .adapter_kv_cache is not None :
114
+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
115
+ # are the same every call
116
+ ak , av = self .adapter_kv_cache
117
+ else :
118
+ prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
119
+ aqkv = self .qkv (prefix )
120
+ q_per_kv = self .config .n_head // self .config .n_query_groups
121
+ aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
122
+ aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
123
+ _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
124
+ if self .config .n_query_groups != 1 :
125
+ # for MHA this is a no-op
126
+ ak = ak .repeat_interleave (q_per_kv , dim = 2 )
127
+ av = av .repeat_interleave (q_per_kv , dim = 2 )
128
+ ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
129
+ av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
130
+ self .adapter_kv_cache = (ak , av )
131
+
132
+ amask = torch .ones (T , aT , dtype = torch .bool , device = query .device )
133
+ a_k_and_v = DefaultKeysAndValues (keys = ak , values = av )
134
+ ay , _ = mha .scaled_dot_product_attention (
135
+ query = query ,
136
+ k_and_v = a_k_and_v ,
137
+ mask = amask ,
138
+ is_causal = False ,
139
+ )
140
+ y = (y + self .gating_factor * ay ).view (B , T , - 1 )
141
+
142
+ return y
110
143
111
144
def reset_parameters (self ) -> None :
112
145
if hasattr (self , "gating_factor" ):
0 commit comments