Commit 8837d748 authored by Yuxin Wu's avatar Yuxin Wu

Expose SimpleTrainer.build_graph; Provide a default Trainer.run_step

parent 064ea7c7
......@@ -71,10 +71,10 @@ class GANTrainer(FeedfreeTrainerBase):
opt = self.model.get_optimizer()
# by default, run one d_min after one g_min
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min
g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = d_min
class SeparateGANTrainer(FeedfreeTrainerBase):
......@@ -137,12 +137,12 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
opt = self.model.get_optimizer()
# run one d_min after one g_min
self.g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = self.d_min
g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min
class RandomZData(DataFlow):
......
......@@ -95,7 +95,8 @@ class InferenceRunnerBase(Callback):
self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(self._tower_name, device, self._input_source)
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
......
......@@ -84,7 +84,7 @@ class PredictorFactory(object):
Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
Returns:
an online predictor (which has to be used under the default session)
an online predictor (which has to be used under a default session)
"""
tower_name = 'towerp{}'.format(tower)
tower = self._towers[tower]
......
......@@ -2,10 +2,8 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod
import time
import weakref
import six
from six.moves import range
import tensorflow as tf
......@@ -30,7 +28,6 @@ class StopTraining(BaseException):
pass
@six.add_metaclass(ABCMeta)
class Trainer(object):
""" Base class for a trainer.
......@@ -102,10 +99,18 @@ class Trainer(object):
self.setup()
self.main_loop()
@abstractmethod
def run_step(self):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
Defines what to do in one iteration, by default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
assert hasattr(self, 'train_op'), \
"Please either set `Trainer.train_op` or provide an implementation " \
"of Trainer.run_step()!"
self.hooked_sess.run(self.train_op)
def _setup_input_source(self, input_source):
"""
......
......@@ -29,10 +29,6 @@ class FeedfreeTrainerBase(Trainer):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._setup_input_source(self._input_source)
def run_step(self):
""" Simply run ``self.train_op``."""
self.hooked_sess.run(self.train_op)
# deprecated
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
......
......@@ -15,7 +15,7 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow.
When DataFlow is given, the InputSource to be used will be ``FeedInput(df)``.
When DataFlow is given instead of InputSource, the InputSource to be used will be ``FeedInput(df)``.
"""
def __init__(self, config):
......@@ -35,13 +35,28 @@ class SimpleTrainer(Trainer):
"Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
def run_step(self):
self.hooked_sess.run(self.train_op)
@staticmethod
def setup_graph(self, model, input):
"""
Setup graph for simple trainer.
def _setup(self):
self._setup_input_source(self._input_source)
Args:
model (ModelDesc):
input (InputSource):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
cbs = input.get_callbacks()
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
cost, grads = self.model.get_cost_and_grad()
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
model.build_graph(input)
_, grads = model.get_cost_and_grad()
opt = model.get_optimizer()
train_op = opt.apply_gradients(grads, name='min_op')
return train_op, cbs
def _setup(self):
self.train_op, callbacks = SimpleTrainer.setup_graph(self.model, self._input_source)
self.config.callbacks.extend(callbacks)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment