hw3-report

You might also like

Download as pdf or txt
Download as pdf or txt
You are on page 1of 12

深度学习第三次作业

吴辰禹
航天航空学院

日期:2023 年 12 月 28 日

摘 要

本报告包含两部分。第一部分包括关于 Transformer 时序预测,第二部分则是 GAN 模型的实


现与在 MNIST 数据集上的应用。

1 Part 1. Transformer 时序预测

在本章,我们使用标准 Transformer 在 ETTh1 数据集上进行时序预测。输入时间步为 96 步,


label 时间步为 48 步,预测时间步为 720 步。

1.1 Task 1. 标准 attention 的实现

使用标准的多头注意力机制实现 Transformer。具体实现的方式如下面的公式所示。假设序
列的 query,key 和 value 向量组成的矩阵为:
Q ∈ RS×d , K ∈ RT ×d , V ∈ RT ×d (1)
其中,S 为 query 序列的长度,T 为需要 pay attention to 的序列的长度,d 为 embedding 的维数。
假设 query,key 和 value 的 projection 变换分别为 Wq , Wk , Wv ∈ Rd×d ,多头注意力的头数为 h,
那么的 Q, K, V 将按照下式被 project 和分割:
[Q1 , · · · , Qh ] = Q · Wq , [K1 , · · · , Kh ] = K · Wk , [V1 , · · · , Vh ] = V · Wv (2)
得到 projection 和头分割的结果后,我们按照下面的方法计算每个 head 的 attention:
Q i · Ki
Ai = softmax( √ , dim = 1) (3)
d/h
然后,头 i 的 attention 矩阵 Pi 经过一次 dropout,再和 Vi 矩阵相乘,得到该头的输出 Zi ∈ RS×d/h :
Zi = dropout(Pi ) · Vi (4)
随后,每个头的输出被横向拼接在一起,再经过一次 Projection,得到最终的 attention 输出 Z ∈

Z = [Z1 , · · · , Zh ] · Wo (5)

使用上述过程计算 attention,在 attention.py 提供的测例上,自相关、mask 自相关和互相关的相


对误差分别为 1.36 × 10−4 , 1.45 × 10− 4 和 1.29 × 10−4 。
训练时 Training-set 和 Validation-set 上的 loss 如图1所示(超参数和 run.py 中的默认设置保
持一致),最佳模型在 Epoch 2 获得。测试集上的 MSE 和 MAE 分别为 0.877,0.690。

1
图 1: Full attention 模型训练历史

最佳模型对测试集上的 0,20,40 和 60 号样本的最后一个 component 的预测效果如图2所


示。可以看到,Transformer 没能准确预测最后一个 component 的变化趋势。第 1166 号样本各个
component 的预测结果如图3所示。Transformer 能够较好地捕捉其他 component 的频率信息,但
是幅值以及趋势信息则不能很好地还原。

图 2: Full attention 模型训练历史

图 3: Full attention 模型训练历史

1.2 Mask 的具体实现

1.2.1 为什么需要加 Mask

训练时,我们往往会一次性地向 decoder 输入一整个待预测的序列(即便输入的只是 0 向量


之类的占位符)
,以保证训练效率。如果按照上一节式子中表达的 attention 以及输出计算方法计
算自注意力,第 k 个 token 的输出不仅依赖于前 k − 1 个 token 的 v,还会依赖于 k + 1 及以后的
token 的 v,attention 值 αkj , j > k 也非 0. 然而,在实际推断过程中,预测第 k 的 token 时我们并

2
不知道 k + 1 及以后的 token 到底是什么。因此,如果我们在训练阶段允许模型”窥探未来“,将
会导致训练-推断不相容的问题。所以,我们需要在训练阶段就添加 mask,使得即便在输入一整
个待预测序列(可能由占位符组成)的情况下,模型也不能在预测第 k 个 token 时动用第 k + 1
及之后的 token 信息。

1.2.2 具体实现方法

