Commit 6306da7e authored by Yuxin Wu's avatar Yuxin Wu

small refactor in train/base

parent 8bfab811
...@@ -73,3 +73,4 @@ model-* ...@@ -73,3 +73,4 @@ model-*
checkpoint checkpoint
*.json *.json
*.prototxt *.prototxt
snippet
# tensorpack # tensorpack
Neural Network Toolbox on TensorFlow Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some [examples](examples) to learn about the framework. See some [examples](examples) to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code. You can actually train them and reproduce the performance... not just to see how to write code.
......
...@@ -204,9 +204,9 @@ def get_config(): ...@@ -204,9 +204,9 @@ def get_config():
HumanHyperParamSetter('entropy_beta'), HumanHyperParamSetter('entropy_beta'),
HumanHyperParamSetter('explore_factor'), HumanHyperParamSetter('explore_factor'),
master, master,
StartProcOrThread(master)
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
]), ]),
extra_threads_procs=[master],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=M, model=M,
step_per_epoch=STEP_PER_EPOCH, step_per_epoch=STEP_PER_EPOCH,
......
...@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027 ...@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for I can reproduce the results on 2 TitanX for
n=5, about 7.1% val error after 67k steps (8.6 step/s) n=5, about 7.1% val error after 67k steps (8.6 step/s)
n=18, about 5.8% val error after 80k steps (2.6 step/s) n=18, about 5.9% val error after 80k steps (2.6 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (1.55 step/s) n=30: a 182-layer network, about 5.6% val error after 51k steps (1.55 step/s)
This model uses the whole training set instead of a train-val split. This model uses the whole training set instead of a train-val split.
""" """
......
...@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm ...@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm
import argparse, sys import argparse, sys
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
import six
from matplotlib import rc from matplotlib import rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) #rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
...@@ -52,11 +53,9 @@ def get_args(): ...@@ -52,11 +53,9 @@ def get_args():
help='title of the graph', help='title of the graph',
default='') default='')
parser.add_argument('--xlabel', parser.add_argument('--xlabel',
help='x label', help='x label', type=six.text_type)
default = 'x')
parser.add_argument('--ylabel', parser.add_argument('--ylabel',
help='y label', help='y label', type=six.text_type)
default='y')
parser.add_argument('-s', '--scale', parser.add_argument('-s', '--scale',
help='scale of each y, separated by comma') help='scale of each y, separated by comma')
parser.add_argument('--annotate-maximum', parser.add_argument('--annotate-maximum',
...@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys): ...@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys):
if args.annotate_maximum or args.annotate_minimum: if args.annotate_maximum or args.annotate_minimum:
annotate_min_max(truncate_data_x, data_y, ax) annotate_min_max(truncate_data_x, data_y, ax)
plt.xlabel(args.xlabel.decode('utf-8'), fontsize='xx-large') if args.xlabel:
plt.ylabel(args.ylabel.decode('utf-8'), fontsize='xx-large') plt.xlabel(args.xlabel, fontsize='xx-large')
if args.ylabel:
plt.ylabel(args.ylabel, fontsize='xx-large')
plt.legend(loc='best', fontsize='xx-large') plt.legend(loc='best', fontsize='xx-large')
# adjust maxx # adjust maxx
......
...@@ -56,13 +56,6 @@ class Callback(object): ...@@ -56,13 +56,6 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc) Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
""" """
@property
def global_step(self):
"""
Access the global step value of this training.
"""
return self.trainer.global_step
def trigger_epoch(self): def trigger_epoch(self):
""" """
Triggered after every epoch. Triggered after every epoch.
...@@ -95,7 +88,7 @@ class ProxyCallback(Callback): ...@@ -95,7 +88,7 @@ class ProxyCallback(Callback):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def __str__(self): def __str__(self):
return str(self.cb) return "Proxy-" + str(self.cb)
class PeriodicCallback(ProxyCallback): class PeriodicCallback(ProxyCallback):
""" """
......
...@@ -9,6 +9,7 @@ import re ...@@ -9,6 +9,7 @@ import re
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname from ..tfutils.varmanip import get_savename_from_varname
from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
global_step=self.global_step, global_step=get_global_step(),
write_meta_graph=False) write_meta_graph=False)
# create a symbolic link for the latest model # create a symbolic link for the latest model
......
...@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow): ...@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow):
""" """
:param shapes: a list of lists/tuples :param shapes: a list of lists/tuples
:param size: size of this DataFlow :param size: size of this DataFlow
:param random: whether to randomly generate data every iteration. note
that only generating the data could be time-consuming!
""" """
super(FakeData, self).__init__() super(FakeData, self).__init__()
self.shapes = shapes self.shapes = shapes
......
...@@ -13,7 +13,6 @@ import tensorflow as tf ...@@ -13,7 +13,6 @@ import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
...@@ -32,7 +31,6 @@ class Trainer(object): ...@@ -32,7 +31,6 @@ class Trainer(object):
summary_writer: a `tf.SummaryWriter` summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig` config: a `TrainConfig`
model: a `ModelDesc` model: a `ModelDesc`
global_step: a `int`
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -44,7 +42,7 @@ class Trainer(object): ...@@ -44,7 +42,7 @@ class Trainer(object):
self.config = config self.config = config
self.model = config.model self.model = config.model
self.model.get_input_vars() # ensure they are present self.model.get_input_vars() # ensure they are present
self._extra_threads_procs = config.extra_threads_procs self.init_session_and_coord()
@abstractmethod @abstractmethod
def train(self): def train(self):
...@@ -84,15 +82,6 @@ class Trainer(object): ...@@ -84,15 +82,6 @@ class Trainer(object):
""" This is called right after all steps in an epoch are finished""" """ This is called right after all steps in an epoch are finished"""
pass pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
def _process_summary(self, summary_str): def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
for val in summary.value: for val in summary.value:
...@@ -107,31 +96,39 @@ class Trainer(object): ...@@ -107,31 +96,39 @@ class Trainer(object):
get_global_step()) get_global_step())
self.stat_holder.add_stat(name, val) self.stat_holder.add_stat(name, val)
def main_loop(self): def finalize_graph(self):
# some final operations that might modify the graph # some final operations that might modify the graph
get_global_step_var() # ensure there is such var, before finalizing the graph get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.setup_graph(weakref.proxy(self)) callbacks.setup_graph(weakref.proxy(self))
self._init_summary()
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
self.summary_writer = tf.train.SummaryWriter(logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
self.sess.run(tf.initialize_all_variables()) self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
self._start_concurrency() tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
def main_loop(self):
self.finalize_graph()
callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
self.global_step = get_global_step() logger.info("Start training with global_step={}".format(get_global_step()))
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train() callbacks.before_train()
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1): self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
self.epoch_num, self.global_step + self.config.step_per_epoch)): self.epoch_num, get_global_step() + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
...@@ -139,7 +136,6 @@ class Trainer(object): ...@@ -139,7 +136,6 @@ class Trainer(object):
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful? #callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
...@@ -155,18 +151,6 @@ class Trainer(object): ...@@ -155,18 +151,6 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
def _start_concurrency(self):
"""
Run all threads before starting training
"""
logger.info("Starting all threads & procs ...")
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
with self.sess.as_default():
# avoid sigint get handled by other processes
start_proc_mask_signal(self._extra_threads_procs)
def process_grads(self, grads): def process_grads(self, grads):
g = [] g = []
for grad, var in grads: for grad, var in grads:
......
...@@ -32,7 +32,6 @@ class TrainConfig(object): ...@@ -32,7 +32,6 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf :param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1. :param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given. :param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param extra_threads_procs: list of `Startable` threads or processes
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -72,6 +71,10 @@ class TrainConfig(object): ...@@ -72,6 +71,10 @@ class TrainConfig(object):
self.tower = [0] self.tower = [0]
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread
self.callbacks.cbs.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
......
...@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer):
class SyncMultiGPUTrainer(MultiGPUTrainer): class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self): def train(self):
self.init_session_and_coord()
self._build_enque_thread() self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
...@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
class AsyncMultiGPUTrainer(MultiGPUTrainer): class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self): def train(self):
self.init_session_and_coord()
self._build_enque_thread() self._build_enque_thread()
grad_list = self._multi_tower_grads() grad_list = self._multi_tower_grads()
......
...@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer): ...@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) avg_maintain_op)
self.init_session_and_coord()
describe_model() describe_model()
# create an infinte data producer # create an infinte data producer
self.config.dataset.reset_state() self.config.dataset.reset_state()
...@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer): ...@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer):
def train(self): def train(self):
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self.init_session_and_coord()
self._build_enque_thread() self._build_enque_thread()
grads = self._single_tower_grad() grads = self._single_tower_grad()
......
...@@ -113,9 +113,10 @@ def get_caffe_pb(): ...@@ -113,9 +113,10 @@ def get_caffe_pb():
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file): if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir) proto_path = download(CAFFE_PROTO_URL, dir)
assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir)) ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \ assert ret == 0, \
"caffe proto compilation failed! Did you install protoc?" "Command `protoc caffe.proto --python_out .` failed!"
import imp import imp
return imp.load_source('caffepb', caffe_pb_file) return imp.load_source('caffepb', caffe_pb_file)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment