@@ -92,6 +92,8 @@ def training(opt):
92
92
G_loss = 0
93
93
94
94
for epoch in range (EPOCHS ):
95
+ C_loss_avg = 0
96
+ G_loss_avg = 0
95
97
96
98
for batch_idx , (real , _ ) in enumerate (tqdm (loader )):
97
99
critic .train ()
@@ -115,7 +117,7 @@ def training(opt):
115
117
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
116
118
117
119
C_loss = - (torch .mean (real_predict ) - torch .mean (fake_predict ))
118
-
120
+ C_loss_avg += C_loss
119
121
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
120
122
121
123
critic .zero_grad ()
@@ -139,7 +141,7 @@ def training(opt):
139
141
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
140
142
141
143
G_loss = - (torch .mean (fake_predict ))
142
-
144
+ G_loss_avg += G_loss
143
145
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
144
146
145
147
gen .zero_grad ()
@@ -148,10 +150,13 @@ def training(opt):
148
150
149
151
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
150
152
151
- if batch_idx == 0 :
153
+ if batch_idx == 0 and epoch > 1 :
154
+ C_loss_avg = C_loss_avg / (CRITIC_TRAIN_STEPS * BATCH_SIZE )
155
+ G_loss_avg = G_loss_avg / (BATCH_SIZE )
156
+
152
157
print (
153
- f"Epoch [{ epoch } /{ EPOCHS } ] Batch { batch_idx } /{ len (loader )} \
154
- Loss D: { C_loss :.4f} , loss G: { G_loss :.4f} "
158
+ f"Epoch [{ epoch } /{ EPOCHS } ] Batch { batch_idx } /{ len (loader )} "
159
+ + f" Loss D: { C_loss_avg :.4f} , loss G: { G_loss_avg :.4f} "
155
160
)
156
161
157
162
with torch .no_grad ():
@@ -181,13 +186,13 @@ def training(opt):
181
186
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
182
187
183
188
if opt .weights :
184
- if C_loss_prev > C_loss :
185
- C_loss_prev = C_loss
189
+ if C_loss_prev > C_loss_avg :
190
+ C_loss_prev = C_loss_avg
186
191
weight_path = str (Weight_dir / 'critic.pth' )
187
192
torch .save (critic .state_dict (), weight_path )
188
193
189
- if G_loss_prev > G_loss :
190
- G_loss_prev = G_loss
194
+ if G_loss_prev > G_loss_avg :
195
+ G_loss_prev = G_loss_avg
191
196
weight_path = str (Weight_dir / 'generator.pth' )
192
197
torch .save (gen .state_dict (), weight_path )
193
198
0 commit comments