@@ -19,7 +19,8 @@ def restoration(gfpgan,
19
19
has_aligned = False ,
20
20
only_center_face = True ,
21
21
suffix = None ,
22
- paste_back = False ):
22
+ paste_back = False ,
23
+ device = 'cuda' ):
23
24
# read image
24
25
img_name = os .path .basename (img_path )
25
26
print (f'Processing { img_name } ...' )
@@ -43,7 +44,7 @@ def restoration(gfpgan,
43
44
# prepare data
44
45
cropped_face_t = img2tensor (cropped_face / 255. , bgr2rgb = True , float32 = True )
45
46
normalize (cropped_face_t , (0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ), inplace = True )
46
- cropped_face_t = cropped_face_t .unsqueeze (0 ).to ('cuda' )
47
+ cropped_face_t = cropped_face_t .unsqueeze (0 ).to (device )
47
48
48
49
try :
49
50
with torch .no_grad ():
@@ -77,17 +78,18 @@ def restoration(gfpgan,
77
78
78
79
if __name__ == '__main__' :
79
80
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
81
+
80
82
parser = argparse .ArgumentParser ()
81
83
82
- parser .add_argument ('--upscale_factor' , type = int , default = 1 )
84
+ parser .add_argument ('--upscale_factor' , type = int , default = 2 )
83
85
parser .add_argument ('--arch' , type = str , default = 'clean' )
84
86
parser .add_argument ('--channel' , type = int , default = 2 )
85
- parser .add_argument ('--model_path' , type = str , default = 'experiments/pretrained_models/GFPGANv1 .pth' )
87
+ parser .add_argument ('--model_path' , type = str , default = 'experiments/pretrained_models/GFPGANCleanv1-NoCE-C2 .pth' )
86
88
parser .add_argument ('--test_path' , type = str , default = 'inputs/whole_imgs' )
87
89
parser .add_argument ('--suffix' , type = str , default = None , help = 'Suffix of the restored faces' )
88
90
parser .add_argument ('--only_center_face' , action = 'store_true' )
89
91
parser .add_argument ('--aligned' , action = 'store_true' )
90
- parser .add_argument ('--paste_back' , action = 'store_true ' )
92
+ parser .add_argument ('--paste_back' , action = 'store_false ' )
91
93
parser .add_argument ('--save_root' , type = str , default = 'results' )
92
94
93
95
args = parser .parse_args ()
@@ -123,14 +125,17 @@ def restoration(gfpgan,
123
125
narrow = 1 ,
124
126
sft_half = True )
125
127
126
- gfpgan .to (device )
127
- checkpoint = torch .load (args .model_path , map_location = lambda storage , loc : storage )
128
- gfpgan .load_state_dict (checkpoint ['params_ema' ])
129
- gfpgan .eval ()
128
+ gfpgan .load_state_dict (torch .load (args .model_path , map_location = lambda storage , loc : storage )['params_ema' ])
129
+ gfpgan .to (device ).eval ()
130
130
131
131
# initialize face helper
132
132
face_helper = FaceRestoreHelper (
133
- args .upscale_factor , face_size = 512 , crop_ratio = (1 , 1 ), det_model = 'retinaface_resnet50' , save_ext = 'png' )
133
+ args .upscale_factor ,
134
+ face_size = 512 ,
135
+ crop_ratio = (1 , 1 ),
136
+ det_model = 'retinaface_resnet50' ,
137
+ save_ext = 'png' ,
138
+ device = device )
134
139
135
140
img_list = sorted (glob .glob (os .path .join (args .test_path , '*' )))
136
141
for img_path in img_list :
@@ -142,6 +147,7 @@ def restoration(gfpgan,
142
147
has_aligned = args .aligned ,
143
148
only_center_face = args .only_center_face ,
144
149
suffix = args .suffix ,
145
- paste_back = args .paste_back )
150
+ paste_back = args .paste_back ,
151
+ device = device )
146
152
147
153
print (f'Results are in the [{ args .save_root } ] folder.' )
0 commit comments