@@ -119,11 +119,17 @@ def training(opt):
119
119
real .shape [0 ], Z_DIM , 1 , 1 ).to (work_device )
120
120
121
121
# ~~~~~~~~~~~~~~~~~~~ critic loop ~~~~~~~~~~~~~~~~~~~ #
122
-
123
- fake = gen (fixed_noise ) # dim of (N,1,W,H)
122
+ with torch . no_grad ():
123
+ fake = gen (fixed_noise ) # dim of (N,1,W,H)
124
124
125
125
for _ in range (CRITIC_TRAIN_STEPS ):
126
126
127
+ critic .zero_grad ()
128
+ # ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
129
+
130
+ for p in critic .parameters ():
131
+ p .data .clamp_ (- WEIGHT_CLIP , WEIGHT_CLIP )
132
+
127
133
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
128
134
129
135
# make it one dimensional array
@@ -138,16 +144,11 @@ def training(opt):
138
144
139
145
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
140
146
141
- critic .zero_grad ()
142
147
C_loss .backward ()
143
148
critic_optim .step ()
144
149
145
- # ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
146
-
147
- for p in critic .parameters ():
148
- p .data .clamp_ (- WEIGHT_CLIP , WEIGHT_CLIP )
149
-
150
150
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
151
+ gen .zero_grad ()
151
152
152
153
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
153
154
@@ -161,7 +162,6 @@ def training(opt):
161
162
162
163
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
163
164
164
- gen .zero_grad ()
165
165
G_loss .backward ()
166
166
gen_optim .step ()
167
167
0 commit comments