Commit e04d846a authored by Yuxin Wu's avatar Yuxin Wu

predictorfactory

parent fefdcfb1
...@@ -203,7 +203,5 @@ if __name__ == '__main__': ...@@ -203,7 +203,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() QueueInputTrainer(config).train()
# TODO test if queue trainer works
#QueueInputTrainer(config).train()
...@@ -5,16 +5,18 @@ I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https:// ...@@ -5,16 +5,18 @@ I ran into the paper [DisturbLabel: Regularizing CNN on the Loss Layer](https://
which basically said that noisy data gives you better performance. which basically said that noisy data gives you better performance.
As many, I didn't believe the method and the results. As many, I didn't believe the method and the results.
This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and This is a simple mnist training script with DisturbLabel. It uses the simple architecture in the paper, and
hyperparameters in my original [mnist example](../mnist-convnet.py). The results surprised me: hyperparameters in my original [mnist example](../mnist-convnet.py).
The results surprised me, clean labels give the worst accuracy:
![mnist](mnist.png) ![mnist](mnist.png)
Experiements were repeated 15 times for p=0, 10 times for p=0.02 & 0.05, and 5 times for other values Experiements were repeated 15 times for p=0, 10 times for p=0.02 & 0.05, and 5 times for other values
of p. All experiements run for 100 epochs, with lr decay, which are enough for them to converge. of p. All experiements run for 100 epochs, with lr decay, which are enough for them to converge.
I suppose the disturb method works as a random noise to prevent SGD from getting stuck. I suppose the disturb method works as a random noise that could prevent SGD from getting stuck, if
However it didn't work for harder problems such as SVHN: training data are too easy or too few.
It didn't work for harder problems such as SVHN:
![svhn](svhn.png) ![svhn](svhn.png)
......
...@@ -156,4 +156,3 @@ if __name__ == '__main__': ...@@ -156,4 +156,3 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() QueueInputTrainer(config).train()
#SimpleTrainer(config).train()
...@@ -113,6 +113,7 @@ class Callbacks(Callback): ...@@ -113,6 +113,7 @@ class Callbacks(Callback):
self.test_callback_context = TestCallbackContext() self.test_callback_context = TestCallbackContext()
def _setup_graph(self): def _setup_graph(self):
with tf.name_scope(None):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType): if isinstance(cb.type, TrainCallbackType):
cb.setup_graph(self.trainer) cb.setup_graph(self.trainer)
......
...@@ -78,7 +78,7 @@ class InferenceRunner(Callback): ...@@ -78,7 +78,7 @@ class InferenceRunner(Callback):
for v in self.vcs: for v in self.vcs:
assert isinstance(v, Inferencer), str(v) assert isinstance(v, Inferencer), str(v)
def _before_train(self): def _setup_graph(self):
self.input_vars = self.trainer.model.reuse_input_vars() self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors() self._find_output_tensors()
input_names = [x.name for x in self.input_vars] input_names = [x.name for x in self.input_vars]
......
...@@ -52,6 +52,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -52,6 +52,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean = tf.identity(batch_mean, 'mean') batch_mean = tf.identity(batch_mean, 'mean')
batch_var = tf.identity(batch_var, 'variance') batch_var = tf.identity(batch_var, 'variance')
# XXX hack....
emaname = 'EMA' emaname = 'EMA'
in_main_tower = not batch_mean.name.startswith('towerp') in_main_tower = not batch_mean.name.startswith('towerp')
if in_main_tower: if in_main_tower:
......
...@@ -39,7 +39,8 @@ class Trainer(object): ...@@ -39,7 +39,8 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
self.model = config.model self.model = config.model
self.extra_threads_procs = config.extra_threads_procs self.model.get_input_vars() # ensure they are present
self._extra_threads_procs = config.extra_threads_procs
@abstractmethod @abstractmethod
def train(self): def train(self):
...@@ -53,7 +54,7 @@ class Trainer(object): ...@@ -53,7 +54,7 @@ class Trainer(object):
@abstractmethod @abstractmethod
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
""" return a predictor function""" """ return a online predictor"""
pass pass
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
...@@ -61,8 +62,7 @@ class Trainer(object): ...@@ -61,8 +62,7 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more Can be overwritten by subclasses to exploit more
parallelism among funcs. parallelism among funcs.
""" """
return [self.get_predict_func(input_name, output_names) return [self.get_predict_func(input_name, output_names) for k in range(n)]
for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
self._trigger_epoch() self._trigger_epoch()
...@@ -156,7 +156,7 @@ class Trainer(object): ...@@ -156,7 +156,7 @@ class Trainer(object):
with self.sess.as_default(): with self.sess.as_default():
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self.extra_threads_procs) start_proc_mask_signal(self._extra_threads_procs)
def process_grads(self, grads): def process_grads(self, grads):
g = [] g = []
......
...@@ -59,7 +59,7 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -59,7 +59,7 @@ class MultiGPUTrainer(QueueInputTrainer):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device # avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS) backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
...@@ -78,9 +78,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -78,9 +78,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model() describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
# [debug]: do nothing in training # [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0] #self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop() self.main_loop()
...@@ -107,8 +104,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -107,8 +104,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop() self.main_loop()
def _start_async_threads(self, grad_list): def _start_async_threads(self, grad_list):
......
...@@ -15,10 +15,42 @@ from ..tfutils.modelutils import describe_model ...@@ -15,10 +15,42 @@ from ..tfutils.modelutils import describe_model
from ..utils import * from ..utils import *
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import add_moving_summary
from ..predict import OnlinePredictor from ..predict import OnlinePredictor
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
PREFIX = 'towerp'
def __init__(self, trainer, towers):
self.trainer = trainer
self.towers = towers
self.tower_built = False
def get_predictor(self, input_names, output_names, tower):
if not self.tower_built:
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names)
output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.trainer.sess, raw_input_vars, output_vars)
def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
inputs = self.trainer.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.towers:
logger.info("Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('{}{}'.format(self.PREFIX, k)):
self.trainer.model.build_graph(inputs, False)
self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
data = next(self.data_producer) data = next(self.data_producer)
...@@ -30,7 +62,7 @@ class SimpleTrainer(Trainer): ...@@ -30,7 +62,7 @@ class SimpleTrainer(Trainer):
self.input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True) model.build_graph(self.input_vars, True)
cost_var = model.get_cost() cost_var = model.get_cost()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) add_moving_summary(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads) grads = self.process_grads(grads)
...@@ -55,11 +87,9 @@ class SimpleTrainer(Trainer): ...@@ -55,11 +87,9 @@ class SimpleTrainer(Trainer):
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
input_vars = get_vars_by_names(input_names) if not hasattr(self, 'predictor_factory'):
for v in input_vars: self.predictor_factory = PredictorFactory(self, [0])
assert v in self.input_vars return self.predictor_factory.get_predictor(input_names, output_names, 0)
output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.sess, input_vars, output_vars)
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer): def __init__(self, trainer):
...@@ -102,8 +132,6 @@ class EnqueueThread(threading.Thread): ...@@ -102,8 +132,6 @@ class EnqueueThread(threading.Thread):
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" Single GPU Trainer, takes input from a queue""" """ Single GPU Trainer, takes input from a queue"""
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
...@@ -120,10 +148,12 @@ class QueueInputTrainer(Trainer): ...@@ -120,10 +148,12 @@ class QueueInputTrainer(Trainer):
50, [x.dtype for x in self.input_vars], name='input_queue') 50, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
if predict_tower is None: if predict_tower is None:
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
predict_tower = [0] predict_tower = [0]
self.predict_tower = predict_tower self.predictor_factory = PredictorFactory(self, predict_tower)
self.dequed_inputs = None self.dequed_inputs = None
def _get_model_inputs(self): def _get_model_inputs(self):
...@@ -136,15 +166,6 @@ class QueueInputTrainer(Trainer): ...@@ -136,15 +166,6 @@ class QueueInputTrainer(Trainer):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
return ret return ret
def _build_predict_tower(self):
inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower:
logger.info("Building graph for predict tower p{}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False)
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_model_inputs()
...@@ -158,13 +179,13 @@ class QueueInputTrainer(Trainer): ...@@ -158,13 +179,13 @@ class QueueInputTrainer(Trainer):
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) # GATE_NONE cost_var, gate_gradients=0) # GATE_NONE
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) add_moving_summary(cost_var)
return grads return grads
def _build_enque_thread(self): def _build_enque_thread(self):
""" create a thread that keeps filling the queue """ """ create a thread that keeps filling the queue """
self.input_th = EnqueueThread(self) self.input_th = EnqueueThread(self)
self.extra_threads_procs.append(self.input_th) self._extra_threads_procs.append(self.input_th)
def train(self): def train(self):
assert self.config.nr_tower == 1, \ assert self.config.nr_tower == 1, \
...@@ -176,9 +197,6 @@ class QueueInputTrainer(Trainer): ...@@ -176,9 +197,6 @@ class QueueInputTrainer(Trainer):
grads = self.process_grads(grads) grads = self.process_grads(grads)
describe_model() describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
...@@ -213,14 +231,5 @@ class QueueInputTrainer(Trainer): ...@@ -213,14 +231,5 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func :param tower: return the kth predict_func
:returns: an `OnlinePredictor` :returns: an `OnlinePredictor`
""" """
tower = self.predict_tower[tower % len(self.predict_tower)] return self.predictor_factory.get_predictor(input_names, output_names, tower)
raw_input_vars = get_vars_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def get_predict_funcs(self, input_names, output_names, n):
""" return n predictors evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
...@@ -9,6 +9,9 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -9,6 +9,9 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
import tensorflow as tf
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.isupper()] __all__ = [x for x in all_local_names if x.isupper()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment