8000 Merge pull request #134 from THUMNLab/graphcl_ogb · THUMNLab/AutoGL@0f766ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f766ca

Browse files
Merge pull request #134 from THUMNLab/graphcl_ogb
add ogb gnn
2 parents dd9748f + 25beac6 commit 0f766ca

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

examples/ogb_gnn.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import torch
2+
from torch_geometric.nn import MessagePassing
3+
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
4+
import torch.nn.functional as F
5+
from torch_geometric.nn.inits import uniform
6+
7+
# from conv import GNN_node, GNN_node_Virtualnode
8+
9+
from torch_scatter import scatter_mean
10+
11+
import torch
12+
from torch_geometric.nn import MessagePassing
13+
import torch.nn.functional as F
14+
from torch_geometric.nn import global_mean_pool, global_add_pool
15+
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
16+
from torch_geometric.utils import degree
17+
18+
import math
19+
20+
### GIN convolution along the graph structure
21+
class GINConv(MessagePassing):
22+
def __init__(self, emb_dim):
23+
'''
24+
emb_dim (int): node embedding dimensionality
25+
'''
26+
27+
super(GINConv, self).__init__(aggr = "add")
28+
29+
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
30+
self.eps = torch.nn.Parameter(torch.Tensor([0]))
31+
32+
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
33+
34+
def forward(self, x, edge_index, edge_attr):
35+
edge_embedding = self.bond_encoder(edge_attr)
36+
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
37+
38+
return out
39+
40+
def message(self, x_j, edge_attr):
41+
return F.relu(x_j + edge_attr)
42+
43+
def update(self, aggr_out):
44+
return aggr_out
45+
46+
### GCN convolution along the graph structure
47+
class GCNConv(MessagePassing):
48+
def __init__(self, emb_dim):
49+
super(GCNConv, self).__init__(aggr='add')
50+
51+
self.linear = torch.nn.Linear(emb_dim, emb_dim)
52+
self.root_emb = torch.nn.Embedding(1, emb_dim)
53+
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
54+
55+
def forward(self, x, edge_index, edge_attr):
56+
x = self.linear(x)
57+
edge_embedding = self.bond_encoder(edge_attr)
58+
59+
row, col = edge_index
60+
61+
#edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
62+
deg = degree(row, x.size(0), dtype = x.dtype) + 1
63+
deg_inv_sqrt = deg.pow(-0.5)
64+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
65+
66+
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
67+
68+
return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
69+
70+
def message(self, x_j, edge_attr, norm):
71+
return norm.view(-1, 1) * F.relu(x_j + edge_attr)
72+
73+
def update(self, aggr_out):
74+
return aggr_out
75+
76+
77+
### GNN to generate node embedding
78+
class GNN_node(torch.nn.Module):
79+
"""
80+
Output:
81+
node representations
82+
"""
83+
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
84+
'''
85+
emb_dim (int): node embedding dimensionality
86+
num_layer (int): number of GNN message passing layers
87+
88+
'''
89+
90+
super(GNN_node, self).__init__()
91+
self.num_layer = num_layer
92+
self.drop_ratio = drop_ratio
93+
self.JK = JK
94+
### add residual connection or not
95+
self.residual = residual
96+
97+
if self.num_layer < 2:
98+
raise ValueError("Number of GNN layers must be greater than 1.")
99+
100+
self.atom_encoder = AtomEncoder(emb_dim)
101+
102+
###List of GNNs
103+
self.convs = torch.nn.ModuleList()
104+
self.batch_norms = torch.nn.ModuleList()
105+
106+
for layer in range(num_layer):
107+
if gnn_type == 'gin':
108+
self.convs.append(GINConv(emb_dim))
109+
elif gnn_type == 'gcn':
110+
self.convs.append(GCNConv(emb_dim))
111+
else:
112+
raise ValueError('Undefined GNN type called {}'.format(gnn_type))
113+
114+
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
115+
116+
def forward(self, batched_data):
117+
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
118+
119+
### computing input node embedding
120+
121+
h_list = [self.atom_encoder(x)]
122+
for layer in range(self.num_layer):
123+
124+
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
125+
h = self.batch_norms[layer](h)
126+
127+
if layer == self.num_layer - 1:
128+
#remove relu for the last layer
129+
h = F.dropout(h, self.drop_ratio, training = self.training)
130+
else:
131+
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
132+
133+
if self.residual:
134+
h += h_list[layer]
135+
136+
h_list.append(h)
137+
138+
### Different implementations of Jk-concat
139+
if self.JK == "last":
140+
node_representation = h_list[-1]
141+
elif self.JK == "sum":
142+
node_representation = 0
143+
for layer in range(self.num_layer + 1):
144+
node_representation += h_list[layer]
145+
146+
return node_representation
147+
148+
149+
### Virtual GNN to generate node embedding
150+
class GNN_node_Virtualnode(torch.nn.Module):
151+
"""
152+
Output:
153+
node representations
154+
"""
155+
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
156+
'''
157+
emb_dim (int): node embedding dimensionality
158+
'''
159+
160+
super(GNN_node_Virtualnode, self).__init__()
161+
self.num_layer = num_layer
162+
self.drop_ratio = drop_ratio
163+
self.JK = JK
164+
### add residual connection or not
165+
self.residual = residual
166+
167+
if self.num_layer < 2:
168+
raise ValueError("Number of GNN layers must be greater than 1.")
169+
170+
self.atom_encoder = AtomEncoder(emb_dim)
171+
172+
### set the initial virtual node embedding to 0.
173+
self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
174+
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
175+
176+
### List of GNNs
177+
self.convs = torch.nn.ModuleList()
178+
### batch norms applied to node embeddings
179+
self.batch_norms = torch.nn.ModuleList()
180+
181+
### List of MLPs to transform virtual node at every layer
182+
self.mlp_virtualnode_list = torch.nn.ModuleList()
183+
184+
for layer in range(num_layer):
185+
if gnn_type == 'gin':
186+
self.convs.append(GINConv(emb_dim))
187+
elif gnn_type == 'gcn':
188+
self.convs.append(GCNConv(emb_dim))
189+
else:
190+
raise ValueError('Undefined GNN type called {}'.format(gnn_type))
191+
192+
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
193+
194+
for layer in range(num_layer - 1):
195+
self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \
196+
torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU()))
197+
198+
199+
def forward(self, batched_data):
200+
201+
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
202+
203+
### virtual node embeddings for graphs
204+
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
205+
206+
h_list = [self.atom_encoder(x)]
207+
for layer in range(self.num_layer):
208+
### add message from virtual nodes to graph nodes
209+
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
210+
211+
### Message passing among graph nodes
212+
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
213+
214+
h = self.batch_norms[layer](h)
215+
if layer == self.num_layer - 1:
216+
#remove relu for the last layer
217+
h = F.dropout(h, self.drop_ratio, training = self.training)
218+
else:
219+
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
220+
221+
if self.residual:
222+
h = h + h_list[layer]
223+
224+
h_list.append(h)
225+
226+
### update the virtual nodes
227+
if layer < self.num_layer - 1:
228+
### add message from graph nodes to virtual nodes
229+
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
230+
### transform virtual nodes using MLP
231+
232+
if self.residual:
233+
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
234+
else:
235+
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
236+
237+
### Different implementations of Jk-concat
238+
if self.JK == "last":
239+
node_representation = h_list[-1]
240+
elif self.JK == "sum":
241+
node_representation = 0
242+
for layer in range(self.num_layer + 1):
243+
node_representation += h_list[layer]
244+
245+
return node_representation
246+
247+
class GNN(torch.nn.Module):
248+
249+
def __init__(self, num_tasks, num_layer = 5, emb_dim = 300,
250+
gnn_type = 'gin', virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
251+
'''
252+
num_tasks (int): number of labels to be predicted
253+
virtual_node (bool): whether to add virtual node or not
254+
'''
255+
256+
super(GNN, self).__init__()
257+
258+
self.num_layer = num_layer
259+
self.drop_ratio = drop_ratio
260+
self.JK = JK
261+
self.emb_dim = emb_dim
262+
self.num_tasks = num_tasks
263+
self.graph_pooling = graph_pooling
264+
265+
if self.num_layer < 2:
266+
raise ValueError("Number of GNN layers must be greater than 1.")
267+
268+
### GNN to generate node embeddings
269+
if virtual_node:
270+
self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
271+
else:
272+
self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
273+
274+
275+
### Pooling function to generate whole-graph embeddings
276+
if self.graph_pooling == "sum":
277+
self.pool = global_add_pool
278+
elif self.graph_pooling == "mean":
279+
self.pool = global_mean_pool
280+
elif self.graph_pooling == "max":
281+
self.pool = global_max_pool
282+
elif self.graph_pooling == "attention":
283+
self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
284+
elif self.graph_pooling == "set2set":
285+
self.pool = Set2Set(emb_dim, processing_steps = 2)
286+
else:
287+
raise ValueError("Invalid graph pooling type.")
288+
289+
if graph_pooling == "set2set":
290+
self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
291+
else:
292+
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
293+
294+
def forward(self, batched_data):
295+
h_node = self.gnn_node(batched_data)
296+
297+
h_graph = self.pool(h_node, batched_data.batch)
298+
299+
return self.graph_pred_linear(h_graph)
300+
301+
if __name__ == '__main__':
302+
GNN(num_tasks = 10)

0 commit comments

Comments
 (0)
0