Commit 8837d748 authored by Yuxin Wu's avatar Yuxin Wu

Expose SimpleTrainer.build_graph; Provide a default Trainer.run_step

parent 064ea7c7
...@@ -71,10 +71,10 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -71,10 +71,10 @@ class GANTrainer(FeedfreeTrainerBase):
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op') g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([g_min]):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op') d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min self.train_op = d_min
class SeparateGANTrainer(FeedfreeTrainerBase): class SeparateGANTrainer(FeedfreeTrainerBase):
...@@ -137,12 +137,12 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase): ...@@ -137,12 +137,12 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# run one d_min after one g_min # run one d_min after one g_min
self.g_min = opt.minimize(g_loss, var_list=self.model.g_vars, g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
colocate_gradients_with_ops=True, name='g_op') colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([g_min]):
self.d_min = opt.minimize(d_loss, var_list=self.model.d_vars, d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
colocate_gradients_with_ops=True, name='d_op') colocate_gradients_with_ops=True, name='d_op')
self.train_op = self.d_min self.train_op = d_min
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -95,7 +95,8 @@ class InferenceRunnerBase(Callback): ...@@ -95,7 +95,8 @@ class InferenceRunnerBase(Callback):
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(self._tower_name, device, self._input_source) self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks() cbs = self._input_source.get_callbacks()
......
...@@ -84,7 +84,7 @@ class PredictorFactory(object): ...@@ -84,7 +84,7 @@ class PredictorFactory(object):
Args: Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower) tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
Returns: Returns:
an online predictor (which has to be used under the default session) an online predictor (which has to be used under a default session)
""" """
tower_name = 'towerp{}'.format(tower) tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower] tower = self._towers[tower]
......
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod
import time import time
import weakref import weakref
import six
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
...@@ -30,7 +28,6 @@ class StopTraining(BaseException): ...@@ -30,7 +28,6 @@ class StopTraining(BaseException):
pass pass
@six.add_metaclass(ABCMeta)
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
...@@ -102,10 +99,18 @@ class Trainer(object): ...@@ -102,10 +99,18 @@ class Trainer(object):
self.setup() self.setup()
self.main_loop() self.main_loop()
@abstractmethod
def run_step(self): def run_step(self):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
""" """
Defines what to do in one iteration, by default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
assert hasattr(self, 'train_op'), \
"Please either set `Trainer.train_op` or provide an implementation " \
"of Trainer.run_step()!"
self.hooked_sess.run(self.train_op)
def _setup_input_source(self, input_source): def _setup_input_source(self, input_source):
""" """
......
...@@ -29,10 +29,6 @@ class FeedfreeTrainerBase(Trainer): ...@@ -29,10 +29,6 @@ class FeedfreeTrainerBase(Trainer):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._setup_input_source(self._input_source) self._setup_input_source(self._input_source)
def run_step(self):
""" Simply run ``self.train_op``."""
self.hooked_sess.run(self.train_op)
# deprecated # deprecated
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
......
...@@ -15,7 +15,7 @@ __all__ = ['SimpleTrainer'] ...@@ -15,7 +15,7 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer. """ A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow. Support both InputSource and DataFlow.
When DataFlow is given, the InputSource to be used will be ``FeedInput(df)``. When DataFlow is given instead of InputSource, the InputSource to be used will be ``FeedInput(df)``.
""" """
def __init__(self, config): def __init__(self, config):
...@@ -35,13 +35,28 @@ class SimpleTrainer(Trainer): ...@@ -35,13 +35,28 @@ class SimpleTrainer(Trainer):
"Consider QueueInput or other InputSource instead.") "Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
def run_step(self): @staticmethod
self.hooked_sess.run(self.train_op) def setup_graph(self, model, input):
"""
Setup graph for simple trainer.
def _setup(self): Args:
self._setup_input_source(self._input_source) model (ModelDesc):
input (InputSource):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
cbs = input.get_callbacks()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) model.build_graph(input)
cost, grads = self.model.get_cost_and_grad() _, grads = model.get_cost_and_grad()
opt = self.model.get_optimizer() opt = model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op') train_op = opt.apply_gradients(grads, name='min_op')
return train_op, cbs
def _setup(self):
self.train_op, callbacks = SimpleTrainer.setup_graph(self.model, self._input_source)
self.config.callbacks.extend(callbacks)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment