8th 王振 JTVAE论文解读

You might also like

Download as pdf or txt
Download as pdf or txt
You are on page 1of 39

Junction Tree Variational Autoencoder for

Molecular Graph Generation

王振

2021年11月15日
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Sample
• Tree decoder
• Graph decoder
• Conclusion
Junction Tree Variational Autoencoder for Molecular Graph Generation

本篇论文是基于Graph的,而在生成分子之前的很多工作中,很多都是基于SMILES的。

基于SMILES生成模型的两个关键缺点:
1.分子的 SMILES 表示不是为捕获分子相似度而设计的,会导致生成模型难以学习到平
滑的分子embedding。

2.比起SMILES表示,在图上更容易表达分子的一些重要的化学特性,比如分子的有效性。
这里作者假设,直接在图上操作可以改进有效化学结构的生成性建模。
‣ 基于原子(atom)的分子生成== 基于字母 ‣ 基于官能团(group)的分子生成 == 基
生成句子 于单词生成句子

‣ 原子和键 ‣ 环和键
‣ 基于原子的生成 ‣ 基于官能团的生成
‣ 中间步骤可能不具有化学意义 ‣ 每一步都具有化学意义

Jin et al., Junction Tree Variational Autoencoder for Molecular Graph Generation. arXiv:1802.04364
整体架构

4
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph decoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
分子→连接树

分子 连接树

N N N N N O O Cl S …
官能团词汇
N N N S C …

• group by group生成分子

• 词汇量:处理250K分子得到少于800个
分子→连接树

1.对每个原子编号,提取非环键和单环,划分成两类节点
2.共享原子数大于2的节点合并
3.若有3个及以上的节点共享一个原子,则将该
原子独立成新的节点
4.提取官能团

总词汇数<800
5.找到最小生成树(最短路径)
Tree Decomposition

Vocab
Tree Decomposition

SMILES
MolTree类
self.smiles
self.nodes node1 MolTreeNode类
node2 self.smiles
node3 self.mol
node4 self.clique
…… self.neighbor
self.is_leaf
self.nid
self.label
self.idx
self.wid
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Tree encoder
Old memory

embedding

𝑚𝑚𝑘𝑘𝑘𝑘
𝑚𝑚𝑖𝑖𝑖𝑖
𝑘𝑘 ∈ 𝑁𝑁 𝑖𝑖 \j
Final
Old memory memory

𝑥𝑥𝑖𝑖
embedding
fnode:存储该batch_size fmess:存储该batch_size中 mess_graph:存储所有边的所有 node_graph:存储所有节点的所
中所有node的word_id 所有边的初始节点的idx 前向边的idx 有前向边的idx(作为尾节点)
torch.size([num_nodes, ]) torch.size([num_edges, ]) torch.size([num_edges, Max_NB]) torch.size([num_nodes, Max_NB])
Max_NB Max_NB
0 73 0 23 0 20 5 0
0 20 5 0
1 5 1 54 1 11 16 45
1 11 16 45
2 65 2 11 2 2 0 0
2 2 0 0
3 12 3 12 3 3 6 0
3 3 6 0
…… …… … …
… …

scope=List[Tuple(int: start_idx, int: len)]


