8000 bug Fixes, moved entire dataset to GPU · caffeine-coder1/computer_vision@db56be5 · GitHub
[go: up one dir, main page]

Skip to content

Commit db56be5

Browse files
bug Fixes, moved entire dataset to GPU
1 parent ab38112 commit db56be5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

GAN/WGAN/training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def training(opt):
4646
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
4747

4848
trans = transforms.Compose(
49-
[transforms.ToPILImage(), transforms.Resize((H, W)),
49+
[transforms.Resize((H, W)),
5050
transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
5151

5252
MNIST_data = MNIST('/datasets', True, transform=trans, download=True)
53+
MNIST_data = MNIST_data.to(work_device)
5354

54-
loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)
55+
loader = DataLoader(MNIST_data, BATCH_SIZE, True,
56+
num_workers=1, pin_memory=True)
5557

5658
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
5759

@@ -95,7 +97,7 @@ def training(opt):
9597
for batch_idx, (real, _) in enumerate(tqdm(loader)):
9698
critic.train()
9799
gen.train()
98-
real = real.to(work_device)
100+
# real = real.to(work_device)
99101
fixed_noise = torch.rand(
100102
real.shape[0], Z_DIM, 1, 1).to(work_device)
101103

0 commit comments

Comments
 (0)
0