|
14 | 14 |
|
15 | 15 |
|
16 | 16 | parser = argparse.ArgumentParser()
|
17 |
| -parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake') |
| 17 | +parser.add_argument('--dataset', required=True, help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') |
18 | 18 | parser.add_argument('--dataroot', required=True, help='path to dataset')
|
19 | 19 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
|
20 | 20 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
|
|
60 | 60 | transforms.ToTensor(),
|
61 | 61 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
62 | 62 | ]))
|
| 63 | + nc=3 |
63 | 64 | elif opt.dataset == 'lsun':
|
64 | 65 | dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
|
65 | 66 | transform=transforms.Compose([
|
|
68 | 69 | transforms.ToTensor(),
|
69 | 70 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
70 | 71 | ]))
|
| 72 | + nc=3 |
71 | 73 | elif opt.dataset == 'cifar10':
|
72 | 74 | dataset = dset.CIFAR10(root=opt.dataroot, download=True,
|
73 | 75 | transform=transforms.Compose([
|
74 | 76 | transforms.Resize(opt.imageSize),
|
75 | 77 | transforms.ToTensor(),
|
76 | 78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
77 | 79 | ]))
|
| 80 | + nc=3 |
| 81 | + |
| 82 | +elif opt.dataset == 'mnist': |
| 83 | + dataset = dset.MNIST(root=opt.dataroot, download=True, |
| 84 | + transform=transforms.Compose([ |
| 85 | + transforms.Resize(opt.imageSize), |
| 86 | + transforms.ToTensor(), |
| 87 | + transforms.Normalize((0.5,), (0.5,)), |
| 88 | + ])) |
| 89 | + nc=1 |
| 90 | + |
78 | 91 | elif opt.dataset == 'fake':
|
79 | 92 | dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
|
80 | 93 | transform=transforms.ToTensor())
|
| 94 | + nc=3 |
| 95 | + |
81 | 96 | assert dataset
|
82 | 97 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
|
83 | 98 | shuffle=True, num_workers=int(opt.workers))
|
|
87 | 102 | nz = int(opt.nz)
|
88 | 103 | ngf = int(opt.ngf)
|
89 | 104 | ndf = int(opt.ndf)
|
90 |
| -nc = 3 |
91 | 105 |
|
92 | 106 |
|
93 | 107 | # custom weights initialization called on netG and netD
|
|
0 commit comments