我们在计算 Qi 和 Ki 的矩阵乘积 Pi = Qi · KiT ∈ RS×T 之后,根据输入的 mask 矩阵


M ∈ RS×T 的值对矩阵乘积进行调整。具体来讲,M 中值为 0 的地方,Pi 对应的值被调整为
−1e9;M 中值为 1 的地方,Pi 的值保持不变:我们为 decoder 的自相关 attention 层添加 mask,
不对 encoder 自相关和 decoder 互相关添加 mask(因为它们要 pay attention to 的序列在推理阶段
总是完全已知的)对于 decoder 的自相关,M 为一个下三角矩阵,下三角(包括对角线)所有元
素为 1,其他元素为 0。
具体实现上,decoder 自相关的 M 矩阵在 Exp_main.construct_mask 函数中生成,在训练、验
证和测试阶段都被传入 Transformer 模型(因为在这三个阶段我们都使用占位符输入一整个待预
测序列)。

1.3 Attention map 的可视化

在下面的可视化中,图的每一行代表一个 query 对所有 key 的 attention。纵向放置的时间序


列是 query 序列,横向放置的序列是 key 序列。随机挑选样本进行 attention map 的可视化,encoder
layer 0 和 layer 1 的自相关如图4和图5所示。其中从左至右从上至下分别为 8 个 head 的 attention
结果。可以看到,不论是 layer 0 还是 layer 1,query 对强变化附近的 key 的注意力都很强(如
time step 50 附近的一个剧烈变化)
。此外,layer 0 和 layer 1 的 attention 矩阵都较为稀疏,每一行
上 attention 的分布都很集中,并不均匀分布。对比 layer 0 和 layer 1,layer 1 的 attention 在更多
的 key 处有较大值(attention 的集中度没有 layer0 高),说明 layer 1 捕捉到了具有更大尺度的时
序特征。

图 4: Encoder layer 0

3
图 5: Encoder layer 1

decoder 的自注意力和互注意力分别如图6, 7所示。由于输出序列有 720 个 timestep,分配到


各个 token 上的注意力相对较小。为了能更清晰地显示每个 query 对各个 key 的注意力分配,我
们将每一行 attention 的值归一化到 0~1 之间(实际训练和推理中依然保持行之和为 1 的条件)。
可以看到,head 1,4,5 的自注意力 attention map 依然具有较强的稀疏性,所有的 query 对剧烈
变化处(序列开头附近)的 attention 都较大。然而,其他 head 的自注意力则并没有那么稀疏,或
者只是在特定的 query 处较为稀疏。这些 head 试图捕捉的应该是尺度更大一些的信息。对于互
注意力,attention map 依然保持了较为明显的稀疏性,query 更加关注剧烈变化附近的 key。

图 6: Decoder self attention

4
图 7: Decoder cross attention

1.4 一些改进 Transformer 的尝试

1.4.1 使用 Linformer

Linformer 宣称将 Transformer 的时间复杂度从 O(n2 ) 降低到 O(n),并且在一些任务上取得


了和 Transformer 相当的效果 [4]。Linformer 的 attention 计算过程与普通的 Transformer 的不同
在于,在计算 Q 和 K 的矩阵乘积时,Linformer 先将 K ∈ RT ×d 矩阵通过全连接层 E 降维到
Rk×d , k ≪ d:
Q · [E(K)]T
A = softmax( √ , dim = −1) ∈ mathbbRS×k (6)
d
于是,这一步的计算复杂度从 O(n2 ) 下降到 O(nk),而 k 往往很小,所以 [4] 认为 Linformer 的
计算复杂度和序列长度的关系是线性的。此外,Linformer 还使用一个全连接层 F 将 V ∈ RT ×d
降维到 Rk×d ,然后计算输出:
Z = A · F (V ) ∈ RS×d (7)

我们在标准 Transformer 的基础上,将 decoder 的自注意力模块(所操作的序列长度最长)替换


