|
1 |
| -> 翻译任务 |
2 |
| -
|
3 |
| -* 目前该页面无人翻译,期待你的加入 |
4 |
| -* 翻译奖励: https://github.com/orgs/apachecn/discussions/243 |
5 |
| -* 任务认领: https://github.com/apachecn/pytorch-doc-zh/discussions/583 |
6 |
| - |
7 |
| -请参考这个模版来写内容: |
8 |
| - |
9 |
| - |
10 |
| -# PyTorch 某某页面 |
11 |
| - |
12 | 1 | > 译者:[片刻小哥哥](https://github.com/jiangzhonglian)
|
13 | 2 | >
|
14 | 3 | > 项目地址:<https://pytorch.apachecn.org/2.0/tutorials/intermediate/scaled_dot_product_attention_tutorial#using-sdpa-with-attn-bias-subclasses>
|
15 | 4 | >
|
16 | 5 | > 原始地址:<https://pytorch.org/tutorials//intermediate/scaled_dot_product_attention_tutorial#using-sdpa-with-attn-bias-subclasses.html>
|
17 | 6 |
|
18 |
| -开始写原始页面的翻译内容 |
19 |
| - |
20 |
| - |
21 |
| - |
22 |
| -注意事项: |
23 |
| - |
24 |
| -1. 代码参考: |
25 |
| - |
26 |
| -```py |
27 |
| -import torch |
28 |
| - |
29 |
| -x = torch.ones(5) # input tensor |
30 |
| -y = torch.zeros(3) # expected output |
31 |
| -w = torch.randn(5, 3, requires_grad=True) |
32 |
| -b = torch.randn(3, requires_grad=True) |
33 |
| -z = torch.matmul(x, w)+b |
34 |
| -loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) |
35 |
| -``` |
36 |
| - |
37 |
| -2. 公式参考: |
38 |
| - |
39 |
| -1) 无需换行的写法: |
40 |
| - |
41 |
| -$\sqrt{w^T*w}$ |
42 |
| - |
43 |
| -2) 需要换行的写法: |
44 | 7 |
|
45 |
| -$$ |
46 |
| -\sqrt{w^T*w} |
47 |
| -$$ |
48 |
| - |
49 |
| -3. 图片参考(用图片的实际地址就行): |
50 |
| - |
51 |
| -<img src='http://data.apachecn.org/img/logo/logo_green.png' width=20% /> |
52 |
| - |
53 |
| -4. **翻译完后请删除上面所有模版内容就行** |
54 |
| - |
55 |
| - |
56 |
| - |
57 |
| -# 概要 |
| 8 | +# 摘要 |
58 | 9 | 在本教程中,我们将介绍一个新的torch.nn.functional函数,它对于实现机器翻译架构非常有帮助。这个函数名为torch.nn.functional.scaled_dot_product_attention。有关该函数的详细描述,请参阅PyTorch文档。此函数已经被整合到torch.nn.MultiheadAttention和torch.nn.TransformerEncoderLayer中。
|
59 | 10 |
|
60 | 11 | # 概述
|
|
70 | 21 | 本教程需要PyTorch 2.0.0或更高版本。
|
71 | 22 | ```
|
72 | 23 |
|
73 |
| -``` |
| 24 | +```py |
74 | 25 | import torch
|
75 | 26 | import torch.nn as nn
|
76 | 27 | import torch.nn.functional as F
|
@@ -162,7 +113,7 @@ The memory efficient implementation runs in 4197.082 microseconds
|
162 | 113 | # 因果自注意力(Causal Self Attention)
|
163 | 114 | 下面是一个因果自注意力(multi-headed causal self attention)块的示例实现,灵感来源于Andrej Karpathy的NanoGPT仓库。
|
164 | 115 |
|
165 |
| -``` |
| 116 | +```py |
166 | 117 | class CausalSelfAttention(nn.Module):
|
167 | 118 |
|
168 | 119 | def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
|
@@ -227,7 +178,7 @@ CausalSelfAttention(
|
227 | 178 | # NestedTensor 和 Dense 张量支持
|
228 | 179 | SDPA支持NestedTensor和Dense张量输入。NestedTensors处理的情况是输入是一个不等长序列的批次,而无需将每个序列填充到批次中的最大长度。有关NestedTensors的更多信息,请参阅torch.nested和NestedTensors教程。
|
229 | 180 |
|
230 |
| -``` |
| 181 | +```py |
231 | 182 | import random
|
232 | 183 | def generate_rand_batch(
|
233 | 184 | batch_size,
|
@@ -317,7 +268,7 @@ The compiled module runs in 516.612 microseconds
|
317 | 268 |
|
318 | 269 | 具体的执行时间取决于机器,但我的结果是:未编译的模块运行时间为166.616微秒,编译后的模块运行时间为166.726微秒。这并不是我们期望的结果。让我们深入探究一下。PyTorch内置了一个惊人的性能分析器(profiler),您可以使用它来检查代码的性能特征。
|
319 | 270 |
|
320 |
| -``` |
| 271 | +```py |
321 | 272 | from torch.profiler import profile, record_function, ProfilerActivity
|
322 | 273 | activities = [ProfilerActivity.CPU]
|
323 | 274 | if device == 'cuda':
|
@@ -394,7 +345,7 @@ Self CUDA time total: 20.514ms
|
394 | 345 | The current argument is_causal in torch.nn.functional.scaled_dot_product_attention is the same as using torch.nn.attention.bias.causal_upper_left.
|
395 | 346 | ```
|
396 | 347 |
|
397 |
| -``` |
| 348 | +```py |
398 | 349 | from torch.nn.attention.bias import causal_lower_right, causal_upper_left
|
399 | 350 |
|
400 | 351 | batch_size = 32
|
@@ -449,7 +400,7 @@ out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
|
449 | 400 | ```
|
450 | 401 |
|
451 | 402 | out
|
452 |
| -``` |
| 403 | +```py |
453 | 404 | <class 'torch.nn.attention.bias.CausalBias'>
|
454 | 405 | <class 'torch.nn.attention.bias.CausalBias'>
|
455 | 406 | tensor([[ True, False, False, False, False, False, False, False, False, False],
|
|
0 commit comments