Commit dd1ac6b0 authored by ppwwyyxx's avatar ppwwyyxx

step_per_epoch & compatible feeding

parent 745ad4f0
...@@ -24,6 +24,10 @@ from utils.concurrency import * ...@@ -24,6 +24,10 @@ from utils.concurrency import *
from dataflow.dataset import Mnist from dataflow.dataset import Mnist
from dataflow import * from dataflow import *
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): def get_model(inputs, is_training):
""" """
Args: Args:
...@@ -43,6 +47,15 @@ def get_model(inputs, is_training): ...@@ -43,6 +47,15 @@ def get_model(inputs, is_training):
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
if is_training:
# augmentations
image, label = tf.train.slice_input_producer(
[image, label], name='slice_queue')
image = tf.image.random_brightness(image, 0.1)
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=2, enqueue_many=False)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5) conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2) pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3) conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
...@@ -86,11 +99,11 @@ def get_config(): ...@@ -86,11 +99,11 @@ def get_config():
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28 IMAGE_SIZE = 28
BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE) dataset_train = Mnist('train')
dataset_test = BatchData(Mnist('test'), 256, remainder=True) dataset_test = BatchData(Mnist('test'), 256, remainder=True)
#dataset_train = FixedSizeData(dataset_train, 20) step_per_epoch = dataset_train.size() / BATCH_SIZE
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
...@@ -129,6 +142,7 @@ def get_config(): ...@@ -129,6 +142,7 @@ def get_config():
inputs=input_vars, inputs=input_vars,
input_queue=input_queue, input_queue=input_queue,
get_model_func=get_model, get_model_func=get_model,
step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
) )
......
...@@ -41,7 +41,9 @@ def start_train(config): ...@@ -41,7 +41,9 @@ def start_train(config):
input_queue = config['input_queue'] input_queue = config['input_queue']
get_model_func = config['get_model_func'] get_model_func = config['get_model_func']
step_per_epoch = int(config['step_per_epoch'])
max_epoch = int(config['max_epoch']) max_epoch = int(config['max_epoch'])
assert step_per_epoch > 0 and max_epoch > 0
enqueue_op = input_queue.enqueue(tuple(input_vars)) enqueue_op = input_queue.enqueue(tuple(input_vars))
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
...@@ -79,14 +81,19 @@ def start_train(config): ...@@ -79,14 +81,19 @@ def start_train(config):
# start training: # start training:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
# a thread that keeps filling the queue # a thread that keeps filling the queue
th = EnqueueThread(sess, coord, enqueue_op, dataset_train) input_th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=False)
with sess.as_default(), \ with sess.as_default(), \
coordinator_guard( coordinator_guard(
sess, coord, th, input_queue): sess, coord, [input_th] + model_th, input_queue):
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, max_epoch): for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in xrange(dataset_train.size()): for step in xrange(step_per_epoch):
if coord.should_stop():
return
fetches = [train_op, cost_var] + output_vars + model_inputs fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches) results = sess.run(fetches)
cost = results[1] cost = results[1]
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from itertools import izip
import tensorflow as tf import tensorflow as tf
from .utils import expand_dim_if_necessary
from .naming import * from .naming import *
import logger import logger
...@@ -37,21 +39,26 @@ class EnqueueThread(threading.Thread): ...@@ -37,21 +39,26 @@ class EnqueueThread(threading.Thread):
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = {}
for var, data in izip(self.input_vars, dp):
data = expand_dim_if_necessary(var, data)
feed[var] = data
self.sess.run([self.op], feed_dict=feed) self.sess.run([self.op], feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
self.coord.request_stop()
@contextmanager @contextmanager
def coordinator_guard(sess, coord, thread, queue): def coordinator_guard(sess, coord, threads, queue):
""" """
Context manager to make sure that: Context manager to make sure that:
queue is closed queue is closed
thread is joined threads are joined
""" """
thread.start() for th in threads:
th.start()
try: try:
yield yield
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
...@@ -60,4 +67,4 @@ def coordinator_guard(sess, coord, thread, queue): ...@@ -60,4 +67,4 @@ def coordinator_guard(sess, coord, thread, queue):
coord.request_stop() coord.request_stop()
sess.run( sess.run(
queue.close(cancel_pending_enqueues=True)) queue.close(cancel_pending_enqueues=True))
coord.join([thread]) coord.join(threads)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
def expand_dim_if_necessary(var, dp):
"""
Args:
var: a tensor
dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var
"""
shape = var.get_shape().as_list()
valid_shape = [k for k in shape if k]
if dp.shape == tuple(valid_shape):
new_shape = [k if k else 1 for k in shape]
dp = dp.reshape(new_shape)
return dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment