Commit f80843dc authored by Yuxin Wu's avatar Yuxin Wu

distinguish between sess.run call and run_step call. fix WGAN examples. (#147)

parent eee05770
...@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op') self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min self.train_op = self.d_min
def run_step(self):
ret = self.sess.run([self.train_op] + self.get_extra_fetches())
return ret[1:]
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
......
...@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase): ...@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def run_step(self): def run_step(self):
for k in range(5): for k in range(5):
self.sess.run(self.d_min) self.monitored_sess.run(self.d_min)
ret = self.sess.run([self.g_min] + self.get_extra_fetches()) self.monitored_sess.run(self.g_min)
return ret[1:]
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -114,11 +114,6 @@ class Callback(object): ...@@ -114,11 +114,6 @@ class Callback(object):
def epoch_num(self): def epoch_num(self):
return self.trainer.epoch_num return self.trainer.epoch_num
@property
def local_step(self):
# inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return self.trainer.local_step + 1
@property @property
def global_step(self): def global_step(self):
return self.trainer.global_step return self.trainer.global_step
......
...@@ -102,7 +102,6 @@ class Callbacks(Callback): ...@@ -102,7 +102,6 @@ class Callbacks(Callback):
traceback.print_exc() traceback.print_exc()
def get_hooks(self): def get_hooks(self):
# TODO skip
return [CallbackHook(cb) for cb in self.cbs] return [CallbackHook(cb) for cb in self.cbs]
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -12,7 +12,9 @@ import tqdm ...@@ -12,7 +12,9 @@ import tqdm
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME, from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME,
LOCAL_STEP_OP_NAME) LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value from ..tfutils.common import (
get_op_tensor_name, get_global_step_var,
get_global_step_value, get_op_or_tensor_by_name)
from .base import Callback from .base import Callback
__all__ = ['StepTensorPrinter', 'MaintainStepCounter', __all__ = ['StepTensorPrinter', 'MaintainStepCounter',
...@@ -33,8 +35,11 @@ class StepTensorPrinter(Callback): ...@@ -33,8 +35,11 @@ class StepTensorPrinter(Callback):
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!") logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!")
self._names = names self._names = names
def _before_train(self):
self._fetches = get_op_or_tensor_by_name(self._names)
def _extra_fetches(self): def _extra_fetches(self):
return self._names return self._fetches
def _trigger_step(self, *args): def _trigger_step(self, *args):
assert len(args) == len(self._names), len(args) assert len(args) == len(self._names), len(args)
...@@ -63,9 +68,15 @@ class MaintainStepCounter(Callback): ...@@ -63,9 +68,15 @@ class MaintainStepCounter(Callback):
gs_val = get_global_step_value() gs_val = get_global_step_value()
if gs_val != 0: if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val)) logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.trainer.local_step
def _extra_fetches(self): def _extra_fetches(self):
return [self.gs_incr_var.op] # increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
return [self.gs_incr_var.op]
else:
return []
class ProgressBar(Callback): class ProgressBar(Callback):
...@@ -80,21 +91,33 @@ class ProgressBar(Callback): ...@@ -80,21 +91,33 @@ class ProgressBar(Callback):
self._names = [get_op_tensor_name(n)[1] for n in names] self._names = [get_op_tensor_name(n)[1] for n in names]
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _extra_fetches(self):
return self._names
def _before_train(self): def _before_train(self):
self._fetches = get_op_or_tensor_by_name(self._names)
self._last_updated = self.trainer.local_step
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
if len(self._names): if len(self._names):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _extra_fetches(self):
if self.trainer.local_step != self._last_updated:
# local_step == number of steps that have finished in this epoch
self._last_updated = self.trainer.local_step
if self.trainer.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
else:
self._bar.update()
# XXX TODO move this to trigger_step after rename
if self.trainer.local_step == self._total - 1:
self._bar.close()
return self._fetches
else:
return []
def _trigger_step(self, *args): def _trigger_step(self, *args):
if self.local_step == 1: if len(args):
self._bar = tqdm.trange(self._total, **self._tqdm_args)
if len(self._names):
self._bar.set_postfix(zip(self._tags, args)) self._bar.set_postfix(zip(self._tags, args))
self._bar.update()
if self.local_step == self._total:
self._bar.close()
...@@ -34,7 +34,9 @@ class PeriodicTrigger(ProxyCallback): ...@@ -34,7 +34,9 @@ class PeriodicTrigger(ProxyCallback):
def _trigger_step(self, *args): def _trigger_step(self, *args):
if self._step_k is None: if self._step_k is None:
return return
if self.local_step % self._step_k == 0: # trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
if (self.trainer.local_step + 1) % self._step_k == 0:
self.cb.trigger() self.cb.trigger()
def _trigger_epoch(self, *args): def _trigger_epoch(self, *args):
......
...@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
for k in out_shape: for k in out_shape:
if not isinstance(k, int): if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k)) raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k))
out_channel = out_shape[channel_axis] out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch
shp3_static = shp3_dyn = out_shape shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel] filter_shape = kernel_shape + [out_channel, in_channel]
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from six.moves import map
from ..utils.naming import ( from ..utils.naming import (
GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_VAR_NAME,
...@@ -133,7 +134,7 @@ def get_op_or_tensor_by_name(name): ...@@ -133,7 +134,7 @@ def get_op_or_tensor_by_name(name):
if not isinstance(name, list): if not isinstance(name, list):
return f(name) return f(name)
else: else:
return map(f, name) return list(map(f, name))
def get_name_scope_name(): def get_name_scope_name():
......
...@@ -54,7 +54,7 @@ class Trainer(object): ...@@ -54,7 +54,7 @@ class Trainer(object):
self.model = config.model self.model = config.model
self.epoch_num = self.config.starting_epoch - 1 self.epoch_num = self.config.starting_epoch - 1
self.local_step = 0 self.local_step = -1
def train(self): def train(self):
""" Start training """ """ Start training """
......
...@@ -46,21 +46,6 @@ class FeedfreeTrainerBase(Trainer): ...@@ -46,21 +46,6 @@ class FeedfreeTrainerBase(Trainer):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self) self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost()
opt = self.config.optimizer
# GATE_NONE faster?
grads = opt.compute_gradients(
cost,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
self.monitored_sess.run(self.train_op) self.monitored_sess.run(self.train_op)
...@@ -82,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -82,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# import sys; sys.exit() # import sys; sys.exit()
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost()
opt = self.config.optimizer
# GATE_NONE faster?
grads = opt.compute_gradients(
cost,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment