Skip to content

Commit 0733aec

Browse files
committed
增加torch.nn.functional函数
1 parent 3f3765c commit 0733aec

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

docs/2.0/tutorials/intermediate/scaled_dot_product_attention_tutorial#using-sdpa-with-attn-bias-subclasses.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, de
3232
F.scaled_dot_product_attention(query, key, value)
3333
```
3434

35-
# out
3635
```
3736
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
3837
-1.2593],
@@ -51,7 +50,7 @@ tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
5150

5251
# 显式调度器控制
5352
虽然该函数会隐式地分派到三种实现之一,但用户也可以通过使用上下文管理器(context manager)来显式控制分派。这个上下文管理器允许用户显式禁用某些实现。如果用户想确保函数确实针对他们的特定输入使用最快的实现,可以使用上下文管理器来遍历并测量性能。
54-
```
53+
```py
5554
# Lets define a helpful benchmarking function:
5655
import torch.utils.benchmark as benchmark
5756
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
@@ -97,7 +96,7 @@ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
9796
print("EfficientAttention is not supported. See warnings for reasons.")
9897
```
9998

100-
out
99+
101100
```
102101
The default implementation runs in 2304.977 microseconds
103102
The math implementation runs in 19249.369 microseconds
@@ -166,7 +165,7 @@ model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension
166165
print(model)
167166
```
168167

169-
out
168+
170169
```
171170
CausalSelfAttention(
172171
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
@@ -231,7 +230,7 @@ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
231230
print("FlashAttention is not supported. See warnings for reasons.")
232231
```
233232

234-
out
233+
235234
```
236235
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning:
237236
@@ -260,7 +259,7 @@ print(
260259
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
261260
```
262261

263-
out
262+
264263
```
265264
The non compiled module runs in 408.207 microseconds
266265
The compiled module runs in 516.612 microseconds
@@ -294,7 +293,7 @@ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
294293
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
295294
```
296295

297-
out
296+
298297
```
299298
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
300299
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
@@ -399,7 +398,7 @@ compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
399398
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
400399
```
401400

402-
out
401+
403402
```py
404403
<class 'torch.nn.attention.bias.CausalBias'>
405404
<class 'torch.nn.attention.bias.CausalBias'>

0 commit comments

Comments
 (0)