Commit 4cd01111 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] initial commit of step callbacks. should be compatible with the old examples.

parent 0652d859
......@@ -79,24 +79,28 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''):
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
# optimize G
grads = self.config.optimizer.compute_gradients(
self.model.g_loss, var_list=self.model.g_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_g())
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
# optimize D
with tf.control_dependencies([self.g_min]):
grads = self.config.optimizer.compute_gradients(
self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.d_min = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='d_op')
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average()
self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr)
self.train_op = self.d_min
def run_step(self):
self.sess.run(self.train_op)
ret = self.sess.run([self.train_op] + self.extra_fetches)
return ret[1:]
class RandomZData(DataFlow):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: steps.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Some common step callbacks. """
from six.moves import zip
from ..utils import logger
from ..tfutils.common import get_op_tensor_name
from ..tfutils.summary import summary_moving_average
from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage']
class StepStatPrinter(Callback):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`print_stat` or :func:`tf.Print` instead. """
def __init__(self, names):
names = [get_op_tensor_name(n)[1] for n in names]
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!")
self._names = names
def _extra_fetches(self):
return self._names
def _trigger_step(self, *args):
for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v))
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors added by :func:`summary.add_moving_summary`
in every step, and summarize them.
"""
def _setup_graph(self):
self.ema_op = summary_moving_average()
def _extra_fetches(self):
return [self.ema_op]
......@@ -40,6 +40,7 @@ class Trainer(object):
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
"""
def __init__(self, config):
......@@ -133,7 +134,7 @@ class Trainer(object):
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches()
self.extra_fetches = self.config.callbacks.extra_fetches()
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
......@@ -175,8 +176,9 @@ class Trainer(object):
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop():
return
self.run_step() # implemented by subclass
callbacks.trigger_step() # not useful?
fetch_data = self.run_step() # implemented by subclass
if fetch_data:
callbacks.trigger_step(*fetch_data)
# trigger epoch outside the timing region.
self.trigger_epoch()
except StopTraining:
......
......@@ -4,7 +4,7 @@
import tensorflow as tf
from ..callbacks.group import Callbacks
from ..callbacks import Callbacks, SummaryMovingAverage, StatPrinter
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -21,7 +21,8 @@ class TrainConfig(object):
"""
def __init__(self, dataflow=None, data=None,
model=None, optimizer=None, callbacks=None,
model=None, optimizer=None,
callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(),
session_init=None,
starting_epoch=1, step_per_epoch=None, max_epoch=99999,
......@@ -34,7 +35,11 @@ class TrainConfig(object):
or ``data`` has to be present.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (Callbacks): the callbacks to perform during training.
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
......@@ -50,6 +55,7 @@ class TrainConfig(object):
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
# process data
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
logger.warn("[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead.")
......@@ -65,8 +71,20 @@ class TrainConfig(object):
self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = callbacks
assert_type(self.callbacks, Callbacks)
if isinstance(callbacks, Callbacks):
# keep quiet now because I haven't determined the final API yet.
# logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!")
# logger.warn("[Deprecated] Please change the option 'callbacks=' to a list of "
# "callbacks without StatPrinter().")
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
self.model = model
assert_type(self.model, ModelDesc)
......@@ -102,12 +120,6 @@ class TrainConfig(object):
if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower]
# TODO deprecated @Jan20
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread
self.callbacks.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
......
......@@ -9,7 +9,6 @@ from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average
from .input_data import QueueInput, FeedfreeInput
from .base import Trainer
......@@ -55,7 +54,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost."""
self.sess.run(self.train_op)
ret = self.sess.run([self.train_op] + self.extra_fetches)
return ret[1:]
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# else:
......@@ -101,9 +101,8 @@ class SimpleFeedfreeTrainer(
cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
# skip training
# self.train_op = tf.group(*self.dequed_inputs)
......
......@@ -11,7 +11,6 @@ from six.moves import zip, range
from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
......@@ -113,13 +112,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
def run_step(self):
self.sess.run(self.train_op)
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
class AsyncMultiGPUTrainer(MultiGPUTrainer,
......@@ -169,10 +163,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = tf.group(
self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op')
self.train_op = self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var(), name='min_op')
self._start_async_threads(grad_list)
......@@ -199,7 +191,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
for th in self.training_threads: # resume all threads
th.resume()
next(self.async_step_counter)
self.sess.run(self.train_op)
return super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
......
......@@ -9,7 +9,6 @@ from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average
from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput
......@@ -43,10 +42,10 @@ class PredictorFactory(object):
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self):
tf.get_variable_scope().reuse_variables()
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
freeze_collection(SUMMARY_BACKUP_KEYS), \
tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.model.build_graph(self.model.get_input_vars())
build_prediction_graph(fn, self.towers)
......@@ -73,7 +72,9 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed()
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
ret = self.sess.run([self.train_op] + self.extra_fetches,
feed_dict=feed)
return ret[1:]
def _setup(self):
self._input_method._setup(self)
......@@ -87,9 +88,8 @@ class SimpleTrainer(Trainer):
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
def _trigger_epoch(self):
if self.summary_op is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment