Commit dd1ac6b0 authored by ppwwyyxx's avatar ppwwyyxx

step_per_epoch & compatible feeding

parent 745ad4f0
......@@ -24,6 +24,10 @@ from utils.concurrency import *
from dataflow.dataset import Mnist
from dataflow import *
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
"""
Args:
......@@ -43,6 +47,15 @@ def get_model(inputs, is_training):
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
if is_training:
# augmentations
image, label = tf.train.slice_input_producer(
[image, label], name='slice_queue')
image = tf.image.random_brightness(image, 0.1)
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=2, enqueue_many=False)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
......@@ -86,11 +99,11 @@ def get_config():
logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28
BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_train = Mnist('train')
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
#dataset_train = FixedSizeData(dataset_train, 20)
step_per_epoch = dataset_train.size() / BATCH_SIZE
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config = tf.ConfigProto()
......@@ -129,6 +142,7 @@ def get_config():
inputs=input_vars,
input_queue=input_queue,
get_model_func=get_model,
step_per_epoch=step_per_epoch,
max_epoch=100,
)
......
......@@ -41,7 +41,9 @@ def start_train(config):
input_queue = config['input_queue']
get_model_func = config['get_model_func']
step_per_epoch = int(config['step_per_epoch'])
max_epoch = int(config['max_epoch'])
assert step_per_epoch > 0 and max_epoch > 0
enqueue_op = input_queue.enqueue(tuple(input_vars))
model_inputs = input_queue.dequeue()
......@@ -79,14 +81,19 @@ def start_train(config):
# start training:
coord = tf.train.Coordinator()
# a thread that keeps filling the queue
th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
input_th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=False)
with sess.as_default(), \
coordinator_guard(
sess, coord, th, input_queue):
sess, coord, [input_th] + model_th, input_queue):
callbacks.before_train()
for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for step in xrange(dataset_train.size()):
for step in xrange(step_per_epoch):
if coord.should_stop():
return
fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches)
cost = results[1]
......
......@@ -5,8 +5,10 @@
import threading
from contextlib import contextmanager
from itertools import izip
import tensorflow as tf
from .utils import expand_dim_if_necessary
from .naming import *
import logger
......@@ -37,21 +39,26 @@ class EnqueueThread(threading.Thread):
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
feed = {}
for var, data in izip(self.input_vars, dp):
data = expand_dim_if_necessary(var, data)
feed[var] = data
self.sess.run([self.op], feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
self.coord.request_stop()
@contextmanager
def coordinator_guard(sess, coord, thread, queue):
def coordinator_guard(sess, coord, threads, queue):
"""
Context manager to make sure that:
queue is closed
thread is joined
threads are joined
"""
thread.start()
for th in threads:
th.start()
try:
yield
except (KeyboardInterrupt, Exception) as e:
......@@ -60,4 +67,4 @@ def coordinator_guard(sess, coord, thread, queue):
coord.request_stop()
sess.run(
queue.close(cancel_pending_enqueues=True))
coord.join([thread])
coord.join(threads)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
def expand_dim_if_necessary(var, dp):
"""
Args:
var: a tensor
dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var
"""
shape = var.get_shape().as_list()
valid_shape = [k for k in shape if k]
if dp.shape == tuple(valid_shape):
new_shape = [k if k else 1 for k in shape]
dp = dp.reshape(new_shape)
return dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment