Skip to content

Commit 70aa9dd

Browse files
committed
first code
0 parents  commit 70aa9dd

File tree

13 files changed

+267
-0
lines changed

13 files changed

+267
-0
lines changed

input_data.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Functions for downloading and reading MNIST data."""
2+
import gzip
3+
import os
4+
import urllib
5+
import numpy
6+
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
7+
8+
9+
def maybe_download(filename, work_directory):
10+
"""Download the data from Yann's website, unless it's already here."""
11+
if not os.path.exists(work_directory):
12+
os.mkdir(work_directory)
13+
filepath = os.path.join(work_directory, filename)
14+
if not os.path.exists(filepath):
15+
filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
16+
statinfo = os.stat(filepath)
17+
print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
18+
return filepath
19+
20+
21+
def _read32(bytestream):
22+
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
23+
return numpy.frombuffer(bytestream.read(4), dtype=dt)
24+
25+
26+
def extract_images(filename):
27+
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
28+
print 'Extracting', filename
29+
with gzip.open(filename) as bytestream:
30+
magic = _read32(bytestream)
31+
if magic != 2051:
32+
raise ValueError(
33+
'Invalid magic number %d in MNIST image file: %s' %
34+
(magic, filename))
35+
num_images = _read32(bytestream)
36+
rows = _read32(bytestream)
37+
cols = _read32(bytestream)
38+
buf = bytestream.read(rows * cols * num_images)
39+
data = numpy.frombuffer(buf, dtype=numpy.uint8)
40+
data = data.reshape(num_images, rows, cols, 1)
41+
return data
42+
43+
44+
def dense_to_one_hot(labels_dense, num_classes=10):
45+
"""Convert class labels from scalars to one-hot vectors."""
46+
num_labels = labels_dense.shape[0]
47+
index_offset = numpy.arange(num_labels) * num_classes
48+
labels_one_hot = numpy.zeros((num_labels, num_classes))
49+
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
50+
return labels_one_hot
51+
52+
53+
def extract_labels(filename, one_hot=False):
54+
"""Extract the labels into a 1D uint8 numpy array [index]."""
55+
print 'Extracting', filename
56+
with gzip.open(filename) as bytestream:
57+
magic = _read32(bytestream)
58+
if magic != 2049:
59+
raise ValueError(
60+
'Invalid magic number %d in MNIST label file: %s' %
61+
(magic, filename))
62+
num_items = _read32(bytestream)
63+
buf = bytestream.read(num_items)
64+
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
65+
if one_hot:
66+
return dense_to_one_hot(labels)
67+
return labels
68+
69+
70+
class DataSet(object):
71+
def __init__(self, images, labels, fake_data=False):
72+
if fake_data:
73+
self._num_examples = 10000
74+
else:
75+
assert images.shape[0] == labels.shape[0], (
76+
"images.shape: %s labels.shape: %s" % (images.shape,
77+
labels.shape))
78+
self._num_examples = images.shape[0]
79+
# Convert shape from [num examples, rows, columns, depth]
80+
# to [num examples, rows*columns] (assuming depth == 1)
81+
assert images.shape[3] == 1
82+
images = images.reshape(images.shape[0],
83+
images.shape[1] * images.shape[2])
84+
# Convert from [0, 255] -> [0.0, 1.0].
85+
images = images.astype(numpy.float32)
86+
images = numpy.multiply(images, 1.0 / 255.0)
87+
self._images = images
88+
self._labels = labels
89+
self._epochs_completed = 0
90+
self._index_in_epoch = 0
91+
92+
@property
93+
def images(self):
94+
return self._images
95+
96+
@property
97+
def labels(self):
98+
return self._labels
99+
100+
@property
101+
def num_examples(self):
102+
return self._num_examples
103+
104+
@property
105+
def epochs_completed(self):
106+
return self._epochs_completed
107+
108+
def next_batch(self, batch_size, fake_data=False):
109+
"""Return the next `batch_size` examples from this data set."""
110+
if fake_data:
111+
fake_image = [1.0 for _ in xrange(784)]
112+
fake_label = 0
113+
return [fake_image for _ in xrange(batch_size)], [
114+
fake_label for _ in xrange(batch_size)]
115+
start = self._index_in_epoch
116+
self._index_in_epoch += batch_size
117+
if self._index_in_epoch > self._num_examples:
118+
# Finished epoch
119+
self._epochs_completed += 1
120+
# Shuffle the data
121+
perm = numpy.arange(self._num_examples)
122+
numpy.random.shuffle(perm)
123+
self._images = self._images[perm]
124+
self._labels = self._labels[perm]
125+
# Start next epoch
126+
start = 0
127+
self._index_in_epoch = batch_size
128+
assert batch_size <= self._num_examples
129+
end = self._index_in_epoch
130+
return self._images[start:end], self._labels[start:end]
131+
132+
133+
def read_data_sets(train_dir, fake_data=False, one_hot=False):
134+
class DataSets(object):
135+
pass
136+
data_sets = DataSets()
137+
if fake_data:
138+
data_sets.train = DataSet([], [], fake_data=True)
139+
data_sets.validation = DataSet([], [], fake_data=True)
140+
data_sets.test = DataSet([], [], fake_data=True)
141+
return data_sets
142+
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
143+
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
144+
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
145+
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
146+
VALIDATION_SIZE = 5000
147+
local_file = maybe_download(TRAIN_IMAGES, train_dir)
148+
train_images = extract_images(local_file)
149+
local_file = maybe_download(TRAIN_LABELS, train_dir)
150+
train_labels = extract_labels(local_file, one_hot=one_hot)
151+
local_file = maybe_download(TEST_IMAGES, train_dir)
152+
test_images = extract_images(local_file)
153+
local_file = maybe_download(TEST_LABELS, train_dir)
154+
test_labels = extract_labels(local_file, one_hot=one_hot)
155+
validation_images = train_images[:VALIDATION_SIZE]
156+
validation_labels = train_labels[:VALIDATION_SIZE]
157+
train_images = train_images[VALIDATION_SIZE:]
158+
train_labels = train_labels[VALIDATION_SIZE:]
159+
data_sets.train = DataSet(train_images, train_labels)
160+
data_sets.validation = DataSet(validation_images, validation_labels)
161+
data_sets.test = DataSet(test_images, test_labels)
162+
return data_sets

main.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import input_data
4+
import matplotlib.pyplot as plt
5+
import os
6+
from scipy.misc import imsave as ims
7+
8+
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
9+
n_samples = mnist.train.num_examples
10+
11+
n_hidden = 500
12+
n_z = 20
13+
batchsize = 100
14+
15+
def merge(images, size):
16+
h, w = images.shape[1], images.shape[2]
17+
img = np.zeros((h * size[0], w * size[1]))
18+
19+
for idx, image in enumerate(images):
20+
i = idx % size[1]
21+
j = idx / size[1]
22+
img[j*h:j*h+h, i*w:i*w+w] = image
23+
24+
return img
25+
26+
# encoder
27+
def recognition(input_images):
28+
with tf.variable_scope("recognition"):
29+
w1 = tf.get_variable("w1",[784,n_hidden])
30+
b1 = tf.get_variable("b1",[n_hidden])
31+
w2 = tf.get_variable("w2",[n_hidden,n_hidden])
32+
b2 = tf.get_variable("b2",[n_hidden])
33+
w_mean = tf.get_variable("w_mean",[n_hidden,n_z])
34+
b_mean = tf.get_variable("b_mean",[n_z])
35+
w_stddev = tf.get_variable("w_stddev",[n_hidden,n_z])
36+
b_stddev = tf.get_variable("b_stddev",[n_z])
37+
38+
h1 = tf.nn.sigmoid(tf.matmul(input_images,w1) + b1)
39+
h2 = tf.nn.sigmoid(tf.matmul(h1,w2) + b2)
40+
o_mean = tf.matmul(h2,w_mean) + b_mean
41+
o_stddev = tf.matmul(h2,w_stddev) + b_stddev
42+
return o_mean, o_stddev
43+
44+
# decoder
45+
def generation(z):
46+
with tf.variable_scope("generation"):
47+
w1 = tf.get_variable("w1",[n_z,n_hidden])
48+
b1 = tf.get_variable("b1",[n_hidden])
49+
w2 = tf.get_variable("w2",[n_hidden,n_hidden])
50+
b2 = tf.get_variable("b2",[n_hidden])
51+
w_image = tf.get_variable("w_image",[n_hidden,784])
52+
b_image = tf.get_variable("b_image",[784])
53+
54+
h1 = tf.nn.sigmoid(tf.matmul(z,w1) + b1)
55+
h2 = tf.nn.sigmoid(tf.matmul(h1,w2) + b2)
56+
o_image = tf.nn.sigmoid(tf.matmul(h2,w_image) + b_image)
57+
return o_image
58+
59+
images = tf.placeholder(tf.float32, [None, 784])
60+
# instead of mapping directly to z, map a gaussian over z, parameterized by mean/stddev
61+
# important: z_stddev contains log(standard_deviation^2).
62+
z_mean, z_stddev = recognition(images)
63+
print z_mean.get_shape()
64+
65+
# unit guassian
66+
samples = tf.random_normal([batchsize,n_z],0,1,dtype=tf.float32)
67+
guessed_z = z_mean + (z_stddev * samples)
68+
69+
generated_images = generation(guessed_z)
70+
71+
# -log of p(x|z)
72+
generation_loss = -tf.reduce_sum(images * tf.log(1e-10 + generated_images) + (1-images) * tf.log(1e-10 + 1 - generated_images),1)
73+
74+
# we want real p(z) to be unit guassian
75+
# the KL divergence loss between real p(z) and q(z|x)
76+
# q(z|x) is the recognition network
77+
# latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.exp(z_stddev) - z_stddev - 1,1)
78+
latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
79+
80+
cost = tf.reduce_mean(generation_loss + latent_loss)
81+
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
82+
83+
train = True
84+
visualization = mnist.train.next_batch(batchsize)[0]
85+
reshaped_vis = visualization.reshape(batchsize,28,28)
86+
ims("results/base.jpg",merge(reshaped_vis[:64],[8,8]))
87+
# train
88+
saver = tf.train.Saver(max_to_keep=2)
89+
with tf.Session() as sess:
90+
if train:
91+
sess.run(tf.initialize_all_variables())
92+
for epoch in range(10):
93+
for idx in range(int(n_samples / batchsize)):
94+
batch = mnist.train.next_batch(batchsize)[0]
95+
_, real_cost = sess.run((optimizer, cost), feed_dict={images: batch})
96+
# dumb hack to print cost every epoch
97+
if idx % (n_samples - 3) == 0:
98+
print "%d: %f" % (epoch, real_cost)
99+
saver.save(sess, os.getcwd()+"/training/train",global_step=epoch)
100+
generated_test = sess.run(generated_images, feed_dict={images: visualization})
101+
generated_test = generated_test.reshape(batchsize,28,28)
102+
ims("results/"+str(epoch)+".jpg",merge(generated_test[:64],[8,8]))
103+
else:
104+
saver.restore(sess, tf.train.latest_checkpoint(os.getcwd()+"/training/"))
105+
batch = mnist.train.next_batch(batchsize)[0]

results/0.jpg

25.9 KB
Loading

results/1.jpg

8.86 KB
Loading

results/2.jpg

8.91 KB
Loading

results/3.jpg

9.45 KB
Loading

results/4.jpg

9.81 KB
Loading

results/5.jpg

10 KB
Loading

results/6.jpg

10.2 KB
Loading

results/7.jpg

10.3 KB
Loading

0 commit comments

Comments
 (0)