@@ -32,7 +32,6 @@ query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, de
32
32
F.scaled_dot_product_attention(query, key, value)
33
33
```
34
34
35
- # out
36
35
```
37
36
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
38
37
-1.2593],
@@ -51,7 +50,7 @@ tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
51
50
52
51
# 显式调度器控制
53
52
虽然该函数会隐式地分派到三种实现之一,但用户也可以通过使用上下文管理器(context manager)来显式控制分派。这个上下文管理器允许用户显式禁用某些实现。如果用户想确保函数确实针对他们的特定输入使用最快的实现,可以使用上下文管理器来遍历并测量性能。
54
- ```
53
+ ``` py
55
54
# Lets define a helpful benchmarking function:
56
55
import torch.utils.benchmark as benchmark
57
56
def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
@@ -97,7 +96,7 @@ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
97
96
print (" EfficientAttention is not supported. See warnings for reasons." )
98
97
```
99
98
100
- out
99
+
101
100
```
102
101
The default implementation runs in 2304.977 microseconds
103
102
The math implementation runs in 19249.369 microseconds
@@ -166,7 +165,7 @@ model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension
166
165
print (model)
167
166
```
168
167
169
- out
168
+
170
169
```
171
170
CausalSelfAttention(
172
171
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
@@ -231,7 +230,7 @@ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
231
230
print (" FlashAttention is not supported. See warnings for reasons." )
232
231
```
233
232
234
- out
233
+
235
234
```
236
235
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning:
237
236
@@ -260,7 +259,7 @@ print(
260
259
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
261
260
```
262
261
263
- out
262
+
264
263
```
265
264
The non compiled module runs in 408.207 microseconds
266
265
The compiled module runs in 516.612 microseconds
@@ -294,7 +293,7 @@ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
294
293
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
295
294
```
296
295
297
- out
296
+
298
297
```
299
298
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
300
299
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
@@ -399,7 +398,7 @@ compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
399
398
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
400
399
```
401
400
402
- out
401
+
403
402
``` py
404
403
< class ' torch.nn.attention.bias.CausalBias' >
405
404
< class ' torch.nn.attention.bias.CausalBias' >
0 commit comments