1+ import torch
2+ from torch_geometric .nn import MessagePassing
3+ from torch_geometric .nn import global_add_pool , global_mean_pool , global_max_pool , GlobalAttention , Set2Set
4+ import torch .nn .functional as F
5+ from torch_geometric .nn .inits import uniform
6+
7+ # from conv import GNN_node, GNN_node_Virtualnode
8+
9+ from torch_scatter import scatter_mean
10+
11+ import torch
12+ from torch_geometric .nn import MessagePassing
13+ import torch .nn .functional as F
14+ from torch_geometric .nn import global_mean_pool , global_add_pool
15+ from ogb .graphproppred .mol_encoder import AtomEncoder ,BondEncoder
16+ from torch_geometric .utils import degree
17+
18+ import math
19+
20+ ### GIN convolution along the graph structure
21+ class GINConv (MessagePassing ):
22+ def __init__ (self , emb_dim ):
23+ '''
24+ emb_dim (int): node embedding dimensionality
25+ '''
26+
27+ super (GINConv , self ).__init__ (aggr = "add" )
28+
29+ self .mlp = torch .nn .Sequential (torch .nn .Linear (emb_dim , 2 * emb_dim ), torch .nn .BatchNorm1d (2 * emb_dim ), torch .nn .ReLU (), torch .nn .Linear (2 * emb_dim , emb_dim ))
30+ self .eps = torch .nn .Parameter (torch .Tensor ([0 ]))
31+
32+ self .bond_encoder = BondEncoder (emb_dim = emb_dim )
33+
34+ def forward (self , x , edge_index , edge_attr ):
35+ edge_embedding = self .bond_encoder (edge_attr )
36+ out = self .mlp ((1 + self .eps ) * x + self .propagate (edge_index , x = x , edge_attr = edge_embedding ))
37+
38+ return out
39+
40+ def message (self , x_j , edge_attr ):
41+ return F .relu (x_j + edge_attr )
42+
43+ def update (self , aggr_out ):
44+ return aggr_out
45+
46+ ### GCN convolution along the graph structure
47+ class GCNConv (MessagePassing ):
48+ def __init__ (self , emb_dim ):
49+ super (GCNConv , self ).__init__ (aggr = 'add' )
50+
51+ self .linear = torch .nn .Linear (emb_dim , emb_dim )
52+ self .root_emb = torch .nn .Embedding (1 , emb_dim )
53+ self .bond_encoder = BondEncoder (emb_dim = emb_dim )
54+
55+ def forward (self , x , edge_index , edge_attr ):
56+ x = self .linear (x )
57+ edge_embedding = self .bond_encoder (edge_attr )
58+
59+ row , col = edge_index
60+
61+ #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
62+ deg = degree (row , x .size (0 ), dtype = x .dtype ) + 1
63+ deg_inv_sqrt = deg .pow (- 0.5 )
64+ deg_inv_sqrt [deg_inv_sqrt == float ('inf' )] = 0
65+
66+ norm = deg_inv_sqrt [row ] * deg_inv_sqrt [col ]
67+
68+ return self .propagate (edge_index , x = x , edge_attr = edge_embedding , norm = norm ) + F .relu (x + self .root_emb .weight ) * 1. / deg .view (- 1 ,1 )
69+
70+ def message (self , x_j , edge_attr , norm ):
71+ return norm .view (- 1 , 1 ) * F .relu (x_j + edge_attr )
72+
73+ def update (self , aggr_out ):
74+ return aggr_out
75+
76+
77+ ### GNN to generate node embedding
78+ class GNN_node (torch .nn .Module ):
79+ """
80+ Output:
81+ node representations
82+ """
83+ def __init__ (self , num_layer , emb_dim , drop_ratio = 0.5 , JK = "last" , residual = False , gnn_type = 'gin' ):
84+ '''
85+ emb_dim (int): node embedding dimensionality
86+ num_layer (int): number of GNN message passing layers
87+
88+ '''
89+
90+ super (GNN_node , self ).__init__ ()
91+ self .num_layer = num_layer
92+ self .drop_ratio = drop_ratio
93+ self .JK = JK
94+ ### add residual connection or not
95+ self .residual = residual
96+
97+ if self .num_layer < 2 :
98+ raise ValueError ("Number of GNN layers must be greater than 1." )
99+
100+ self .atom_encoder = AtomEncoder (emb_dim )
101+
102+ ###List of GNNs
103+ self .convs = torch .nn .ModuleList ()
104+ self .batch_norms = torch .nn .ModuleList ()
105+
106+ for layer in range (num_layer ):
107+ if gnn_type == 'gin' :
108+ self .convs .append (GINConv (emb_dim ))
109+ elif gnn_type == 'gcn' :
110+ self .convs .append (GCNConv (emb_dim ))
111+ else :
112+ raise ValueError ('Undefined GNN type called {}' .format (gnn_type ))
113+
114+ self .batch_norms .append (torch .nn .BatchNorm1d (emb_dim ))
115+
116+ def forward (self , batched_data ):
117+ x , edge_index , edge_attr , batch = batched_data .x , batched_data .edge_index , batched_data .edge_attr , batched_data .batch
118+
119+ ### computing input node embedding
120+
121+ h_list = [self .atom_encoder (x )]
122+ for layer in range (self .num_layer ):
123+
124+ h = self .convs [layer ](h_list [layer ], edge_index , edge_attr )
125+ h = self .batch_norms [layer ](h )
126+
127+ if layer == self .num_layer - 1 :
128+ #remove relu for the last layer
129+ h = F .dropout (h , self .drop_ratio , training = self .training )
130+ else :
131+ h = F .dropout (F .relu (h ), self .drop_ratio , training = self .training )
132+
133+ if self .residual :
134+ h += h_list [layer ]
135+
136+ h_list .append (h )
137+
138+ ### Different implementations of Jk-concat
139+ if self .JK == "last" :
140+ node_representation = h_list [- 1 ]
141+ elif self .JK == "sum" :
142+ node_representation = 0
143+ for layer in range (self .num_layer + 1 ):
144+ node_representation += h_list [layer ]
145+
146+ return node_representation
147+
148+
149+ ### Virtual GNN to generate node embedding
150+ class GNN_node_Virtualnode (torch .nn .Module ):
151+ """
152+ Output:
153+ node representations
154+ """
155+ def __init__ (self , num_layer , emb_dim , drop_ratio = 0.5 , JK = "last" , residual = False , gnn_type = 'gin' ):
156+ '''
157+ emb_dim (int): node embedding dimensionality
158+ '''
159+
160+ super (GNN_node_Virtualnode , self ).__init__ ()
161+ self .num_layer = num_layer
162+ self .drop_ratio = drop_ratio
163+ self .JK = JK
164+ ### add residual connection or not
165+ self .residual = residual
166+
167+ if self .num_layer < 2 :
168+ raise ValueError ("Number of GNN layers must be greater than 1." )
169+
170+ self .atom_encoder = AtomEncoder (emb_dim )
171+
172+ ### set the initial virtual node embedding to 0.
173+ self .virtualnode_embedding = torch .nn .Embedding (1 , emb_dim )
174+ torch .nn .init .constant_ (self .virtualnode_embedding .weight .data , 0 )
175+
176+ ### List of GNNs
177+ self .convs = torch .nn .ModuleList ()
178+ ### batch norms applied to node embeddings
179+ self .batch_norms = torch .nn .ModuleList ()
180+
181+ ### List of MLPs to transform virtual node at every layer
182+ self .mlp_virtualnode_list = torch .nn .ModuleList ()
183+
184+ for layer in range (num_layer ):
185+ if gnn_type == 'gin' :
186+ self .convs .append (GINConv (emb_dim ))
187+ elif gnn_type == 'gcn' :
188+ self .convs .append (GCNConv (emb_dim ))
189+ else :
190+ raise ValueError ('Undefined GNN type called {}' .format (gnn_type ))
191+
192+ self .batch_norms .append (torch .nn .BatchNorm1d (emb_dim ))
193+
194+ for layer in range (num_layer - 1 ):
195+ self .mlp_virtualnode_list .append (torch .nn .Sequential (torch .nn .Linear (emb_dim , 2 * emb_dim ), torch .nn .BatchNorm1d (2 * emb_dim ), torch .nn .ReLU (), \
196+ torch .nn .Linear (2 * emb_dim , emb_dim ), torch .nn .BatchNorm1d (emb_dim ), torch .nn .ReLU ()))
197+
198+
199+ def forward (self , batched_data ):
200+
201+ x , edge_index , edge_attr , batch = batched_data .x , batched_data .edge_index , batched_data .edge_attr , batched_data .batch
202+
203+ ### virtual node embeddings for graphs
204+ virtualnode_embedding = self .virtualnode_embedding (torch .zeros (batch [- 1 ].item () + 1 ).to (edge_index .dtype ).to (edge_index .device ))
205+
206+ h_list = [self .atom_encoder (x )]
207+ for layer in range (self .num_layer ):
208+ ### add message from virtual nodes to graph nodes
209+ h_list [layer ] = h_list [layer ] + virtualnode_embedding [batch ]
210+
211+ ### Message passing among graph nodes
212+ h = self .convs [layer ](h_list [layer ], edge_index , edge_attr )
213+
214+ h = self .batch_norms [layer ](h )
215+ if layer == self .num_layer - 1 :
216+ #remove relu for the last layer
217+ h = F .dropout (h , self .drop_ratio , training = self .training )
218+ else :
219+ h = F .dropout (F .relu (h ), self .drop_ratio , training = self .training )
220+
221+ if self .residual :
222+ h = h + h_list [layer ]
223+
224+ h_list .append (h )
225+
226+ ### update the virtual nodes
227+ if layer < self .num_layer - 1 :
228+ ### add message from graph nodes to virtual nodes
229+ virtualnode_embedding_temp = global_add_pool (h_list [layer ], batch ) + virtualnode_embedding
230+ ### transform virtual nodes using MLP
231+
232+ if self .residual :
233+ virtualnode_embedding = virtualnode_embedding + F .dropout (self .mlp_virtualnode_list [layer ](virtualnode_embedding_temp ), self .drop_ratio , training = self .training )
234+ else :
235+ virtualnode_embedding = F .dropout (self .mlp_virtualnode_list [layer ](virtualnode_embedding_temp ), self .drop_ratio , training = self .training )
236+
237+ ### Different implementations of Jk-concat
238+ if self .JK == "last" :
239+ node_representation = h_list [- 1 ]
240+ elif self .JK == "sum" :
241+ node_representation = 0
242+ for layer in range (self .num_layer + 1 ):
243+ node_representation += h_list [layer ]
244+
245+ return node_representation
246+
247+ class GNN (torch .nn .Module ):
248+
249+ def __init__ (self , num_tasks , num_layer = 5 , emb_dim = 300 ,
250+ gnn_type = 'gin' , virtual_node = False , residual = False , drop_ratio = 0.5 , JK = "last" , graph_pooling = "mean" ):
251+ '''
252+ num_tasks (int): number of labels to be predicted
253+ virtual_node (bool): whether to add virtual node or not
254+ '''
255+
256+ super (GNN , self ).__init__ ()
257+
258+ self .num_layer = num_layer
259+ self .drop_ratio = drop_ratio
260+ self .JK = JK
261+ self .emb_dim = emb_dim
262+ self .num_tasks = num_tasks
263+ self .graph_pooling = graph_pooling
264+
265+ if self .num_layer < 2 :
266+ raise ValueError ("Number of GNN layers must be greater than 1." )
267+
268+ ### GNN to generate node embeddings
269+ if virtual_node :
270+ self .gnn_node = GNN_node_Virtualnode (num_layer , emb_dim , JK = JK , drop_ratio = drop_ratio , residual = residual , gnn_type = gnn_type )
271+ else :
272+ self .gnn_node = GNN_node (num_layer , emb_dim , JK = JK , drop_ratio = drop_ratio , residual = residual , gnn_type = gnn_type )
273+
274+
275+ ### Pooling function to generate whole-graph embeddings
276+ if self .graph_pooling == "sum"<
38BA
/span>:
277+ self .pool = global_add_pool
278+ elif self .graph_pooling == "mean" :
279+ self .pool = global_mean_pool
280+ elif self .graph_pooling == "max" :
281+ self .pool = global_max_pool
282+ elif self .graph_pooling == "attention" :
283+ self .pool = GlobalAttention (gate_nn = torch .nn .Sequential (torch .nn .Linear (emb_dim , 2 * emb_dim ), torch .nn .BatchNorm1d (2 * emb_dim ), torch .nn .ReLU (), torch .nn .Linear (2 * emb_dim , 1 )))
284+ elif self .graph_pooling == "set2set" :
285+ self .pool = Set2Set (emb_dim , processing_steps = 2 )
286+ else :
287+ raise ValueError ("Invalid graph pooling type." )
288+
289+ if graph_pooling == "set2set" :
290+ self .graph_pred_linear = torch .nn .Linear (2 * self .emb_dim , self .num_tasks )
291+ else :
292+ self .graph_pred_linear = torch .nn .Linear (self .emb_dim , self .num_tasks )
293+
294+ def forward (self , batched_data ):
295+ h_node = self .gnn_node (batched_data )
296+
297+ h_graph = self .pool (h_node , batched_data .batch )
298+
299+ return self .graph_pred_linear (h_graph )
300+
301+ if __name__ == '__main__' :
302+ GNN (num_tasks = 10 )
0 commit comments