Professional Documents
Culture Documents
hw3-report
hw3-report
hw3-report
吴辰禹
航天航空学院
日期:2023 年 12 月 28 日
摘 要
使用标准的多头注意力机制实现 Transformer。具体实现的方式如下面的公式所示。假设序
列的 query,key 和 value 向量组成的矩阵为:
Q ∈ RS×d , K ∈ RT ×d , V ∈ RT ×d (1)
其中,S 为 query 序列的长度,T 为需要 pay attention to 的序列的长度,d 为 embedding 的维数。
假设 query,key 和 value 的 projection 变换分别为 Wq , Wk , Wv ∈ Rd×d ,多头注意力的头数为 h,
那么的 Q, K, V 将按照下式被 project 和分割:
[Q1 , · · · , Qh ] = Q · Wq , [K1 , · · · , Kh ] = K · Wk , [V1 , · · · , Vh ] = V · Wv (2)
得到 projection 和头分割的结果后,我们按照下面的方法计算每个 head 的 attention:
Q i · Ki
Ai = softmax( √ , dim = 1) (3)
d/h
然后,头 i 的 attention 矩阵 Pi 经过一次 dropout,再和 Vi 矩阵相乘,得到该头的输出 Zi ∈ RS×d/h :
Zi = dropout(Pi ) · Vi (4)
随后,每个头的输出被横向拼接在一起,再经过一次 Projection,得到最终的 attention 输出 Z ∈
S×
Z = [Z1 , · · · , Zh ] · Wo (5)
1
图 1: Full attention 模型训练历史
2
不知道 k + 1 及以后的 token 到底是什么。因此,如果我们在训练阶段允许模型”窥探未来“,将
会导致训练-推断不相容的问题。所以,我们需要在训练阶段就添加 mask,使得即便在输入一整
个待预测序列(可能由占位符组成)的情况下,模型也不能在预测第 k 个 token 时动用第 k + 1
及之后的 token 信息。
1.2.2 具体实现方法
图 4: Encoder layer 0
3
图 5: Encoder layer 1
4
图 7: Decoder cross attention
1.4.1 使用 Linformer
5
图 8: Linformer 训练历史
1.4.2 使用 RevIN
均值和方差都在时间方向上计算。上述方法相当于首先对数据进行均值-方差 scale,然后再逐通
道进行线性的 shifting,而 shifting 的参数 γk , βk 则是可学习的——从数据中显式地学到 shifting
的规律,并存储在 γk , βk 中。对模型的输出,使用反变换得到真实的预测序列:
√
y i − βk
i
ŷkt = Var[xikt ] + ϵ( kt ) + Et (xikt ) (9)
γk
我们直接使用 RevIN 的官方实现 [1],得到的训练过程曲线如图9所示。
图 9: Transformer+RevIN 训练历史
1.4.3 性能对比
6
Test MSE Test MAE
7
2.2 Loss Implementation
在本节,我们使用最基础的交叉熵损失函数。我们将题目中要求最大化的函数转化为最小
化的损失函数,并写为交叉熵形式,生成器的损失函数为:
∑
n
min − yi log(D(G(zi ))), yi = 1 (10)
G
i=1
其中,zi 为第 i 个样本的噪声向量,yi 为第 i 个样本的 label,它恒定为 1,D 为判别器。判别器
的损失函数为:
∑
m ∑
n
min − yi log(D(xi )) − qj log(1 − D(G(zj ))), yi = 1, qj = 1 (11)
D
i=1 j=1
其中,第一个求和在真实样本上进行,第二个求和在生成样本上进行,xi 为第 i 个真实样本,zj
为第 j 个生成样本的噪声向量,yi 和 qj 恒定为 1。对于上述两个损失函数的具体实现,我们都
是用 PyTorch 的 nn.BCEWithLogitsLoss 函数。
2.3 Training
标准 GAN 生成的手写体数字如图13所示。大多数数字是可以辨识的,但是某些生成的结果
十分混乱,例如13左下角的图片。此外,生成的图片中有大量明显的噪点。
8
图 13: 标准 GAN 生成的手写体数字
9
图 14: LSGAN 的 Loss
10
图 16: DCGAN 的结构
11
图 18: DCGAN 生成的手写体数字
参考文献
[1] Ts-Kim. TS-Kim/Revin: RevIN: Reversible instance normalization for accurate time-series forecasting against distribution
shift. 2023. URL: https://github.com/ts-kim/RevIN.
[2] Taesung Kim et al. “Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift”.
In: International Conference on Learning Representations. 2021. URL: https : / / openreview . net / forum ? id =
cGDAkQo1C0p.
[3] tatp22. Causal mask of the decoder, issue 16, linformer-pytorch. 2020. URL: https://github.com/tatp22/linformer-
pytorch/issues/16.
[4] Sinong Wang et al. Linformer: Self-Attention with Linear Complexity. 2020. arXiv: 2006.04768 [cs.LG].
12