fnode:存储该batch_size中 fmess:存储该batch_size中所 mess_graph:存储所有边的所有
所有node的word_id 有边的初始节点的idx 前向边的idx
torch.size([num_nodes, ]) torch.size( [num_edges, ]) torch.size([num_edges, Max_NB])
Max_NB
初始化 node_graph:存储所有节点的所
0 73 0 0(pad)
所有边的特征 有前向边的idx
0 P P P
1 5 1 54 torch.size([num_nodes, Max_NB])
hidden_size
2 65 1 11 16 45
2 11 Max_NB
3 12 2 2 0 0 0 0 0 0 0
3 12
… … 3 3 6 0 1 0 0 0 0 0 20 5 0
… …
… … 2 0 0 0 0 1 11 16 45
3 0 0 0 0 2 2 0 0
nn.Embedding
(vocab_size, hidden_size) hidden_size=450 … … 3 3 6 0
0 p p p p GRU messages … …
1 torch.size([num_
hidden_size=450 edges, 450 ])
2
0
3 hidden_size=450
1
… … 0
2 1
edge_begin_node_features
3 torch.size([num_edges, 450 ]) 2
… … 3
… …
node_features
torch.size([num_nodes, 450 ]) messages
torch.size([num_edges, 450 ])
node_graph:存储所有节点的所 node_features
messages 有前向边的idx torch.size([num_nodes,
torch.size([num_edges, 450 ]) torch.size([num_nodes, Max_NB]) 450 ])
hidden_size=450
hidden_size=450 0 20 5 0 hidden_size + hidden_size hidden_size
0
0 1 11 16 45 0
1 Cat(dim=1)0
1 2 2 0 0
2
1 Linear 1
2
3 3 6 0 2 2
3 3
… … 3 3
… … … …
… … … …
hidden_size Node_features
hidden_size=450 torch.size([num_nodes,
Max_NB 450 ])
0
mess_nei.sum(dim=1) 1
num_nodes hidden_size
2
scope=List[Tuple( start_idx, len)] 0
3
mess_nei 1
… … 之前设置了每个tree的第1个节
torch.size([num_nodes, Max_NB, 450 ])
点为根节点 2
mess_nei
torch.size([num_nodes, … …
450 ])
Tree_features
torch.size(
[batch_size, 450 ])
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Graph encoder
fatoms:存储该batch_size fbonds:存储该batch_size agraph:存储所有atom的所有前向 bgraph:存储所有bond的所有
中所有atom的特征 中所有bond的特征 边的idx 前向边的idx
torch.size([num_atoms, 39]) torch.size([num_bonds, 50]) torch_size([num_atoms, Max_NB]) torch_size([num_bonds,
39 Max_NB=6 Max_NB])
39+11=50 Max_NB=6
0 1 0 . 1 0 1 0 1 . 0 0 0 20 5 0 0 20 5 0
1 0 1 . 0 1 1 1 0 . 0 1 1 11 16 45 1 11 16 45
2 0 1 . 1 2 0 1 0 . 1 1 2 2 0 0 2 2 0 0
3 1 0 . 0 3 0 0 1 . 1 0 3 3 6 0 3 3 6 0
…… . …… . … … … …
one-hot encoding one-hot encoding
Symbol:23 Bond Type:5
Degree(NB):6 Stereo :6
Formal Charge:5 scope=List[Tuple(int: start_idx, int: len)]
Chiral:4 Bond_Feature_Dim
Aromatic :1 =5+6=11
Atom_Feature_Dim
=23+6+5+4+1=39
fbonds:存储该batch_size中所有 binput messages bgraph:存储所有bond的
bond的特征 torch.size([num_bonds, torch.size([num_bonds, 所有前向边的idx
torch.size([num_bonds, 50]) hidden_size]) hidden_size]) torch_size([num_bonds,
Max_NB])
hidden_size=450 hidden_size=450 Max_NB=6
39+11=50
0 0 0 20 5 0
0 1 0 1 . 0 0
nn.Linear(50, 1 1 1 11 16 45
1 1 1 0 . 0 1 hidden_sze) ReLU
2 2 2 2 0 0
2 0 1 0 . 1 1
3 0 0 1 . 1 0 3 3 3 3 6 0
… … . … … … … … …
one-hot encoding

hidden_size=450 hidden_size

