Commit 6306da7e authored by Yuxin Wu's avatar Yuxin Wu

small refactor in train/base

parent 8bfab811
......@@ -73,3 +73,4 @@ model-*
checkpoint
*.json
*.prototxt
snippet
# tensorpack
Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some [examples](examples) to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code.
......
......@@ -204,9 +204,9 @@ def get_config():
HumanHyperParamSetter('entropy_beta'),
HumanHyperParamSetter('explore_factor'),
master,
StartProcOrThread(master)
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
]),
extra_threads_procs=[master],
session_config=get_default_sess_config(0.5),
model=M,
step_per_epoch=STEP_PER_EPOCH,
......
......@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for
n=5, about 7.1% val error after 67k steps (8.6 step/s)
n=18, about 5.8% val error after 80k steps (2.6 step/s)
n=18, about 5.9% val error after 80k steps (2.6 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (1.55 step/s)
This model uses the whole training set instead of a train-val split.
"""
......
......@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm
import argparse, sys
from collections import defaultdict
from itertools import chain
import six
from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
......@@ -52,11 +53,9 @@ def get_args():
help='title of the graph',
default='')
parser.add_argument('--xlabel',
help='x label',
default = 'x')
help='x label', type=six.text_type)
parser.add_argument('--ylabel',
help='y label',
default='y')
help='y label', type=six.text_type)
parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum',
......@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys):
if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax)
plt.xlabel(args.xlabel.decode('utf-8'), fontsize='xx-large')
plt.ylabel(args.ylabel.decode('utf-8'), fontsize='xx-large')
if args.xlabel:
plt.xlabel(args.xlabel, fontsize='xx-large')
if args.ylabel:
plt.ylabel(args.ylabel, fontsize='xx-large')
plt.legend(loc='best', fontsize='xx-large')
# adjust maxx
......
......@@ -56,13 +56,6 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@property
def global_step(self):
"""
Access the global step value of this training.
"""
return self.trainer.global_step
def trigger_epoch(self):
"""
Triggered after every epoch.
......@@ -95,7 +88,7 @@ class ProxyCallback(Callback):
self.cb.trigger_epoch()
def __str__(self):
return str(self.cb)
return "Proxy-" + str(self.cb)
class PeriodicCallback(ProxyCallback):
"""
......
......@@ -9,6 +9,7 @@ import re
from .base import Callback
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step,
global_step=get_global_step(),
write_meta_graph=False)
# create a symbolic link for the latest model
......
......@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
:param random: whether to randomly generate data every iteration. note
that only generating the data could be time-consuming!
"""
super(FakeData, self).__init__()
self.shapes = shapes
......
......@@ -13,7 +13,6 @@ import tensorflow as tf
from .config import TrainConfig
from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import timed_operation
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary
......@@ -32,7 +31,6 @@ class Trainer(object):
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
model: a `ModelDesc`
global_step: a `int`
"""
__metaclass__ = ABCMeta
......@@ -44,7 +42,7 @@ class Trainer(object):
self.config = config
self.model = config.model
self.model.get_input_vars() # ensure they are present
self._extra_threads_procs = config.extra_threads_procs
self.init_session_and_coord()
@abstractmethod
def train(self):
......@@ -84,15 +82,6 @@ class Trainer(object):
""" This is called right after all steps in an epoch are finished"""
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
......@@ -107,31 +96,39 @@ class Trainer(object):
get_global_step())
self.stat_holder.add_stat(name, val)
def main_loop(self):
def finalize_graph(self):
# some final operations that might modify the graph
get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...")
callbacks = self.config.callbacks
callbacks.setup_graph(weakref.proxy(self))
self._init_summary()
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
self.summary_writer = tf.train.SummaryWriter(logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
self._start_concurrency()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
def main_loop(self):
self.finalize_graph()
callbacks = self.config.callbacks
with self.sess.as_default():
try:
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {} (global_step {})'.format(
self.epoch_num, self.global_step + self.config.step_per_epoch)):
self.epoch_num, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)):
......@@ -139,7 +136,6 @@ class Trainer(object):
return
self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
......@@ -155,18 +151,6 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
def _start_concurrency(self):
"""
Run all threads before starting training
"""
logger.info("Starting all threads & procs ...")
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
with self.sess.as_default():
# avoid sigint get handled by other processes
start_proc_mask_signal(self._extra_threads_procs)
def process_grads(self, grads):
g = []
for grad, var in grads:
......
......@@ -32,7 +32,6 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param extra_threads_procs: list of `Startable` threads or processes
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......@@ -72,6 +71,10 @@ class TrainConfig(object):
self.tower = [0]
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread
self.callbacks.cbs.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
......
......@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer):
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
......@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
......
......@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
self.init_session_and_coord()
describe_model()
# create an infinte data producer
self.config.dataset.reset_state()
......@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer):
def train(self):
assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self.init_session_and_coord()
self._build_enque_thread()
grads = self._single_tower_grad()
......
......@@ -113,9 +113,10 @@ def get_caffe_pb():
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir)
assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?"
"Command `protoc caffe.proto --python_out .` failed!"
import imp
return imp.load_source('caffepb', caffe_pb_file)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment