8000 Merge branch 'robustGNAS' of github.com:THUMNLab/AutoGL into robustGNAS · THUMNLab/AutoGL@7755f5b · GitHub
[go: up one dir, main page]

Skip to content

Commit 7755f5b

Browse files
committed
Merge branch 'robustGNAS' of github.com:THUMNLab/AutoGL into robustGNAS
2 parents c58f48c + 24ecb96 commit 7755f5b

File tree

10 files changed

+487
-185
lines changed

10 files changed

+487
-185
lines changed

autogl/module/nas/algorithm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ def build_nas_algo_from_name(name: str) -> BaseNAS:
5959
__all__ = ["BaseNAS", "Darts", "Enas", "RandomSearch", "RL", "GraphNasRL","Spos"]
6060
if not is_dgl():
6161
__all__.append("Gasso")
62+
__all__.append("GRNA")

autogl/module/nas/algorithm/grna.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# "Graph differentiable architecture search with structure optimization" NeurIPS 21'
1+
# "Adversarially Robust Neural Architecture Search for Graph Neural Networks"
22

33
import logging
44

@@ -89,7 +89,7 @@ def _prepare(self):
8989
self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
9090
)
9191
# controller
92-
self.controller=UniformSampler(self.nas_modules)
92+
self.controller = UniformSampler(self.nas_modules)
9393