hidden_size=450 0
6
1 nei_message.
0 nn.Linear(
2 sum(dim=1) num_
1 2*hidden_size, Cat
hidden_sze) (dim=1) 3 bonds
2 Linear
… …
3 nei_message
nei_message torch.size([num_bonds, 6, 450 ])
… …
torch.size([num_bonds, 450 ])
messages
torch.size([num_bonds,
450 ])
fatoms:存储该batch_size中所有
messages agraph:存储所有atom的 atom的特征
torch.size([num_bonds, 所有前向边的idx torch.size([num_atoms, 39]) hidden_size=450
hidden_size]) torch_size([num_atoms, 39
Max_NB]) 0
hidden_size=450 Max_NB=6
0 1 0 . 1 1
0 0 20 5 0 1 0 1 . 0
Cat Linear(489,450) 2
1 1 11 16 45 2 0 1 . 1
3
2 2 2 0 0 3 1 0 . 0
… …
3 3 6 0 … … .
3 atom_features
one-hot encoding
… … … … torch.size([num_atoms,
450 ])
scope=List[Tuple( start_idx,
hidden_size=450
len)]
hidden_size
0 得到每个分子的每个原子
hidden_size
6 1 的特征,取均值
2 0
atom_nei_message.
num_ sum(dim=1) 3 1
atoms
… … 2
… …
atom_nei_message atom_nei_message graph_features
torch.size([num_atoms, torch.size([num_atoms, torch.size(
6, 450 ]) 450 ]) [batch_size, 450 ])
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
latent_size=28
resample
0
hidden_size hidden_size
nn.Linear(hidden_size, 1 𝜇𝜇
0 0 latent_size) 2
torch.size(
1 … … [batch_size, latent_size ])
1
2 2 latent_size=28
… … … …
0
𝜎𝜎 2
Tree_features
torch.size(
graph_features
torch.size(
nn.Linear(hidden_size, 1 𝑙𝑙𝑙𝑙𝑙𝑙
[batch_size, [batch_size, latent_size) 2
torch.size(
hidden_size ]) hidden_size ]) … … [batch_size, latent_size ])

若从𝑁𝑁 𝜇𝜇, 𝜎𝜎 中直接采样一个样本进行解码,则在反向传播的时候会造成梯度断裂。


从𝑁𝑁 0, 1 中采样𝜀𝜀,则𝑁𝑁 𝜇𝜇, 𝜎𝜎 中的样本可表示为𝜀𝜀𝜀𝜀 + 𝜇𝜇

𝜇𝜇
𝑧𝑧𝐺𝐺
𝜎𝜎 2 exp
𝑙𝑙𝑙𝑙𝑙𝑙
[batch_size, latent_size ]
1
𝜀𝜀
[batch_size, latent_size ]
2
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph decoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Tree Decoder
Tree Decoder

node_ Node_ directi


x y on
扩展到整个batch_size中的所有tree,每行表示1个tree
1 2 1
2 3 1
(1,2,1) (2,3,1) (3,4,1) (4,5,1) (5,4,0) (4,3,0) (3,6,1) (6,3,0) (3,2,0) (2,1,0) (1,7,1) (7,1,0)
dfs 3 4 1
4 5 1
batch_
5 4 0 size
4 3 0
3 6 1
6 3 0
3 2 0
prop_list
2 1 0
1 7 1
7 1 0
Tree Decoder 1
root
root root 1 2
1 1
2
Add node 2 as 3
Get embedding
with wid of 1 2 neighbor of 1
3 4 Add node 4 as
neighbor of 5

Get embedding 4
with wid of 2 5 Get embedding


initial with wid of 4
GRU
message 5 Add node 5 as
GRU neighbor of 4
node_x Node_y directio
n
1 2 1
1 1 Get embedding
2 3 1
with wid of 5 3 4 1
new_message
4 5 1
2 new_message
GRU 5 4 0
4 3 0
4 3 6 1
Stop pre new_message 6 3 0
Label pre(if direction =1) 3 2 0
TP 5 2 1 0
LP TP 1 7 1
LP 7 1 0
Tree Decoder
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Graph Decoder

Decoding graph
Set of possible candidates
graphs of tree 𝑇𝑇�
Graph Decoder
i

Atoms of subgraph Bond Features


u-v bond(edge)
Atom Features Old message from
neighbors
j

Tree message

Subgraph of ground truth Subgraphs from model

Ground truth Set of possible candidate subgraphs


目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Conclusion

生成分子的合法性
谢 谢!

You might also like