Skip to content

Commit 64f829c

Browse files
surgan12soumith
authored andcommitted
Examples dcgan (#464)
* mnist added dcgan * mnist added
1 parent 6d08877 commit 64f829c

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

dcgan/.swp

12 KB
Binary file not shown.

dcgan/main.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
parser = argparse.ArgumentParser()
17-
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake')
17+
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
1818
parser.add_argument('--dataroot', required=True, help='path to dataset')
1919
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
2020
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
@@ -60,6 +60,7 @@
6060
transforms.ToTensor(),
6161
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
6262
]))
63+
nc=3
6364
elif opt.dataset == 'lsun':
6465
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
6566
transform=transforms.Compose([
@@ -68,16 +69,30 @@
6869
transforms.ToTensor(),
6970
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
7071
]))
72+
nc=3
7173
elif opt.dataset == 'cifar10':
7274
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
7375
transform=transforms.Compose([
7476
transforms.Resize(opt.imageSize),
7577
transforms.ToTensor(),
7678
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
7779
]))
80+
nc=3
81+
82+
elif opt.dataset == 'mnist':
83+
dataset = dset.MNIST(root=opt.dataroot, download=True,
84+
transform=transforms.Compose([
85+
transforms.Resize(opt.imageSize),
86+
transforms.ToTensor(),
87+
transforms.Normalize((0.5,), (0.5,)),
88+
]))
89+
nc=1
90+
7891
elif opt.dataset == 'fake':
7992
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
8093
transform=transforms.ToTensor())
94+
nc=3
95+
8196
assert dataset
8297
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
8398
shuffle=True, num_workers=int(opt.workers))
@@ -87,7 +102,6 @@
87102
nz = int(opt.nz)
88103
ngf = int(opt.ngf)
89104
ndf = int(opt.ndf)
90-
nc = 3
91105

92106

93107
# custom weights initialization called on netG and netD

0 commit comments

Comments
 (0)