8000 code optimization · caffeine-coder1/computer_vision@b76ca1a · GitHub
[go: up one dir, main page]

Skip to content

Commit b76ca1a

Browse files
code optimization
1 parent 76f9660 commit b76ca1a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

GAN/WGAN/training.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,17 @@ def training(opt):
119119
real.shape[0], Z_DIM, 1, 1).to(work_device)
120120

121121
# ~~~~~~~~~~~~~~~~~~~ critic loop ~~~~~~~~~~~~~~~~~~~ #
122-
123-
fake = gen(fixed_noise) # dim of (N,1,W,H)
122+
with torch.no_grad():
123+
fake = gen(fixed_noise) # dim of (N,1,W,H)
124124

125125
for _ in range(CRITIC_TRAIN_STEPS):
126126

127+
critic.zero_grad()
128+
# ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
129+
130+
for p in critic.parameters():
131+
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
132+
127133
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
128134

129135
# make it one dimensional array
@@ -138,16 +144,11 @@ def training(opt):
138144

139145
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
140146

141-
critic.zero_grad()
142147
C_loss.backward()
143148
critic_optim.step()
144149

145-
# ~~~~~~~~~~~ weight cliping as per WGAN paper ~~~~~~~~~~ #
146-
147-
for p in critic.parameters():
148-
p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
149-
150150
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
151+
gen.zero_grad()
151152

152153
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
153154

@@ -161,7 +162,6 @@ def training(opt):
161162

162163
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
163164

164-
gen.zero_grad()
165165
G_loss.backward()
166166
gen_optim.step()
167167

0 commit comments

Comments
 (0)
0