为 Linformer,其他模块保持不变。需要注意的是,[4] 并没有给出 mask 的具体实现方法。社区
中讨论的 [3] 结果是,依然 mask A 矩阵的上半三角(即,最上方的 k × k 方阵的上三角)
。但是,
这样做其实并不能够阻止未来信息的泄露。就作者所知,目前对于 Linformer,还没有一个很好
的方法能够实现 mask。训练过程中 loss 下降曲线如图8所示。可以看到,Linformer 的 validation
loss 要低于相同设置下的标准 Transformer。然而,它在测试集上的效果却较差,MSE 和 MAE 分
别为。这可能是因为 Linformer 的 mask 实现不够完善,导致训练-推理不相容的问题。

5
图 8: Linformer 训练历史

1.4.2 使用 RevIN

RevIN[2] 是由 Kim 在 2022 年提出的一种提升 Transformer 预测效果的数据处理方式,并不


是某种特殊的注意力机制。在将数据输入给 Transformer 之前,RevIN 将数据按照下面的方式做
变换:
xi − Et (xikt )
x̂ikt = γk ( √kt ) + βk (8)
Var[xikt ] + ϵ

均值和方差都在时间方向上计算。上述方法相当于首先对数据进行均值-方差 scale,然后再逐通
道进行线性的 shifting,而 shifting 的参数 γk , βk 则是可学习的——从数据中显式地学到 shifting
的规律,并存储在 γk , βk 中。对模型的输出,使用反变换得到真实的预测序列:

y i − βk
i
ŷkt = Var[xikt ] + ϵ( kt ) + Et (xikt ) (9)
γk
我们直接使用 RevIN 的官方实现 [1],得到的训练过程曲线如图9所示。

图 9: Transformer+RevIN 训练历史

1.4.3 性能对比

Transformer, Transformer+RevIN,和 Linformer 在 1166 号样本上的预测结果对比图如图10所


示;三者的测试集 MSE 和 MAE 如表1所示。可以从表1看到,Transformer+RevIN 的预测效果最
好,Linformer 的预测效果最差。图10则表明,Transformer+RevIN 很好地捕捉了时序的 shifting
特征。Linformer 预测的各个 component 和真实值的偏差都较大,可能是因为 mask 的实现不够

6
Test MSE Test MAE

Transformer 0.877 0.690


Linformer 0.934 0.718
Transformer+RevIN 0.758 0.617

表 1: 三种模型在测试集上的 MSE 和 MAE

完善,导致训练-推理之间存在不相容性;另外,E, F 全连接层也可能会损失掉一部分 attention


信息。

图 10: 三种模型在 1166 号样本上的对比

2 Part 2. GAN 的实现

2.1 Model Implementation

使用 MLP 实现生成器和判别器。对于生成器,它的输入维度为噪声维度,经过两个 1024


维的隐藏层到达 28 × 28 = 784 的输出层。两个 1024 维的隐藏层后分别接有两个 ReLU 激活层,
在输出层后接有一个 tanh 激活层。对于判别器,它的输入维度为 28 × 28 = 784,经过两个 256
维的隐藏层,到达 1 维的输出层。两个 256 维的隐藏层后分别接有两个 ReLU 激活层,在输出
层后接有一个 sigmoid 激活层。生成器和判别器的结构如图 11所示。具体实现上,使用 PyTorch
的 Sequential 容器将各个层组合在一起。

图 11: GAN 的结构

7
2.2 Loss Implementation

在本节,我们使用最基础的交叉熵损失函数。我们将题目中要求最大化的函数转化为最小
化的损失函数,并写为交叉熵形式,生成器的损失函数为:

n
min − yi log(D(G(zi ))), yi = 1 (10)
G
i=1
其中,zi 为第 i 个样本的噪声向量,yi 为第 i 个样本的 label,它恒定为 1,D 为判别器。判别器
的损失函数为:

