7
7
8
8
Hacked together by / Copyright 2021 Ross Wightman
9
9
"""
10
- from typing import Union , Tuple
10
+ from typing import Optional , Union , Tuple
11
11
12
12
import torch
13
13
import torch .nn as nn
14
14
15
+ from . config import use_fused_attn
15
16
from .helpers import to_2tuple
17
+ from .pos_embed import resample_abs_pos_embed
16
18
from .pos_embed_sincos import apply_rot_embed , RotaryEmbedding
17
19
from .weight_init import trunc_normal_
18
20
@@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module):
27
29
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
28
30
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
29
31
"""
32
+ fused_attn : torch .jit .Final [bool ]
33
+
30
34
def __init__ (
31
35
self ,
32
36
in_features : int ,
33
- out_features : int = None ,
34
- embed_dim : int = None ,
35
- num_heads : int = 4 ,
37
+ out_features : Optional [int ] = None ,
38
+ ref_feat_size : Union [int , Tuple [int , int ]] = 7 ,
39
+ embed_dim : Optional [int ] = None ,
40
+ head_dim : Optional [int ] = 64 ,
41
+ num_heads : Optional [int ] = None ,
36
42
qkv_bias : bool = True ,
43
+ qkv_separate : bool = False ,
37
44
):
38
45
super ().__init__ ()
39
46
embed_dim = embed_dim or in_features
40
- out_features = out_features or in_features
41
- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
42
- self .proj = nn .Linear (embed_dim , out_features )
47
+ self .in_features = in_features
48
+ self .out_features = out_features or in_features
49
+ ref_feat_size = to_2tuple (ref_feat_size )
50
+ if num_heads is not None :
51
+ assert embed_dim % num_heads == 0
52
+ head_dim = embed_dim // num_heads
53
+ else :
54
+ assert embed_dim % head_dim == 0
55
+ num_heads = embed_dim // head_dim
43
56
self .num_heads = num_heads
44
- assert embed_dim % num_heads == 0
45
- self .head_dim = embed_dim // num_heads
57
+ self .head_dim = head_dim
46
58
self .scale = self .head_dim ** - 0.5
47
- self .pos_embed = RotaryEmbedding (self .head_dim )
48
-
49
- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
50
- nn .init .zeros_ (self .qkv .bias )
59
+ self .fused_attn = use_fused_attn ()
60
+
61
+ if qkv_separate :
62
+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
63
+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
64
+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
65
+ self .qkv = None
66
+ else :
67
+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
68
+ self .proj = nn .Linear (embed_dim , self .out_features )
69
+ self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
70
+
71
+ def init_weights (self , zero_init_last : bool = False ):
72
+ if self .qkv is None :
73
+ in_features = self .q .in_features
74
+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
75
+ nn .init .zeros_ (self .q .bias )
76
+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
77
+ nn .init .zeros_ (self .k .bias )
78
+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
79
+ nn .init .zeros_ (self .v .bias )
80
+ else :
81
+ in_features = self .qkv .in_features
82
+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
83
+ nn .init .zeros_ (self .qkv .bias )
51
84
52
85
def forward (self , x ):
53
86
B , _ , H , W = x .shape
54
87
N = H * W
55
- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
56
-
88
+ x = x .flatten (2 ).transpose (1 , 2 )
57
89
x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
58
-
59
- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
60
- q , k , v = x [0 ], x [1 ], x [2 ]
61
-
62
- qc , q = q [:, :, :1 ], q [:, :, 1 :]
63
- sin_emb , cos_emb = self .pos_embed .get_embed ((H , W ))
64
- q = apply_rot_embed (q , sin_emb , cos_emb )
65
- q = torch .cat ([qc , q ], dim = 2 )
66
-
67
- kc , k = k [:, :, :1 ], k [:, :, 1 :]
68
- k = apply_rot_embed (k , sin_emb , cos_emb )
69
- k = torch .cat ([kc , k ], dim = 2 )
70
-
71
- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
72
- attn = attn .softmax (dim = - 1 )
73
-
74
- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
90
+ if self .qkv is None :
91
+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
92
+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
93
+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
94
+ else :
95
+ x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
96
+ q , k , v = x .unbind (0 )
97
+
98
+ rse , rce = self .pos_embed .get_embed ((H , W ))
99
+ q = torch .cat ([q [:, :, :1 , :], apply_rot_embed (q [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
100
+ k = torch .cat ([k [:, :, :1 , :], apply_rot_embed (k [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
101
+
102
+ if self .fused_attn :
103
+ x = nn .functional .scaled_dot_product_attention (q , k , v )
104
+ else :
105
+ q = q * self .scale
106
+ attn = q @ k .transpose (- 2 , - 1 )
107
+ attn = attn .softmax (dim = - 1 )
108
+ x = attn @ v
109
+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
75
110
x = self .proj (x )
76
111
return x [:, 0 ]
77
112
@@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module):
85
120
86
121
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87
122
"""
123
+ fused_attn : torch .jit .Final [bool ]
124
+
88
125
def __init__ (
89
126
self ,
90
127
in_features : int ,
91
- feat_size : Union [int , Tuple [int , int ]],
92
- out_features : int = None ,
93
- embed_dim : int = None ,
94
- num_heads : int = 4 ,
128
+ feat_size : Union [int , Tuple [int , int ]] = 7 ,
129
+ out_features : Optional [int ] = None ,
130
+ embed_dim : Optional [int ] = None ,
131
+ head_dim : Optional [int ] = 64 ,
132
+ num_heads : Optional [int ] = None ,
95
133
qkv_bias : bool = True ,
134
+ qkv_separate : bool = False ,
96
135
):
97
136
super ().__init__ ()
98
-
99
137
embed_dim = embed_dim or in_features
100
- out_features = out_features or in_features
101
- assert embed_dim % num_heads == 0
138
+ self .in_features = in_features
139
+ self .out_features = out_features or in_features
140
+ if num_heads is not None :
141
+ assert embed_dim % num_heads == 0
142
+ head_dim = embed_dim // num_heads
143
+ else :
144
+ assert embed_dim % head_dim == 0
145
+ num_heads = embed_dim // head_dim
102
146
self .feat_size = to_2tuple (feat_size )
103
- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
104
- self .proj = nn .Linear (embed_dim , out_features )
147
+ self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
105
148
self .num_heads = num_heads
106
- self .head_dim = embed_dim // num_heads
149
+ self .head_dim = head_dim
107
150
self .scale = self .head_dim ** - 0.5
108
-
109
- spatial_dim = self .feat_size [0 ] * self .feat_size [1 ]
110
- self .pos_embed = nn .Parameter (torch .zeros (spatial_dim + 1 , in_features ))
151
+ self .fused_attn = use_fused_attn ()
152
+
153
+ if qkv_separate :
154
+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
155
+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
156
+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
157
+ self .qkv = None
158
+ else :
159
+ self .q = self .k = self .v = None
160
+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
161
+ self .proj = nn .Linear (embed_dim , self .out_features )
162
+ self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
163
+
164
+ self .init_weights ()
165
+
166
+ def init_weights (self , zero_init_last : bool = False ):
167
+ if self .qkv is None :
168
+ in_features = self .q .in_features
169
+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
170
+ nn .init .zeros_ (self .q .bias )
171
+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
172
+ nn .init .zeros_ (self .k .bias )
173
+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
174
+ nn .init .zeros_ (self .v .bias )
175
+ else :
176
+ in_features = self .qkv .in_features
177
+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
178
+ nn .init .zeros_ (self .qkv .bias )
111
179
trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
112
- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
113
- nn .init .zeros_ (self .qkv .bias )
114
180
115
181
def forward (self , x ):
116
182
B , _ , H , W = x .shape
117
183
N = H * W
118
- assert self .feat_size [0 ] == H
119
- assert self .feat_size [1 ] == W
120
- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
184
+ x = x .flatten (2 ).transpose (1 , 2 )
121
185
x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
122
- x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
123
-
124
- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
125
- q , k , v = x [0 ], x [1 ], x [2 ]
126
- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
127
- attn = attn .softmax (dim = - 1 )
128
-
129
- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
186
+ if self .seq_len != N :
187
+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
188
+ else :
189
+ pos_embed = self .pos_embed .unsqueeze (0 ).to (x .dtype )
190
+ x = x + pos_embed
191
+
192
+ if self .qkv is None :
193
+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
194
+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
195
+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
196
+ else :
197
+ x = self .qkv (x ).reshape (B , - 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
198
+ q , k , v = x .unbind (0 )
199
+
200
+ if self .fused_attn :
201
+ x = nn .functional .scaled_dot_product_attention (q , k , v )
202
+ else :
203
+ q = q * self .scale
204
+ attn = q @ k .transpose (- 2 , - 1 )
205
+ attn = attn .softmax (dim = - 1 )
206
+ x = attn @ v
207
+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
130
208
x = self .proj (x )
131
209
return x [:, 0 ]
0 commit comments