9494
# Evolution
9595
self.evolve = Evolution(

autogl/module/nas/algorithm/spos.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
from tqdm import tqdm, trange
2323
from ....utils import get_logger
2424

25-
from nni.algorithms.nas.pytorch.random import RandomMutator
26-
from nni.retiarii.strategy import RegularizedEvolution
27-
2825
import numpy as np
2926
LOGGER = get_logger("SPOS")
3027

autogl/module/nas/estimator/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def register_nas_estimator_cls(cls):
1919
return register_nas_estimator_cls
2020

2121

22-
from .one_shot import OneShotEstimator, OneShotEstimator_HardwareAware
22+
from .one_shot import OneShotEstimator, OneShotEstimator_HardwareAware
23+
from .grna_estimator import GRNAEstimator
2324
from .train_scratch import TrainEstimator, TrainEstimator_HardwareAware
2425

2526
def build_nas_estimator_from_name(name: str) -> BaseEstimator:
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import torch
2+
from deeprobust.graph.global_attack import Random,DICE
3+
import numpy as np
4+
import scipy.sparse as sp
5+
from tqdm import tqdm
6+
import time
7+
import torch.nn.functional as F
8+
9+
from . import register_nas_estimator
10+
from ..space import BaseSpace
11+
from .base import BaseEstimator
12+
from ..backend import *
13+
from ...train.evaluation import Acc
14+
from ..utils import get_hardware_aware_metric
15+
16+
@register_nas_estimator("grna")
17+
class GRNAEstimator(BaseEstimator):
18+
"""
19+
Graph robust neural architecture estimator under adversarial attack.
20+
21+
Use model directly to get estimations.
22+
23+
Parameters
24+
----------
25+
loss_f : str
26+
The name of loss funciton in PyTorch
27+
evaluation : list of Evaluation
28+
The evaluation metrics in module/train/evaluation
29+
GRNA_metric = acc_metric+ robustness_metric
30+
lambda: float
31+
The hyper-parameter to balance the accuracy metric and robustness metric to perform ultimate evaluation
32+
perturb_type: str
33+
Perturbation methods to simulate the adversarial attack process
34+
adv_sample_num: int
35+
Adversarial sample number used in measure architecture robustness.
36+
"""
37+
38+
def __init__(self,
39+
loss_f="nll_loss",
40+
evaluation=[Acc()],
41+
lambda_=0.05,
42+
perturb_type='random',
43+
adv_sample_num=10,
44+
dis_type='ce',
45+
ptbr=0.05):
46+
super().__init__(loss_f, evaluation)
47+
self.evaluation = evaluation
48+
self.lambda_ = lambda_
49+
self.perturb_type = perturb_type
50+
self.dis_type = dis_type
51+
self.adv_sample_num = adv_sample_num
52+
self.ptbr = ptbr
53+
print('initialize GRNA estimator')
54+
55+
def infer(self, model: BaseSpace, dataset, mask="train"):
56+
57+
device = next(model.parameters()).device
58+
dset = dataset[0].to(device)
59+
mask = bk_mask(dset, mask)
60+
61+
pred = model(dset)[mask]
62+
label = bk_label(dset)
63+
y = label[mask]
64+
65+
loss = getattr(F, self.loss_f)(pred, y)
66+
probs = F.softmax(pred, dim=1).detach().cpu().numpy()
67+
68+
# robustness metric
69+
dist = 0
70+
for _ in range(self.adv_sample_num):
71+
modified_adj = self.gen_adversarial_samples(dset.edge_index, num_nodes=dset.num_nodes, perturb_prop=self.ptbr, attack_method=self.perturb_type)
72+
d_data = dset.clone()
73+
d_data = d_data.to(device)
74+
edge_index = torch.LongTensor(np.vstack((modified_adj.tocoo().row,modified_adj.tocoo().col)))
75+
d_data.edge_index = edge_index.to(device)
76+
perturb_pred = model(d_data)[mask]
77+
dist += distance(perturb_pred, pred, dis_type=self.dis_type)
78+
dist = dist/self.adv_sample_num
79+
80+
y = y.cpu()
81+
metrics = [eva.evaluate(probs, y)+self.lambda_*dist for eva in self.evaluation]
82+
83+
return metrics, loss
84+
85+
86+
def gen_adversarial_samples(self, edge_index, num_nodes, perturb_prop, attack_method='random'):
87+
88+
if num_nodes is None:
89+
num_nodes = max(edge_index[0])+1
90+
91+
edge_index = edge_index.detach().cpu().numpy()
92+
delta = int(edge_index.shape[1]//2 * perturb_prop)
93+
v = np.ones_like(edge_index[0])
94+
adj = sp.csr_matrix((v,(edge_index[0], edge_index[1])), shape=(num_nodes,num_nodes))
95+
if attack_method=='random':
96+
attacker = Random()
97+
attacker.attack(adj, n_perturbations=delta, type='flip')
98+
elif attack_method=='dice':
99+
labels = self.data.y.cpu().numpy()
100+
attacker = DICE()
101+
attacker.attack(adj, labels, delta)
102+
else:
103+
assert False, 'Wrong Type of attack method!'
104+
105+
modified_adj = attacker.modified_adj # scipy.sparse matrix
106+
107+
return modified_adj.tocsr()
108+
109+
110+
111+
112+
def distance(perturb, clean, dis_type='ce', data=None, p=2):
113+
"""
114+
Distance between logits of perturbed and clean data.
115+
Parameters:
116+
---------
117+
perturb: torch.Tensor [n, C]
118+
clean: torch.Tensor [n,C]
119+
type: loss type
120+
labels: ground truth labels, needed when type='cw'
121+
p: fro norm, needed when type='fro'
122+
123+
Return
124+
------
125+
Distance: torch.Tensor [n,]
126+
"""
127+
# if type=='cos':
128+
# return perturb*clean / torch.sqrt(torch.norm(perturb,p=2) * torch.norm(clean))
129+
if dis_type=='fro':
130+
return torch.norm(perturb - clean,p=p)
131+
132+
elif dis_type=='ce':
133+
p_ = F.softmax(clean,-1)
134+
logq = F.log_softmax(perturb, -1)
135+
return -(p_*logq).mean()
136+
137+
elif dis_type=='kl':
138+
logq = F.log_softmax(perturb,-1)
139+
p_ = F.softmax(clean, -1)
140+
logp = F.log_softmax(clean,-1)
141+
return (p_*(logp-logq)).mean()*100
142+
143+
elif dis_type=='cw':
144+
perturb, clean, labels = perturb[data.train_mask],clean[data.train_mask], data.y[data.train_mask]
145+
eye = torch.eye(labels.max() + 1)
146+
onehot_mx = eye[labels]
147+
one_hot_labels = onehot_mx.to(labels.device)
148+
# perturb
149+
ptb_best_second_class = (perturb - 1000*one_hot_labels).argmax(1)
150+
margin = perturb[np.arange(len(perturb)), labels] - \
151+
perturb[np.arange(len(perturb)), ptb_best_second_class]
152+
ptb_loss = -torch.clamp(margin, max = 0, min = -0.1).mean()
153+
# clean
154+
clean_best_second_class = (perturb - 1000*one_hot_labels).argmax(1)
155+
margin = clean[np.arange(len(perturb)), labels] - \
156+
clean[np.arange(len(perturb)), clean_best_second_class]
157+
clean_loss = -torch.clamp(margin, max = 0, min = -0.1).mean()
158+
159+
return (ptb_loss-clean_loss)*100
Lines changed: 0 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch.nn.functional as F
21
import torch
32
from deeprobust.graph.global_attack import Random,DICE
43
import numpy as np
@@ -91,148 +90,3 @@ def infer(self, model: BaseSpace, dataset, mask="train"):
9190
return metrics, loss
9291

9392

94-
95-
@register_nas_estimator("grna")
96-
class GRNAEstimator(BaseEstimator):
97-
"""
98-
Graph robust neural architecture estimator under adversarial attack.
99-
100-
Use model directly to get estimations.
101-
102-
Parameters
103-
----------
104-
loss_f : str
105-
The name of loss funciton in PyTorch
106-
evaluation : list of Evaluation
107-
The evaluation metrics in module/train/evaluation
108-
GRNA_metric = acc_metric+ robustness_metric
109-
lambda: float
110-
The hyper-parameter to balance the accuracy metric and robustness metric to perform ultimate evaluation
111-
perturb_type: str
112-
Perturbation methods to simulate the adversarial attack process
113-
adv_sample_num: int
114-
Adversarial sample number used in measure architecture robustness.
115-
"""
116-
117-
def __init__(self,
118-
loss_f="nll_loss",
119-
evaluation=[Acc()],
120-
lambda_=0.1,
121-
perturb_type='random',
122-
adv_sample_num=10,
123-
dis_type='ce',
124-
ptbr=0.05):
125-
super().__init__(loss_f, evaluation)
126-
self.evaluation = evaluation
127-
self.lambda_ = lambda_
128-
self.perturb_type = perturb_type
129-
self.dis_type = dis_type
130-
self.adv_sample_num = adv_sample_num
131-
self.ptbr = ptbr
132-
print('initialize GRNA estimator')
133-
134-
def infer(self, model: BaseSpace, dataset, mask="train"):
135-
136-
device = next(model.parameters()).device
137-
dset = dataset[0].to(device)
138-
mask = bk_mask(dset, mask)
139-
140-
pred = model(dset)[mask]
141-
label = bk_label(dset)
142-
y = label[mask]
143-
144-
loss = getattr(F, self.loss_f)(pred, y)
145-
probs = F.softmax(pred, dim=1).detach().cpu().numpy()
146-
147-
# robustness metric
148-
dist = 0
149-
for _ in range(self.adv_sample_num):
150-
modified_adj = self.gen_adversarial_samples(dset.edge_index, num_nodes=dset.num_nodes, perturb_prop=self.ptbr, attack_method=self.perturb_type)
151-
d_data = dset.clone()
152-
d_data = d_data.to(device)
153-
edge_index = torch.LongTensor(np.vstack((modified_adj.tocoo().row,modified_adj.tocoo().col)))
154-
d_data.edge_index = edge_index.to(device)
155-
perturb_pred = model(d_data)[mask]
156-
dist += distance(perturb_pred, pred, dis_type=self.dis_type)
157-
dist = dist/self.adv_sample_num
158-
159-
y = y.cpu()
160-
metrics = [eva.evaluate(probs, y)+self.lambda_*dist for eva in self.evaluation]
161-
162-
return metrics, loss
163-
164-
165-
def gen_adversarial_samples(self, edge_index, num_nodes, perturb_prop, attack_method='random'):
166-
167-
if num_nodes is None:
168-
num_nodes = max(edge_index[0])+1
169-
170-
edge_index = edge_index.detach().cpu().numpy()
171-
delta = int(edge_index.shape[1]//2 * perturb_prop)
172-
v = np.ones_like(edge_index[0])
173-
adj = sp.csr_matrix((v,(edge_index[0], edge_index[1])), shape=(num_nodes,num_nodes))
174-
if attack_method=='random':
175-
attacker = Random()
176-
attacker.attack(adj, n_perturbations=delta, type='flip')
177-
elif attack_method=='dice':
178-
labels = self.data.y.cpu().numpy()
179-
attacker = DICE()
180-
attacker.attack(adj, labels, delta)
181-
else:
182-
assert False, 'Wrong Type of attack method!'
183-
184-
modified_adj = attacker.modified_adj # scipy.sparse matrix
185-
186-
return modified_adj.tocsr()
187-
188-
189-
190-
191-
def distance(perturb, clean, dis_type='ce', data=None, p=2):
192-
"""
193-
Distance between logits of perturbed and clean data.
194-
Parameters:
195-
---------
196-
perturb: torch.Tensor [n, C]
197-
clean: torch.Tensor [n,C]
198-
type: loss type
199-
labels: ground truth labels, needed when type='cw'
200-
p: fro norm, needed when type='fro'
201-
202-
Return
203-
------
204-
Distance: torch.Tensor [n,]
205-
"""
206-
# if type=='cos':
207-
# return perturb*clean / torch.sqrt(torch.norm(perturb,p=2) * torch.norm(clean))
208-
if dis_type=='fro':
209-
return torch.norm(perturb - clean,p=p)
210-
211-
elif dis_type=='ce':
212-
p_ = F.softmax(clean,-1)
213-
logq = F.log_softmax(perturb, -1)
214-
return -(p_*logq).mean()
215-
216-
elif dis_type=='kl':
217-
logq = F.log_softmax(perturb,-1)
218-
p_ = F.softmax(clean, -1)
219-
logp = F.log_softmax(clean,-1)
220-
return (p_*(logp-logq)).mean()*100
221-
222-
elif dis_type=='cw':
223-
perturb, clean, labels = perturb[data.train_mask],clean[data.train_mask], data.y[data.train_mask]
224-
eye = torch.eye(labels.max() + 1)
225-
onehot_mx = eye[labels]
226-
one_hot_labels = onehot_mx.to(labels.device)
227-
# perturb
228-
ptb_best_second_class = (perturb - 1000*one_hot_labels).argmax(1)
229-
margin = perturb[np.arange(len(perturb)), labels] - \
230-
perturb[np.arange(len(perturb)), ptb_best_second_class]
231-
ptb_loss = -torch.clamp(margin, max = 0, min = -0.1).mean()
232-
# clean
233-
clean_best_second_class = (perturb - 1000*one_hot_labels).argmax(1)
234-
margin = clean[np.arange(len(perturb)), labels] - \
235-
clean[np.arange(len(perturb)), clean_best 3E0D _second_class]
236-
clean_loss = -torch.clamp(margin, max = 0, min = -0.1).mean()
237-
238-
return (ptb_loss-clean_loss)*100

configs/nodeclf_nas_grna.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ nas:
1010
algorithm:
1111
name: grna
1212
n_warmup: 1000 # 1000
13-
cycles: 20000
13+
cycles: 5000
1414
population_size: 50
1515
sample_size: 20
1616
mutation_prob: 0.05
1717
estimator:
1818
name: grna
19-
lambda_: 0.0
19+
lambda_: 0.05
20+
adv_sample_num: 10
21+
ptbr: 0.05
2022
ensemble:
2123
name: null
2224
feature:

0 commit comments

Comments
 (0)
0