m ∑
n
min − yi log(D(xi )) − qj log(1 − D(G(zj ))), yi = 1, qj = 1 (11)
D
i=1 j=1
其中,第一个求和在真实样本上进行,第二个求和在生成样本上进行,xi 为第 i 个真实样本,zj
为第 j 个生成样本的噪声向量,yi 和 qj 恒定为 1。对于上述两个损失函数的具体实现,我们都
是用 PyTorch 的 nn.BCEWithLogitsLoss 函数。

2.3 Training

使用上述模型和损失函数(标准 GAN)进行训练。在训练过程中,我们使用 Adam 优化器,


总共优化 20 个 Epoch。优化过程中生成器和判别器的 Loss 如图12所示。可以看到,生成器的
Loss 在训练过程中不断下降,而判别器的 loss 则有上升趋势。这说明,生成器和判别器在训练
过程中都在不断提升自己的能力,二者的对抗十分激烈。最终,生成器的 Loss 下降到 0.7 左右,
判别器的 Loss 上升到 1.3 左右。

图 12: 标准 GAN 的 Loss

标准 GAN 生成的手写体数字如图13所示。大多数数字是可以辨识的,但是某些生成的结果
十分混乱,例如13左下角的图片。此外,生成的图片中有大量明显的噪点。

8
图 13: 标准 GAN 生成的手写体数字

2.4 LSGAN 的实现

本节使用类似于 least square 的损失函数对图11中的 GAN 进行训练。生成器的损失函数为:


1∑
n
min (D(G(zi )) − 1)2 (12)
G 2
i=1
判别器的损失函数为:
1∑ 1∑
m n
min (D(xi ) − 1)2 + (D(G(zj )))2 (13)
D 2 2
i=1 j=1

其中,xi 为第 i 个真实样本,zj 为第 j 个生成样本的噪声向量。训练 20 个 Epoch,得到的 loss 下


降曲线如图14所示。可以看到,生成器的 Loss 最终降低到 0.13 左右,判别器的 loss 上升到 0.25
左右,二者的对抗程度较为激烈。生成器生成的手写体数字如图15所示。和基础 GAN 相比,生
成的图片中噪点依然很多,数字辨识度也没有明显提升,不过没有出现基础 GAN 中出现的混乱
的情况(图13左下角)。

9
图 14: LSGAN 的 Loss

图 15: LSGAN 生成的手写体数字

2.5 DCGAN 的实现

本节使用卷积网络实现 GAN 模型,并使用和 least square GAN 相同的损失函数进行训练。


模型的结构如图16所示。训练过程的 loss 下降曲线如图17所示。在 DCGAN 的训练过程中,生成
器和判别器的 loss 的变化趋势和前两个模型相反,生成器的 loss 上升,判别器的 loss 下降。loss
的变化趋势较为平滑。最终生成的手写体数字如图 18所示。可以看到,生成效果明显优于前两
个模型,生成的图片中噪点较少,数字的辨识度也有明显提升。但是,生成数字的多样性依然
较低,这可能和 GAN 面临的模式坍塌问题有关。

10
图 16: DCGAN 的结构

图 17: DCGAN 的 Loss

11
图 18: DCGAN 生成的手写体数字

参考文献
[1] Ts-Kim. TS-Kim/Revin: RevIN: Reversible instance normalization for accurate time-series forecasting against distribution
shift. 2023. URL: https://github.com/ts-kim/RevIN.

[2] Taesung Kim et al. “Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift”.
In: International Conference on Learning Representations. 2021. URL: https : / / openreview . net / forum ? id =
cGDAkQo1C0p.

[3] tatp22. Causal mask of the decoder, issue 16, linformer-pytorch. 2020. URL: https://github.com/tatp22/linformer-
pytorch/issues/16.

[4] Sinong Wang et al. Linformer: Self-Attention with Linear Complexity. 2020. arXiv: 2006.04768 [cs.LG].

12

You might also like