1+ import torch
2+ from deeprobust .graph .global_attack import Random ,DICE
3+ import numpy as np
4+ import scipy .sparse as sp
5+ from tqdm import tqdm
6+ import time
7+ import torch .nn .functional as F
8+
9+ from . import register_nas_estimator
10+ from ..space import BaseSpace
11+ from .base import BaseEstimator
12+ from ..backend import *
13+ from ...train .evaluation import Acc
14+ from ..utils import get_hardware_aware_metric
15+
16+ @register_nas_estimator ("grna" )
17+ class GRNAEstimator (BaseEstimator ):
18+ """
19+ Graph robust neural architecture estimator under adversarial attack.
20+
21+ Use model directly to get estimations.
22+
23+ Parameters
24+ ----------
25+ loss_f : str
26+ The name of loss funciton in PyTorch
27+ evaluation : list of Evaluation
28+ The evaluation metrics in module/train/evaluation
29+ GRNA_metric = acc_metric+ robustness_metric
30+ lambda: float
31+ The hyper-parameter to balance the accuracy metric and robustness metric to perform ultimate evaluation
32+ perturb_type: str
33+ Perturbation methods to simulate the adversarial attack process
34+ adv_sample_num: int
35+ Adversarial sample number used in measure architecture robustness.
36+ """
37+
38+ def __init__ (self ,
39+ loss_f = "nll_loss" ,
40+ evaluation = [Acc ()],
41+ lambda_ = 0.05 ,
42+ perturb_type = 'random' ,
43+ adv_sample_num = 10 ,
44+ dis_type = 'ce' ,
45+ ptbr = 0.05 ):
46+ super ().__init__ (loss_f , evaluation )
47+ self .evaluation = evaluation
48+ self .lambda_ = lambda_
49+ self .perturb_type = perturb_type
50+ self .dis_type = dis_type
51+ self .adv_sample_num = adv_sample_num
52+ self .ptbr = ptbr
53+ print ('initialize GRNA estimator' )
54+
55+ def infer (self , model : BaseSpace , dataset , mask = "train" ):
56+
57+ device = next (model .parameters ()).device
58+ dset = dataset [0 ].to (device )
59+ mask = bk_mask (dset , mask )
60+
61+ pred = model (dset )[mask ]
62+ label = bk_label (dset )
63+ y = label [mask ]
64+
65+ loss = getattr (F , self .loss_f )(pred , y )
66+ probs = F .softmax (pred , dim = 1 ).detach ().cpu ().numpy ()
67+
68+ # robustness metric
69+ dist = 0
70+ for _ in range (self .adv_sample_num ):
71+ modified_adj = self .gen_adversarial_samples (dset .edge_index , num_nodes = dset .num_nodes , perturb_prop = self .ptbr , attack_method = self .perturb_type )
72+ d_data = dset .clone ()
73+ d_data = d_data .to (device )
74+ edge_index = torch .LongTensor (np .vstack ((modified_adj .tocoo ().row ,modified_adj .tocoo ().col )))
75+ d_data .edge_index = edge_index .to (device )
76+ perturb_pred = model (d_data )[mask ]
77+ dist += distance (perturb_pred , pred , dis_type = self .dis_type )
78+ dist = dist / self .adv_sample_num
79+
80+ y = y .cpu ()
81+ metrics = [eva .evaluate (probs , y )+ self .lambda_ * dist for eva in self .evaluation ]
82+
83+ return metrics , loss
84+
85+
86+ def gen_adversarial_samples (self , edge_index , num_nodes , perturb_prop , attack_method = 'random' ):
87+
88+ if num_nodes is None :
89+ num_nodes = max (edge_index [0 ])+ 1
90+
91+ edge_index = edge_index .detach ().cpu ().numpy ()
92+ delta = int (edge_index .shape [1 ]// 2 * perturb_prop )
93+ v = np .ones_like (edge_index [0 ])
94+ adj = sp .csr_matrix ((v ,(edge_index [0 ], edge_index [1 ])), shape = (num_nodes ,num_nodes ))
95+ if attack_method == 'random' :
96+ attacker = Random ()
97+ attacker .attack (adj , n_perturbations = delta , type = 'flip' )
98+ elif attack_method == 'dice' :
99+ labels = self .data .y .cpu ().numpy ()
100+ attacker = DICE ()
101+ attacker .attack (adj , labels , delta )
102+ else :
103+ assert False , 'Wrong Type of attack method!'
104+
105+ modified_adj = attacker .modified_adj # scipy.sparse matrix
106+
107+ return modified_adj .tocsr ()
108+
109+
110+
111+
112+ def distance (perturb , clean , dis_type = 'ce' , data = None , p = 2 ):
113+ """
114+ Distance between logits of perturbed and clean data.
115+ Parameters:
116+ ---------
117+ perturb: torch.Tensor [n, C]
118+ clean: torch.Tensor [n,C]
119+ type: loss type
120+ labels: ground truth labels, needed when type='cw'
121+ p: fro norm, needed when type='fro'
122+
123+ Return
124+ ------
125+ Distance: torch.Tensor [n,]
126+ """
127+ # if type=='cos':
128+ # return perturb*clean / torch.sqrt(torch.norm(perturb,p=2) * torch.norm(clean))
129+ if dis_type == 'fro' :
130+ return torch .norm (perturb - clean ,p = p )
131+
132+ elif dis_type == 'ce' :
133+ p_ = F .softmax (clean ,- 1 )
134+ logq = F .log_softmax (perturb , - 1 )
135+ return - (p_ * logq ).mean ()
136+
137+ elif dis_type == 'kl' :
138+ logq = F .log_softmax (perturb ,- 1 )
139+ p_ = F .softmax (clean , - 1 )
140+ logp = F .log_softmax (clean ,- 1 )
141+ return (p_ * (logp - logq )).mean ()* 100
142+
143+ elif dis_type == 'cw' :
144+ perturb , clean , labels = perturb [data .train_mask ],clean [data .train_mask ], data .y [data .train_mask ]
145+ eye = torch .eye (labels .max () + 1 )
146+ onehot_mx = eye [labels ]
147+ one_hot_labels = onehot_mx .to (labels .device )
148+ # perturb
149+ ptb_best_second_class = (perturb - 1000 * one_hot_labels ).argmax (1 )
150+ margin = perturb [np .arange (len (perturb )), labels ] - \
151+ perturb [np .arange (len (perturb )), ptb_best_second_class ]
152+ ptb_loss = - torch .clamp (margin , max = 0 , min = - 0.1 ).mean ()
153+ # clean
154+ clean_best_second_class = (perturb - 1000 * one_hot_labels ).argmax (1 )
155+ margin = clean [np .arange (len (perturb )), labels ] - \
156+ clean [np .arange (len (perturb )), clean_best_second_class ]
157+ clean_loss = - torch .clamp (margin , max = 0 , min = - 0.1 ).mean ()
158+
159+ return (ptb_loss - clean_loss )* 100
0 commit comments