Commit 4cd01111 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] initial commit of step callbacks. should be compatible with the old examples.

parent 0652d859
...@@ -79,24 +79,28 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -79,24 +79,28 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''): with TowerContext(''):
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
# optimize G
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
self.model.g_loss, var_list=self.model.g_vars) self.model.g_loss, var_list=self.model.g_vars)
grads = apply_grad_processors( grads = apply_grad_processors(
grads, self.model.get_gradient_processor_g()) grads, self.model.get_gradient_processor_g())
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op') self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
# optimize D
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([self.g_min]):
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
self.model.d_loss, var_list=self.model.d_vars) self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors( grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d()) grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op') self.d_min = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='d_op')
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr') self.train_op = self.d_min
self.summary_op = summary_moving_average()
self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr)
def run_step(self): def run_step(self):
self.sess.run(self.train_op) ret = self.sess.run([self.train_op] + self.extra_fetches)
return ret[1:]
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: steps.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Some common step callbacks. """
from six.moves import zip
from ..utils import logger
from ..tfutils.common import get_op_tensor_name
from ..tfutils.summary import summary_moving_average
from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage']
class StepStatPrinter(Callback):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`print_stat` or :func:`tf.Print` instead. """
def __init__(self, names):
names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!")
self._names = names
def _extra_fetches(self):
return self._names
def _trigger_step(self, *args):
for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v))
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors added by :func:`summary.add_moving_summary`
in every step, and summarize them.
"""
def _setup_graph(self):
self.ema_op = summary_moving_average()
def _extra_fetches(self):
return [self.ema_op]
...@@ -40,6 +40,7 @@ class Trainer(object): ...@@ -40,6 +40,7 @@ class Trainer(object):
model (ModelDesc) model (ModelDesc)
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
coord (tf.train.Coordinator) coord (tf.train.Coordinator)
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
""" """
def __init__(self, config): def __init__(self, config):
...@@ -133,7 +134,7 @@ class Trainer(object): ...@@ -133,7 +134,7 @@ class Trainer(object):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches() self.extra_fetches = self.config.callbacks.extra_fetches()
if not hasattr(logger, 'LOG_DIR'): if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!") raise RuntimeError("logger directory wasn't set!")
...@@ -175,8 +176,9 @@ class Trainer(object): ...@@ -175,8 +176,9 @@ class Trainer(object):
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
callbacks.trigger_step() # not useful? if fetch_data:
callbacks.trigger_step(*fetch_data)
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..callbacks.group import Callbacks from ..callbacks import Callbacks, SummaryMovingAverage, StatPrinter
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -21,7 +21,8 @@ class TrainConfig(object): ...@@ -21,7 +21,8 @@ class TrainConfig(object):
""" """
def __init__(self, dataflow=None, data=None, def __init__(self, dataflow=None, data=None,
model=None, optimizer=None, callbacks=None, model=None, optimizer=None,
callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(), session_config=get_default_sess_config(),
session_init=None, session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999, starting_epoch=1, step_per_epoch=None, max_epoch=99999,
...@@ -34,7 +35,11 @@ class TrainConfig(object): ...@@ -34,7 +35,11 @@ class TrainConfig(object):
or ``data`` has to be present. or ``data`` has to be present.
model (ModelDesc): the model to train. model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig. optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (Callbacks): the callbacks to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session. session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
...@@ -50,6 +55,7 @@ class TrainConfig(object): ...@@ -50,6 +55,7 @@ class TrainConfig(object):
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
# process data
if 'dataset' in kwargs: if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset') dataflow = kwargs.pop('dataset')
logger.warn("[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead.") logger.warn("[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead.")
...@@ -65,8 +71,20 @@ class TrainConfig(object): ...@@ -65,8 +71,20 @@ class TrainConfig(object):
self.optimizer = optimizer self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer) assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = callbacks
assert_type(self.callbacks, Callbacks) if isinstance(callbacks, Callbacks):
# keep quiet now because I haven't determined the final API yet.
# logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!")
# logger.warn("[Deprecated] Please change the option 'callbacks=' to a list of "
# "callbacks without StatPrinter().")
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
...@@ -102,12 +120,6 @@ class TrainConfig(object): ...@@ -102,12 +120,6 @@ class TrainConfig(object):
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
# TODO deprecated @Jan20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread
self.callbacks.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
......
...@@ -9,7 +9,6 @@ from ..utils import logger ...@@ -9,7 +9,6 @@ from ..utils import logger
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
...@@ -55,7 +54,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -55,7 +54,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
self.sess.run(self.train_op) ret = self.sess.run([self.train_op] + self.extra_fetches)
return ret[1:]
# if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
# self.cnt = 0 # self.cnt = 0
# else: # else:
...@@ -101,9 +101,8 @@ class SimpleFeedfreeTrainer( ...@@ -101,9 +101,8 @@ class SimpleFeedfreeTrainer(
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = self.config.optimizer.apply_gradients(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), grads, get_global_step_var(), name='min_op')
summary_moving_average(), name='train_op')
# skip training # skip training
# self.train_op = tf.group(*self.dequed_inputs) # self.train_op = tf.group(*self.dequed_inputs)
......
...@@ -11,7 +11,6 @@ from six.moves import zip, range ...@@ -11,7 +11,6 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
...@@ -113,13 +112,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -113,13 +112,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients(
self.train_op = tf.group( grads, get_global_step_var(), name='min_op')
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
def run_step(self):
self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(MultiGPUTrainer, class AsyncMultiGPUTrainer(MultiGPUTrainer,
...@@ -169,10 +163,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -169,10 +163,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list] grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
self.train_op = tf.group( self.train_op = self.config.optimizer.apply_gradients(
self.config.optimizer.apply_gradients( grad_list[0], get_global_step_var(), name='min_op')
grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op')
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
...@@ -199,7 +191,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -199,7 +191,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
for th in self.training_threads: # resume all threads for th in self.training_threads: # resume all threads
th.resume() th.resume()
next(self.async_step_counter) next(self.async_step_counter)
self.sess.run(self.train_op) return super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self): def _trigger_epoch(self):
self.async_running = False self.async_running = False
......
...@@ -9,7 +9,6 @@ from .base import Trainer ...@@ -9,7 +9,6 @@ from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average
from ..predict import OnlinePredictor, build_prediction_graph from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput from .input_data import FeedInput
...@@ -43,10 +42,10 @@ class PredictorFactory(object): ...@@ -43,10 +42,10 @@ class PredictorFactory(object):
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self): def _build_predict_tower(self):
tf.get_variable_scope().reuse_variables()
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope # build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_): def fn(_):
self.model.build_graph(self.model.get_input_vars()) self.model.build_graph(self.model.get_input_vars())
build_prediction_graph(fn, self.towers) build_prediction_graph(fn, self.towers)
...@@ -73,7 +72,9 @@ class SimpleTrainer(Trainer): ...@@ -73,7 +72,9 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None ret = self.sess.run([self.train_op] + self.extra_fetches,
feed_dict=feed)
return ret[1:]
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method._setup(self)
...@@ -87,9 +88,8 @@ class SimpleTrainer(Trainer): ...@@ -87,9 +88,8 @@ class SimpleTrainer(Trainer):
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = self.config.optimizer.apply_gradients(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), grads, get_global_step_var(), name='min_op')
summary_moving_average(), name='train_op')
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment