Commit 86370a76 authored by ppwwyyxx's avatar ppwwyyxx

checkpoint

parent a78d02ac
......@@ -18,6 +18,7 @@ from models import *
from utils import *
from utils.symbolic_functions import *
from utils.summary import *
from utils.concurrency import *
from dataflow.dataset import Mnist
from dataflow import *
......@@ -40,12 +41,12 @@ def get_model(inputs):
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1 = MaxPooling('pool1', conv1, 2)
#conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
#pool0 = MaxPooling('pool0', conv0, 2)
#conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
#pool1 = MaxPooling('pool1', conv1, 2)
fc0 = FullyConnected('fc0', pool1, 1024)
fc0 = FullyConnected('fc0', image, 1024)
fc0 = tf.nn.dropout(fc0, keep_prob)
# fc will have activation summary by default. disable this for the output layer
......@@ -91,12 +92,14 @@ def get_config():
sess_config.allow_soft_placement = True
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
image_var = tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(
tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
input_queue = tf.RandomShuffleQueue(100, 50, ['float32', 'int32'], name='queue')
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
......@@ -110,13 +113,13 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
SummaryWriter(LOG_DIR),
ValidationError(dataset_test, prefix='test'),
#ValidationError(dataset_test, prefix='test'),
PeriodicSaver(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
input_queue=input_queue,
get_model_func=get_model,
max_epoch=100,
)
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
from utils import *
from utils.concurrency import *
from dataflow import DataFlow
from itertools import count
import argparse
......@@ -39,11 +40,17 @@ def start_train(config):
# a list of input/output variables
input_vars = config['inputs']
output_vars = config['outputs']
cost_var = config['cost']
input_queue = config['input_queue']
get_model_func = config['get_model_func']
max_epoch = int(config['max_epoch'])
enqueue_op = input_queue.enqueue(tuple(input_vars))
model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
output_vars, cost_var = get_model_func(model_inputs)
# build graph
G = tf.get_default_graph()
for v in input_vars:
......@@ -71,21 +78,26 @@ def start_train(config):
with sess.as_default():
sess.run(tf.initialize_all_variables())
callbacks.before_train()
is_training = G.get_tensor_by_name(IS_TRAINING_VAR_NAME)
for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for dp in dataset_train.get_data():
feed = {is_training: True}
feed.update(dict(zip(input_vars, dp)))
results = sess.run(
[train_op, cost_var] + output_vars, feed_dict=feed)
coord = tf.train.Coordinator()
th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
with timed_operation('epoch {}'.format(epoch)), \
coordinator_context(
sess, coord, th, input_queue):
for step in xrange(dataset_train.size()):
# TODO eval dequeue to get dp
fetches = [train_op, cost_var] + output_vars
results = sess.run(fetches,
feed_dict={IS_TRAINING_VAR_NAME: True})
cost = results[1]
outputs = results[2:]
callbacks.trigger_step(feed, outputs, cost)
print tf.train.global_step(sess, global_step_var), cost
# trigger_step
coord.request_stop()
# summary will take a data from the queue
callbacks.trigger_epoch()
print "Finish callback"
sess.close()
def main(get_config_func):
......
......@@ -75,15 +75,13 @@ class SummaryWriter(Callback):
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_step(self, inputs, outputs, cost):
self.last_dp = inputs
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval(self.last_dp)
summary_str = self.summary_op.eval(
feed_dict={IS_TRAINING_VAR_NAME: True})
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
......
......@@ -34,7 +34,7 @@ def getlogger():
logger = getlogger()
for func in ['info', 'warning', 'error', 'critical', 'warn']:
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
def set_file(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment