Skip to content

Commit 707ebff

Browse files
authored
Add files via upload
1 parent b8afbb3 commit 707ebff

File tree

8 files changed

+12426
-0
lines changed

8 files changed

+12426
-0
lines changed

BatchDatsetReader.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Code ideas from https://github.com/Newmu/dcgan and tensorflow mnist dataset reader
3+
"""
4+
import numpy as np
5+
# import scipy.misc as misc
6+
import cv2
7+
import os
8+
9+
10+
class BatchDatset:
11+
files = []
12+
images = []
13+
annotations = []
14+
image_options = {}
15+
batch_offset = 0
16+
epochs_completed = 0
17+
18+
def __init__(self, datadir='gen_imgs', dataset_file='dataset.txt', image_options={'resize': True, 'resize_size': (1024, 48)}):
19+
"""
20+
Intialize a generic file reader with batching for list of files
21+
:param dataset_file: list of file records to read -
22+
sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
23+
:param image_options: A dictionary of options for modifying the output image
24+
Available options:
25+
resize = True/ False
26+
resize_size = #size of output image - does bilinear resize
27+
color=True/False
28+
"""
29+
print("Initializing Batch Dataset Reader...")
30+
31+
print(image_options)
32+
f = open(dataset_file, 'r')
33+
self.files = f.readlines()
34+
self.image_options = image_options
35+
self.datadir = datadir
36+
self._read_images()
37+
38+
def _read_images(self):
39+
self.images = np.array([eval(filename)[0] for filename in self.files])
40+
self.annotations = np.array([eval(filename)[1:] for filename in self.files])
41+
print (self.images.shape)
42+
print (self.annotations.shape)
43+
44+
def _transform(self, filename):
45+
46+
image = cv2.imread(filename, 0)
47+
if image is None:
48+
return None
49+
# if self.__channels and len(image.shape) < 3: # make sure images are of shape(h,w,3)
50+
# image = np.array([image for i in range(3)])
51+
52+
if self.image_options.get("resize", False) and self.image_options["resize"]:
53+
resize_size = self.image_options["resize_size"]
54+
resize_image = cv2.resize(image, resize_size)
55+
else:
56+
resize_image = image
57+
58+
return np.expand_dims(np.array(resize_image)/255.0, axis=3)
59+
60+
def get_records(self):
61+
return self.images, self.annotations
62+
63+
def reset_batch_offset(self, offset=0):
64+
self.batch_offset = offset
65+
66+
def next_batch(self, batch_size):
67+
start = self.batch_offset
68+
self.batch_offset += batch_size
69+
if self.batch_offset > len(self.images):
70+
# Finished epoch
71+
self.epochs_completed += 1
72+
print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
73+
# Shuffle the data
74+
perm = np.arange(len(self.images))
75+
np.random.shuffle(perm)
76+
self.images = self.images[perm]
77+
self.annotations = self.annotations[perm]
78+
# Start next epoch
79+
start = 0
80+
self.batch_offset = batch_size
81+
82+
end = self.batch_offset
83+
im_names = self.images[start:end]
84+
arr = []
85+
for elem in im_names:
86+
tmp = self._transform(os.path.join(self.datadir, elem))
87+
if tmp is None:
88+
continue
89+
arr.append(tmp)
90+
imgs = np.array(arr)
91+
# imgs = np.array([self._transform(os.path.join(self.datadir, elem)) for elem in im_names])
92+
annotations = self.annotations[start:end]
93+
labels = np.zeros((len(annotations), self.image_options["resize_size"][0]))
94+
for i in range(len(annotations)):
95+
labels[i][annotations[i]] = 1
96+
labels = np.expand_dims(labels, axis=1) # [80,1, 1024]
97+
labels = np.expand_dims(labels, axis=3) # [80,1, 1024,1]
98+
return imgs, labels
99+
100+
def get_random_batch(self, batch_size):
101+
indexes = np.random.randint(0, len(self.images), size=[batch_size]).tolist()
102+
im_names = self.images[indexes]
103+
arr = []
104+
for elem in im_names:
105+
tmp = self._transform(os.path.join(self.datadir, elem))
106+
if tmp is None:
107+
continue
108+
arr.append(tmp)
109+
# imgs = np.array([self._transform(os.path.join(self.datadir, elem)) for elem in im_names])
110+
imgs = np.array(arr)
111+
annotations = self.annotations[indexes]
112+
labels = np.zeros((len(annotations), self.image_options["resize_size"][0]))
113+
for i in range(len(annotations)):
114+
labels[i][annotations[i]] = 1
115+
labels = np.expand_dims(labels, axis=1) # [80,1, 1024]
116+
labels = np.expand_dims(labels, axis=3) # [80,1, 1024,1]
117+
return imgs, labels
118+
119+
120+
# data = BatchDatset()
121+
# a = data.next_batch(3)
122+
# print a

CharacterSegmentTrain.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
from __future__ import print_function
2+
import tensorflow as tf
3+
import numpy as np
4+
from PIL import Image, ImageDraw
5+
import TensorflowUtils as utils
6+
# import read_MITSceneParsingData as scene_parsing
7+
import datetime
8+
import BatchDatsetReader as dataset
9+
from six.moves import xrange
10+
import os
11+
import cv2
12+
13+
# import pydevd
14+
# pydevd.settrace('192.168.50.217',port=8888, stdoutToServer=True, stderrToServer=True)
15+
config = tf.ConfigProto()
16+
config.gpu_options.per_process_gpu_memory_fraction = 0.6 # occupy GPU40%
17+
session = tf.Session(config=config)
18+
19+
FLAGS = tf.flags.FLAGS
20+
tf.flags.DEFINE_integer("batch_size", "80", "batch size for training")
21+
tf.flags.DEFINE_string("logs_dir", "logs/", "path to logs directory")
22+
tf.flags.DEFINE_string("data_dir", "gen_imgs/", "path to dataset")
23+
tf.flags.DEFINE_string("test_data_dir", "test_imgs/", "path to test dataset")
24+
tf.flags.DEFINE_float("learning_rate", "1e-4", "Learning rate for Adam Optimizer")
25+
tf.flags.DEFINE_string("model_dir", "Model_zoo/", "Path to vgg model mat")
26+
tf.flags.DEFINE_bool('debug', "False", "Debug mode: True/ False")
27+
tf.flags.DEFINE_string('mode', "inference", "Mode train/ test/ inference")
28+
29+
30+
MAX_ITERATION = int(1e5 + 1)
31+
NUM_OF_CLASSESS = 2
32+
IMAGE_SIZE = (1024, 48)
33+
34+
def vgg_net(weights, image):
35+
layers = (
36+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
37+
38+
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
39+
40+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
41+
'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
42+
43+
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
44+
'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
45+
46+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
47+
'relu5_3', 'conv5_4', 'relu5_4'
48+
)
49+
50+
net = {}
51+
current = image
52+
for i, name in enumerate(layers):
53+
kind = name[:4]
54+
if kind == 'conv':
55+
kernels, bias = weights[i][0][0][0][0]
56+
# matconvnet: weights are [width, height, in_channels, out_channels]
57+
# tensorflow: weights are [height, width, in_channels, out_channels]
58+
kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")
59+
bias = utils.get_variable(bias.reshape(-1), name=name + "_b")
60+
current = utils.conv2d_basic(current, kernels, bias)
61+
elif kind == 'relu':
62+
current = tf.nn.relu(current, name=name)
63+
if FLAGS.debug:
64+
utils.add_activation_summary(current)
65+
elif kind == 'pool':
66+
current = utils.avg_pool_2x2(current)
67+
net[name] = current
68+
69+
return net
70+
71+
72+
def inference(image, keep_prob):
73+
"""
74+
Semantic segmentation network definition
75+
:param image: input image. Should have values in range 0-255
76+
:param keep_prob:
77+
:return:
78+
"""
79+
with tf.variable_scope("inference"):
80+
down_w_conv1 = utils.weight_variable([3, 3, 1, 32], name='down_w_conv1')
81+
down_b1 = utils.bias_variable([32], name='down_b1')
82+
down_conv1 = tf.nn.relu(utils.conv2d_basic(image, down_w_conv1, down_b1))
83+
down_pool1 = utils.max_pool_2x2(down_conv1) # (24, 512, 32)
84+
85+
down_w_conv2 = utils.weight_variable([3, 3, 32, 64], name='down_w_conv2')
86+
down_b2 = utils.bias_variable([64], name='down_b2')
87+
down_conv2 = tf.nn.relu(utils.conv2d_basic(down_pool1, down_w_conv2, down_b2))
88+
down_pool2 = utils.max_pool_2x2(down_conv2) # (12, 256, 64)
89+
90+
down_w_conv3 = utils.weight_variable([3, 3, 64, 128], name='down_w_conv3')
91+
down_b3 = utils.bias_variable([128], name='down_b3')
92+
down_conv3 = tf.nn.relu(utils.conv2d_basic(down_pool2, down_w_conv3, down_b3))
93+
down_pool3 = utils.max_pool_2x2(down_conv3) # (6, 128, 128)
94+
95+
down_w_conv4 = utils.weight_variable([3, 3, 128, 256], name='down_w_conv4')
96+
down_b4 = utils.bias_variable([256], name='down_b4')
97+
down_conv4 = tf.nn.relu(utils.conv2d_basic(down_pool3, down_w_conv4, down_b4))
98+
down_pool4 = utils.max_pool_2x2(down_conv4) # (3, 64, 256)
99+
100+
down_w_conv5 = utils.weight_variable([3, 3, 256, 512], name='down_w_conv5')
101+
down_b5 = utils.bias_variable([512], name='down_b5')
102+
down_conv5 = tf.nn.relu(utils.conv2d_basic(down_pool4, down_w_conv5, down_b5))
103+
dropout5 = tf.nn.dropout(down_conv5, keep_prob=keep_prob)
104+
# down_pool5 = utils.max_pool_2x2(dropout5) # (1, 32, 512)
105+
down_pool5 = tf.nn.max_pool(dropout5, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
106+
# now to upscale to actual image size
107+
up_w_conv1 = utils.weight_variable([1, 5, 512, 512], name="up_w_conv1")
108+
up_b1 = utils.bias_variable([512], name="up_b1")
109+
up_conv1 = tf.nn.relu(utils.conv2d_transpose_strided(down_pool5, up_w_conv1, up_b1, output_shape=[tf.shape(image)[0],1,64,512]))
110+
111+
up_w_conv2 = utils.weight_variable([1, 5, 256, 512], name="up_w_conv2")
112+
up_b2 = utils.bias_variable([256], name="up_b2")
113+
up_conv2 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv1, up_w_conv2, up_b2, output_shape=[tf.shape(image)[0], 1, 128, 256]))
114+
115+
up_w_conv3 = utils.weight_variable([1, 5, 128, 256], name="up_w_conv3")
116+
up_b3 = utils.bias_variable([128], name="up_b3")
117+
up_conv3 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv2, up_w_conv3, up_b3, output_shape=[tf.shape(image)[0], 1, 256, 128]))
118+
119+
up_w_conv4 = utils.weight_variable([1, 5, 64, 128], name="up_w_conv4")
120+
up_b4 = utils.bias_variable([64], name="up_b4")
121+
up_conv4 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv3, up_w_conv4, up_b4, output_shape=[tf.shape(image)[0], 1, 512, 64]))
122+
123+
up_w_conv5 = utils.weight_variable([1, 5, 1, 64], name="up_w_conv5")
124+
up_b5 = utils.bias_variable([1], name="up_b5")
125+
up_conv5 = tf.nn.sigmoid(utils.conv2d_transpose_strided(up_conv4, up_w_conv5, up_b5, output_shape=[tf.shape(image)[0], 1, 1024, 1]))
126+
127+
annotation_pred = up_conv5 > 0.5
128+
129+
return annotation_pred, up_conv5
130+
131+
132+
def train(loss_val, var_list):
133+
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
134+
grads = optimizer.compute_gradients(loss_val, var_list=var_list)
135+
if FLAGS.debug:
136+
# print(len(var_list))
137+
for grad, var in grads:
138+
utils.add_gradient_summary(grad, var)
139+
return optimizer.apply_gradients(grads)
140+
141+
142+
def _transform(filename):
143+
image = cv2.imread(filename, 0)
144+
# if self.__channels and len(image.shape) < 3: # make sure images are of shape(h,w,3)
145+
# image = np.array([image for i in range(3)])
146+
resize_image = cv2.resize(image, (1024, 48))
147+
return np.expand_dims(np.array(resize_image) / 255.0, axis=3)
148+
149+
150+
def main(argv=None):
151+
keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
152+
image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE[1], IMAGE_SIZE[0], 1], name="input_image")
153+
annotation = tf.placeholder(tf.float32, shape=[None, 1, IMAGE_SIZE[0], 1], name="annotation")
154+
pred_annotation, logits = inference(image, keep_probability)
155+
# logits = tf.squeeze(logits, squeeze_dims=[1, 3])
156+
tf.summary.image("input_image", image, max_outputs=2)
157+
tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
158+
tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
159+
# loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
160+
# labels=tf.squeeze(annotation, squeeze_dims=[3]),
161+
# name="entropy")))
162+
alpha = 0.9
163+
belta = 0.1
164+
# one sample: -ylog(y)+-(1-y)log(1-y), n samples: mean(one sample)
165+
loss = tf.reduce_mean(tf.add(-alpha*tf.reduce_sum(annotation * tf.log(logits + 1e-9), 1),
166+
-belta*tf.reduce_sum((1 - annotation) * tf.log(1 - logits + 1e-9), 1)))
167+
tf.summary.scalar("entropy", loss)
168+
169+
trainable_var = tf.trainable_variables()
170+
if FLAGS.debug:
171+
for var in trainable_var:
172+
utils.add_to_regularization_and_summary(var)
173+
train_op = train(loss, trainable_var)
174+
175+
print("Setting up summary op...")
176+
summary_op = tf.summary.merge_all()
177+
178+
179+
print("Setting up dataset reader")
180+
if FLAGS.mode == 'train':
181+
train_dataset_reader = dataset.BatchDatset(FLAGS.data_dir)
182+
validation_dataset_reader = dataset.BatchDatset(FLAGS.test_data_dir, dataset_file='dataset_test.txt')
183+
184+
sess = tf.Session()
185+
186+
print("Setting up Saver...")
187+
saver = tf.train.Saver()
188+
summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)
189+
190+
sess.run(tf.global_variables_initializer())
191+
ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
192+
ckpt.model_checkpoint_path = 'logs/model.ckpt-100000'
193+
if ckpt and ckpt.model_checkpoint_path:
194+
saver.restore(sess, ckpt.model_checkpoint_path)
195+
print("Model restored...")
196+
197+
if FLAGS.mode == "train":
198+
for itr in xrange(MAX_ITERATION):
199+
train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
200+
feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85}
201+
202+
sess.run(train_op, feed_dict=feed_dict)
203+
204+
if itr % 10 == 0:
205+
train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
206+
print("Step: %d, Train_loss:%g" % (itr, train_loss))
207+
summary_writer.add_summary(summary_str, itr)
208+
209+
if itr % 500 == 0:
210+
valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size)
211+
valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
212+
keep_probability: 1.0})
213+
print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
214+
saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
215+
216+
elif FLAGS.mode == "inference":
217+
path = './real_test_imgs/'
218+
fnames = os.listdir(path)
219+
imgs = np.array([_transform(os.path.join(path, elem)) for elem in fnames])
220+
pred = sess.run(pred_annotation, feed_dict={image: imgs,
221+
keep_probability: 1.0}) # [80, 1,1024,1]
222+
pred = np.squeeze(pred, axis=3)
223+
pred = np.squeeze(pred, axis=1)
224+
pred = np.asarray(pred, np.int)
225+
res = []
226+
for itr in range(len(imgs)):
227+
im = imgs[itr]
228+
pre = pred[itr]
229+
im = 255 * np.squeeze(im, axis=2)
230+
im = Image.fromarray(im)
231+
# make sure images are of shape(h,w,3)
232+
img = im.convert('RGB')
233+
img.save('result/source_%s.jpg' % str(itr))
234+
res.append(['source_%s.jpg']+list(pre))
235+
img_d = ImageDraw.Draw(img)
236+
x_len, y_len = img.size
237+
for x in range(x_len):
238+
if pre[x] == 1:
239+
img_d.line(((x, 0), (x, y_len)), (250, 0, 0))
240+
img.save('result/pred_%s.jpg' % str(itr))
241+
# utils.save_image(im.astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5 + itr))
242+
np.savetxt('res.txt', res, fmt='%s')
243+
elif FLAGS.mode == "test":
244+
valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
245+
pred = sess.run(pred_annotation, feed_dict={image: valid_images,
246+
keep_probability: 1.0}) # [80, 1,1024,1]
247+
pred = np.squeeze(pred, axis=3)
248+
pred = np.squeeze(pred, axis=1)
249+
pred = np.asarray(pred, np.int)
250+
for itr in range(FLAGS.batch_size):
251+
im = valid_images[itr]
252+
pre = pred[itr]
253+
im = 255*np.squeeze(im,axis=2)
254+
im = Image.fromarray(im)
255+
# make sure images are of shape(h,w,3)
256+
img = im.convert('RGB')
257+
258+
img_d = ImageDraw.Draw(img)
259+
x_len, y_len = img.size
260+
for x in range(x_len):
261+
if pre[x] == 1:
262+
img_d.line(((x,0),(x,y_len)),(250,0,0))
263+
img.save('result/pred_%s.jpg' % str(itr))
264+
# utils.save_image(im.astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5 + itr))
265+
266+
if __name__ == "__main__":
267+
tf.app.run()

0 commit comments

Comments
 (0)