|
| 1 | +from __future__ import print_function |
| 2 | +import tensorflow as tf |
| 3 | +import numpy as np |
| 4 | +from PIL import Image, ImageDraw |
| 5 | +import TensorflowUtils as utils |
| 6 | +# import read_MITSceneParsingData as scene_parsing |
| 7 | +import datetime |
| 8 | +import BatchDatsetReader as dataset |
| 9 | +from six.moves import xrange |
| 10 | +import os |
| 11 | +import cv2 |
| 12 | + |
| 13 | +# import pydevd |
| 14 | +# pydevd.settrace('192.168.50.217',port=8888, stdoutToServer=True, stderrToServer=True) |
| 15 | +config = tf.ConfigProto() |
| 16 | +config.gpu_options.per_process_gpu_memory_fraction = 0.6 # occupy GPU40% |
| 17 | +session = tf.Session(config=config) |
| 18 | + |
| 19 | +FLAGS = tf.flags.FLAGS |
| 20 | +tf.flags.DEFINE_integer("batch_size", "80", "batch size for training") |
| 21 | +tf.flags.DEFINE_string("logs_dir", "logs/", "path to logs directory") |
| 22 | +tf.flags.DEFINE_string("data_dir", "gen_imgs/", "path to dataset") |
| 23 | +tf.flags.DEFINE_string("test_data_dir", "test_imgs/", "path to test dataset") |
| 24 | +tf.flags.DEFINE_float("learning_rate", "1e-4", "Learning rate for Adam Optimizer") |
| 25 | +tf.flags.DEFINE_string("model_dir", "Model_zoo/", "Path to vgg model mat") |
| 26 | +tf.flags.DEFINE_bool('debug', "False", "Debug mode: True/ False") |
| 27 | +tf.flags.DEFINE_string('mode', "inference", "Mode train/ test/ inference") |
| 28 | + |
| 29 | + |
| 30 | +MAX_ITERATION = int(1e5 + 1) |
| 31 | +NUM_OF_CLASSESS = 2 |
| 32 | +IMAGE_SIZE = (1024, 48) |
| 33 | + |
| 34 | +def vgg_net(weights, image): |
| 35 | + layers = ( |
| 36 | + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', |
| 37 | + |
| 38 | + 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', |
| 39 | + |
| 40 | + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', |
| 41 | + 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', |
| 42 | + |
| 43 | + 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', |
| 44 | + 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', |
| 45 | + |
| 46 | + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', |
| 47 | + 'relu5_3', 'conv5_4', 'relu5_4' |
| 48 | + ) |
| 49 | + |
| 50 | + net = {} |
| 51 | + current = image |
| 52 | + for i, name in enumerate(layers): |
| 53 | + kind = name[:4] |
| 54 | + if kind == 'conv': |
| 55 | + kernels, bias = weights[i][0][0][0][0] |
| 56 | + # matconvnet: weights are [width, height, in_channels, out_channels] |
| 57 | + # tensorflow: weights are [height, width, in_channels, out_channels] |
| 58 | + kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w") |
| 59 | + bias = utils.get_variable(bias.reshape(-1), name=name + "_b") |
| 60 | + current = utils.conv2d_basic(current, kernels, bias) |
| 61 | + elif kind == 'relu': |
| 62 | + current = tf.nn.relu(current, name=name) |
| 63 | + if FLAGS.debug: |
| 64 | + utils.add_activation_summary(current) |
| 65 | + elif kind == 'pool': |
| 66 | + current = utils.avg_pool_2x2(current) |
| 67 | + net[name] = current |
| 68 | + |
| 69 | + return net |
| 70 | + |
| 71 | + |
| 72 | +def inference(image, keep_prob): |
| 73 | + """ |
| 74 | + Semantic segmentation network definition |
| 75 | + :param image: input image. Should have values in range 0-255 |
| 76 | + :param keep_prob: |
| 77 | + :return: |
| 78 | + """ |
| 79 | + with tf.variable_scope("inference"): |
| 80 | + down_w_conv1 = utils.weight_variable([3, 3, 1, 32], name='down_w_conv1') |
| 81 | + down_b1 = utils.bias_variable([32], name='down_b1') |
| 82 | + down_conv1 = tf.nn.relu(utils.conv2d_basic(image, down_w_conv1, down_b1)) |
| 83 | + down_pool1 = utils.max_pool_2x2(down_conv1) # (24, 512, 32) |
| 84 | + |
| 85 | + down_w_conv2 = utils.weight_variable([3, 3, 32, 64], name='down_w_conv2') |
| 86 | + down_b2 = utils.bias_variable([64], name='down_b2') |
| 87 | + down_conv2 = tf.nn.relu(utils.conv2d_basic(down_pool1, down_w_conv2, down_b2)) |
| 88 | + down_pool2 = utils.max_pool_2x2(down_conv2) # (12, 256, 64) |
| 89 | + |
| 90 | + down_w_conv3 = utils.weight_variable([3, 3, 64, 128], name='down_w_conv3') |
| 91 | + down_b3 = utils.bias_variable([128], name='down_b3') |
| 92 | + down_conv3 = tf.nn.relu(utils.conv2d_basic(down_pool2, down_w_conv3, down_b3)) |
| 93 | + down_pool3 = utils.max_pool_2x2(down_conv3) # (6, 128, 128) |
| 94 | + |
| 95 | + down_w_conv4 = utils.weight_variable([3, 3, 128, 256], name='down_w_conv4') |
| 96 | + down_b4 = utils.bias_variable([256], name='down_b4') |
| 97 | + down_conv4 = tf.nn.relu(utils.conv2d_basic(down_pool3, down_w_conv4, down_b4)) |
| 98 | + down_pool4 = utils.max_pool_2x2(down_conv4) # (3, 64, 256) |
| 99 | + |
| 100 | + down_w_conv5 = utils.weight_variable([3, 3, 256, 512], name='down_w_conv5') |
| 101 | + down_b5 = utils.bias_variable([512], name='down_b5') |
| 102 | + down_conv5 = tf.nn.relu(utils.conv2d_basic(down_pool4, down_w_conv5, down_b5)) |
| 103 | + dropout5 = tf.nn.dropout(down_conv5, keep_prob=keep_prob) |
| 104 | + # down_pool5 = utils.max_pool_2x2(dropout5) # (1, 32, 512) |
| 105 | + down_pool5 = tf.nn.max_pool(dropout5, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID") |
| 106 | + # now to upscale to actual image size |
| 107 | + up_w_conv1 = utils.weight_variable([1, 5, 512, 512], name="up_w_conv1") |
| 108 | + up_b1 = utils.bias_variable([512], name="up_b1") |
| 109 | + up_conv1 = tf.nn.relu(utils.conv2d_transpose_strided(down_pool5, up_w_conv1, up_b1, output_shape=[tf.shape(image)[0],1,64,512])) |
| 110 | + |
| 111 | + up_w_conv2 = utils.weight_variable([1, 5, 256, 512], name="up_w_conv2") |
| 112 | + up_b2 = utils.bias_variable([256], name="up_b2") |
| 113 | + up_conv2 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv1, up_w_conv2, up_b2, output_shape=[tf.shape(image)[0], 1, 128, 256])) |
| 114 | + |
| 115 | + up_w_conv3 = utils.weight_variable([1, 5, 128, 256], name="up_w_conv3") |
| 116 | + up_b3 = utils.bias_variable([128], name="up_b3") |
| 117 | + up_conv3 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv2, up_w_conv3, up_b3, output_shape=[tf.shape(image)[0], 1, 256, 128])) |
| 118 | + |
| 119 | + up_w_conv4 = utils.weight_variable([1, 5, 64, 128], name="up_w_conv4") |
| 120 | + up_b4 = utils.bias_variable([64], name="up_b4") |
| 121 | + up_conv4 = tf.nn.relu(utils.conv2d_transpose_strided(up_conv3, up_w_conv4, up_b4, output_shape=[tf.shape(image)[0], 1, 512, 64])) |
| 122 | + |
| 123 | + up_w_conv5 = utils.weight_variable([1, 5, 1, 64], name="up_w_conv5") |
| 124 | + up_b5 = utils.bias_variable([1], name="up_b5") |
| 125 | + up_conv5 = tf.nn.sigmoid(utils.conv2d_transpose_strided(up_conv4, up_w_conv5, up_b5, output_shape=[tf.shape(image)[0], 1, 1024, 1])) |
| 126 | + |
| 127 | + annotation_pred = up_conv5 > 0.5 |
| 128 | + |
| 129 | + return annotation_pred, up_conv5 |
| 130 | + |
| 131 | + |
| 132 | +def train(loss_val, var_list): |
| 133 | + optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) |
| 134 | + grads = optimizer.compute_gradients(loss_val, var_list=var_list) |
| 135 | + if FLAGS.debug: |
| 136 | + # print(len(var_list)) |
| 137 | + for grad, var in grads: |
| 138 | + utils.add_gradient_summary(grad, var) |
| 139 | + return optimizer.apply_gradients(grads) |
| 140 | + |
| 141 | + |
| 142 | +def _transform(filename): |
| 143 | + image = cv2.imread(filename, 0) |
| 144 | + # if self.__channels and len(image.shape) < 3: # make sure images are of shape(h,w,3) |
| 145 | + # image = np.array([image for i in range(3)]) |
| 146 | + resize_image = cv2.resize(image, (1024, 48)) |
| 147 | + return np.expand_dims(np.array(resize_image) / 255.0, axis=3) |
| 148 | + |
| 149 | + |
| 150 | +def main(argv=None): |
| 151 | + keep_probability = tf.placeholder(tf.float32, name="keep_probabilty") |
| 152 | + image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE[1], IMAGE_SIZE[0], 1], name="input_image") |
| 153 | + annotation = tf.placeholder(tf.float32, shape=[None, 1, IMAGE_SIZE[0], 1], name="annotation") |
| 154 | + pred_annotation, logits = inference(image, keep_probability) |
| 155 | + # logits = tf.squeeze(logits, squeeze_dims=[1, 3]) |
| 156 | + tf.summary.image("input_image", image, max_outputs=2) |
| 157 | + tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2) |
| 158 | + tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2) |
| 159 | + # loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, |
| 160 | + # labels=tf.squeeze(annotation, squeeze_dims=[3]), |
| 161 | + # name="entropy"))) |
| 162 | + alpha = 0.9 |
| 163 | + belta = 0.1 |
| 164 | + # one sample: -ylog(y)+-(1-y)log(1-y), n samples: mean(one sample) |
| 165 | + loss = tf.reduce_mean(tf.add(-alpha*tf.reduce_sum(annotation * tf.log(logits + 1e-9), 1), |
| 166 | + -belta*tf.reduce_sum((1 - annotation) * tf.log(1 - logits + 1e-9), 1))) |
| 167 | + tf.summary.scalar("entropy", loss) |
| 168 | + |
| 169 | + trainable_var = tf.trainable_variables() |
| 170 | + if FLAGS.debug: |
| 171 | + for var in trainable_var: |
| 172 | + utils.add_to_regularization_and_summary(var) |
| 173 | + train_op = train(loss, trainable_var) |
| 174 | + |
| 175 | + print("Setting up summary op...") |
| 176 | + summary_op = tf.summary.merge_all() |
| 177 | + |
| 178 | + |
| 179 | + print("Setting up dataset reader") |
| 180 | + if FLAGS.mode == 'train': |
| 181 | + train_dataset_reader = dataset.BatchDatset(FLAGS.data_dir) |
| 182 | + validation_dataset_reader = dataset.BatchDatset(FLAGS.test_data_dir, dataset_file='dataset_test.txt') |
| 183 | + |
| 184 | + sess = tf.Session() |
| 185 | + |
| 186 | + print("Setting up Saver...") |
| 187 | + saver = tf.train.Saver() |
| 188 | + summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph) |
| 189 | + |
| 190 | + sess.run(tf.global_variables_initializer()) |
| 191 | + ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) |
| 192 | + ckpt.model_checkpoint_path = 'logs/model.ckpt-100000' |
| 193 | + if ckpt and ckpt.model_checkpoint_path: |
| 194 | + saver.restore(sess, ckpt.model_checkpoint_path) |
| 195 | + print("Model restored...") |
| 196 | + |
| 197 | + if FLAGS.mode == "train": |
| 198 | + for itr in xrange(MAX_ITERATION): |
| 199 | + train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size) |
| 200 | + feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85} |
| 201 | + |
| 202 | + sess.run(train_op, feed_dict=feed_dict) |
| 203 | + |
| 204 | + if itr % 10 == 0: |
| 205 | + train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict) |
| 206 | + print("Step: %d, Train_loss:%g" % (itr, train_loss)) |
| 207 | + summary_writer.add_summary(summary_str, itr) |
| 208 | + |
| 209 | + if itr % 500 == 0: |
| 210 | + valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size) |
| 211 | + valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, |
| 212 | + keep_probability: 1.0}) |
| 213 | + print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss)) |
| 214 | + saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr) |
| 215 | + |
| 216 | + elif FLAGS.mode == "inference": |
| 217 | + path = './real_test_imgs/' |
| 218 | + fnames = os.listdir(path) |
| 219 | + imgs = np.array([_transform(os.path.join(path, elem)) for elem in fnames]) |
| 220 | + pred = sess.run(pred_annotation, feed_dict={image: imgs, |
| 221 | + keep_probability: 1.0}) # [80, 1,1024,1] |
| 222 | + pred = np.squeeze(pred, axis=3) |
| 223 | + pred = np.squeeze(pred, axis=1) |
| 224 | + pred = np.asarray(pred, np.int) |
| 225 | + res = [] |
| 226 | + for itr in range(len(imgs)): |
| 227 | + im = imgs[itr] |
| 228 | + pre = pred[itr] |
| 229 | + im = 255 * np.squeeze(im, axis=2) |
| 230 | + im = Image.fromarray(im) |
| 231 | + # make sure images are of shape(h,w,3) |
| 232 | + img = im.convert('RGB') |
| 233 | + img.save('result/source_%s.jpg' % str(itr)) |
| 234 | + res.append(['source_%s.jpg']+list(pre)) |
| 235 | + img_d = ImageDraw.Draw(img) |
| 236 | + x_len, y_len = img.size |
| 237 | + for x in range(x_len): |
| 238 | + if pre[x] == 1: |
| 239 | + img_d.line(((x, 0), (x, y_len)), (250, 0, 0)) |
| 240 | + img.save('result/pred_%s.jpg' % str(itr)) |
| 241 | + # utils.save_image(im.astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5 + itr)) |
| 242 | + np.savetxt('res.txt', res, fmt='%s') |
| 243 | + elif FLAGS.mode == "test": |
| 244 | + valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size) |
| 245 | + pred = sess.run(pred_annotation, feed_dict={image: valid_images, |
| 246 | + keep_probability: 1.0}) # [80, 1,1024,1] |
| 247 | + pred = np.squeeze(pred, axis=3) |
| 248 | + pred = np.squeeze(pred, axis=1) |
| 249 | + pred = np.asarray(pred, np.int) |
| 250 | + for itr in range(FLAGS.batch_size): |
| 251 | + im = valid_images[itr] |
| 252 | + pre = pred[itr] |
| 253 | + im = 255*np.squeeze(im,axis=2) |
| 254 | + im = Image.fromarray(im) |
| 255 | + # make sure images are of shape(h,w,3) |
| 256 | + img = im.convert('RGB') |
| 257 | + |
| 258 | + img_d = ImageDraw.Draw(img) |
| 259 | + x_len, y_len = img.size |
| 260 | + for x in range(x_len): |
| 261 | + if pre[x] == 1: |
| 262 | + img_d.line(((x,0),(x,y_len)),(250,0,0)) |
| 263 | + img.save('result/pred_%s.jpg' % str(itr)) |
| 264 | + # utils.save_image(im.astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5 + itr)) |
| 265 | + |
| 266 | +if __name__ == "__main__": |
| 267 | + tf.app.run() |
0 commit comments