8000 weights are saved now based on the average loss · caffeine-coder1/computer_vision@51cff82 · GitHub
[go: up one dir, main page]

Skip to content

Commit 51cff82

Browse files
weights are saved now based on the average loss
1 parent b76129a commit 51cff82

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

GAN/WGAN/training.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def training(opt):
9292
G_loss = 0
9393

9494
for epoch in range(EPOCHS):
95+
C_loss_avg = 0
96+
G_loss_avg = 0
9597

9698
for batch_idx, (real, _) in enumerate(tqdm(loader)):
9799
critic.train()
@@ -115,7 +117,7 @@ def training(opt):
115117
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
116118

117119
C_loss = -(torch.mean(real_predict) - torch.mean(fake_predict))
118-
120+
C_loss_avg += C_loss
119121
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
120122

121123
critic.zero_grad()
@@ -139,7 +141,7 @@ def training(opt):
139141
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
140142

141143
G_loss = -(torch.mean(fake_predict))
142-
144+
G_loss_avg += G_loss
143145
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
144146

145147
gen.zero_grad()
@@ -148,10 +150,13 @@ def training(opt):
148150

149151
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
150152

151-
if batch_idx == 0:
153+
if batch_idx == 0 and epoch > 1:
154+
C_loss_avg = C_loss_avg/(CRITIC_TRAIN_STEPS*BATCH_SIZE)
155+
G_loss_avg = G_loss_avg/(BATCH_SIZE)
156+
152157
print(
153-
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
154-
Loss D: {C_loss:.4f}, loss G: {G_loss:.4f}"
158+
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)}"
159+
+ f"Loss D: {C_loss_avg:.4f}, loss G: {G_loss_avg:.4f}"
155160
)
156161

157162
with torch.no_grad():
@@ -181,13 +186,13 @@ def training(opt):
181186
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
182187

183188
if opt.weights:
184-
if C_loss_prev > C_loss:
185-
C_loss_prev = C_loss
189+
if C_loss_prev > C_loss_avg:
190+
C_loss_prev = C_loss_avg
186191
weight_path = str(Weight_dir/'critic.pth')
187192
torch.save(critic.state_dict(), weight_path)
188193

189-
if G_loss_prev > G_loss:
190-
G_loss_prev = G_loss
194+
if G_loss_prev > G_loss_avg:
195+
G_loss_prev = G_loss_avg
191196
weight_path = str(Weight_dir/'generator.pth')
192197
torch.save(gen.state_dict(), weight_path)
193198

0 commit comments

Comments
 (0)
0