8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ab38112 commit db56be5Copy full SHA for db56be5
GAN/WGAN/training.py
@@ -46,12 +46,14 @@ def training(opt):
46
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
47
48
trans = transforms.Compose(
49
- [transforms.ToPILImage(), transforms.Resize((H, W)),
+ [transforms.Resize((H, W)),
50
transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
51
52
MNIST_data = MNIST('/datasets', True, transform=trans, download=True)
53
+ MNIST_data = MNIST_data.to(work_device)
54
- loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)
55
+ loader = DataLoader(MNIST_data, BATCH_SIZE, True,
56
+ num_workers=1, pin_memory=True)
57
58
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
59
@@ -95,7 +97,7 @@ def training(opt):
95
97
for batch_idx, (real, _) in enumerate(tqdm(loader)):
96
98
critic.train()
99
gen.train()
- real = real.to(work_device)
100
+ # real = real.to(work_device)
101
fixed_noise = torch.rand(
102
real.shape[0], Z_DIM, 1, 1).to(work_device)
103
0 commit comments