Commit bbe4faf4 authored by ppwwyyxx's avatar ppwwyyxx

use train_config

parent 341a5d43
......@@ -11,8 +11,6 @@ sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf
import numpy as np
from itertools import count
from layers import *
from utils import *
......@@ -20,23 +18,21 @@ from dataflow.dataset import Mnist
from dataflow import *
IMAGE_SIZE = 28
NUM_CLASS = 10
batch_size = 128
LOG_DIR = 'train_log'
def get_model(inputs):
"""
Args:
inputs: a list of input variable,
e.g.: [input, label] with:
input: bx28x28
label: bx1 integer
e.g.: [input_var, label_var] with:
input_var: bx28x28
label_var: bx1 integer
Returns:
(outputs, cost)
outputs: a list of output variable
cost: scalar variable
"""
# use this dropout variable! it will be set to 1 at test time
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
input, label = inputs
......@@ -62,7 +58,7 @@ def get_model(inputs):
fc1 = FullyConnected('lr', fc0, out_dim=10)
prob = tf.nn.softmax(fc1, name='output')
y = one_hot(label, NUM_CLASS)
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost, name='cost')
......@@ -74,59 +70,37 @@ def get_model(inputs):
return [prob, correct], cost
def main():
dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = BatchData(Mnist('test'), batch_size, remainder=True)
callbacks = [
SummaryWriter(LOG_DIR),
AccuracyValidation(
dataset_test,
prefix='test', period=1),
TrainingAccuracy(),
PeriodicSaver(LOG_DIR, period=1)
]
optimizer = tf.train.AdamOptimizer(1e-4)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
with tf.Graph().as_default():
G = tf.get_default_graph()
dataset_train = BatchData(Mnist('train'), 128)
dataset_test = BatchData(Mnist('test'), 128, remainder=True)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost_var = get_model(input_vars)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
train_op = optimizer.minimize(cost_var)
sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables())
with sess.as_default():
for ext in callbacks:
ext.before_train()
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
for epoch in count(1):
for dp in dataset_train.get_data():
feed = {keep_prob_var: 0.5}
feed.update(dict(zip(input_vars, dp)))
results = sess.run(
[train_op, cost_var] + output_vars, feed_dict=feed)
cost = results[1]
outputs = results[2:]
assert len(outputs) == len(output_vars)
for cb in callbacks:
cb.trigger_step(dp, outputs, cost)
for cb in callbacks:
cb.trigger_epoch()
summary_writer.close()
config = dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(1e-4),
callbacks=[
TrainingAccuracy(),
AccuracyValidation(dataset_test,
prefix='test', period=1),
PeriodicSaver(LOG_DIR, period=1),
SummaryWriter(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
max_epoch=100,
)
from train import start_train
start_train(config)
if __name__ == '__main__':
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: train.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from utils import *
from itertools import count
def start_train(config):
"""
Start training with the given config
Args:
config: a tensorpack config dictionary
"""
# a Dataflow instance
dataset_train = config['dataset_train']
# a tf.train.Optimizer instance
optimizer = config['optimizer']
# a list of Callback instance
callbacks = Callbacks(config.get('callbacks', []))
# a tf.ConfigProto instance
sess_config = config.get('session_config', None)
# a list of input/output variables
input_vars = config['inputs']
output_vars = config['outputs']
cost_var = config['cost']
max_epoch = int(config['max_epoch'])
# build graph
G = tf.get_default_graph()
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
train_op = optimizer.minimize(cost_var)
sess = tf.Session(config=sess_config)
# start training
with sess.as_default():
sess.run(tf.initialize_all_variables())
callbacks.before_train()
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
for epoch in xrange(1, max_epoch):
for dp in dataset_train.get_data():
feed = {keep_prob_var: 0.5}
feed.update(dict(zip(input_vars, dp)))
results = sess.run(
[train_op, cost_var] + output_vars, feed_dict=feed)
cost = results[1]
outputs = results[2:]
callbacks.trigger_step(dp, outputs, cost)
callbacks.trigger_epoch()
......@@ -8,6 +8,7 @@ import sys
import numpy as np
import os
from abc import abstractmethod
from .stat import *
from .utils import *
from .naming import *
......@@ -20,12 +21,12 @@ class Callback(object):
def _before_train(self):
"""
Called before training
Called before starting iterative training
"""
# trigger after every step
def trigger_step(self, dp, outputs, cost):
"""
Callback to be triggered after every step (every backpropagation)
Args:
dp: the input dict fed into the graph
outputs: list of output values after running this dp
......@@ -33,8 +34,10 @@ class Callback(object):
"""
pass
# trigger after every epoch
def trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
pass
class PeriodicCallback(Callback):
......@@ -77,11 +80,7 @@ class AccuracyValidation(PeriodicCallback):
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME)
self.correct_var = self.get_tensor(self.correct_var_name)
self.cost_var = self.get_tensor(self.cost_var_name)
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def _trigger(self):
cnt = 0
......@@ -121,11 +120,7 @@ class TrainingAccuracy(Callback):
self.epoch_num = 0
def _before_train(self):
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
output_vars = self.graph.get_collection(OUTPUT_VARS_KEY)
for idx, var in enumerate(output_vars):
if var.name == self.correct_var_name:
......@@ -194,3 +189,25 @@ class SummaryWriter(Callback):
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
class Callbacks(Callback):
def __init__(self, callbacks):
# put SummaryWriter to the first
for idx, cb in enumerate(callbacks):
if type(cb) == SummaryWriter:
callbacks.insert(0, callbacks.pop(idx))
break
self.callbacks = callbacks
def before_train(self):
for cb in self.callbacks:
cb.before_train()
def trigger_step(self, dp, outputs, cost):
for cb in self.callbacks:
cb.trigger_step(dp, outputs, cost)
def trigger_epoch(self):
for cb in self.callbacks:
cb.trigger_epoch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment