Commit 77c8bde9 authored by Yuxin Wu's avatar Yuxin Wu

trainer.build_train_tower(). and some doc fix.

parent a313662f
......@@ -23,19 +23,17 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1'
MOCK_MODULES = ['scipy',
'tensorflow', 'tensorflow.contrib',
'tensorflow.python.ops',
'tensorflow.contrib.framework',
'tensorflow.models',
'tensorflow.models.rnn',
'tensorflow.models.rnn.ptb',
'tensorflow.python',
'tensorflow.python.training',
'sklearn.datasets',
#'tensorflow', 'tensorflow.contrib',
#'tensorflow.python.ops',
#'tensorflow.contrib.framework',
#'tensorflow.python',
#'tensorflow.python.training',
'sklearn.datasets', 'sklearn',
'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb', 'tornado.concurrent',
'tornado', 'msgpack', 'msgpack_numpy', 'ale_python_interface',
'sklearn', 'functools32']
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb',
'tornado.concurrent', 'tornado',
'msgpack', 'msgpack_numpy',
'gym', 'functools32']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name)
......
......@@ -2,5 +2,6 @@ termcolor
numpy
tqdm
decorator
tensorflow
Sphinx==1.5.1
recommonmark==0.4.0
......@@ -73,9 +73,7 @@ class GANTrainer(FeedfreeTrainerBase):
def _setup(self):
super(GANTrainer, self)._setup()
with TowerContext(''):
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
self.build_train_tower()
opt = self.model.get_optimizer()
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
......
......@@ -11,6 +11,9 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
class ProxyOptimizer(tf.train.Optimizer):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def __init__(self, opt):
self._opt = opt
......@@ -54,8 +57,8 @@ def apply_grad_processors(opt, gradprocs):
class PostProcessVariablesOptimizer(ProxyOptimizer):
"""
An optimizer which applies an operation to variables (e.g. clipping,
quantization) after updating the gradient.
An optimizer which applies an operation to variables
(e.g. clipping, quantization) after updating the gradient.
"""
def __init__(self, opt, func, colocate=True):
"""
......
......@@ -102,6 +102,10 @@ class TowerContext(object):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def __str__(self):
return "TowerContext(name={}, is_training={})".format(
self._name, self._is_training)
def get_current_tower_context():
global _CurrentTowerContext
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from ..utils import log_deprecated
from ..tfutils.tower import TowerContext
from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_data import QueueInput, FeedfreeInput
from .base import Trainer
......@@ -27,8 +27,20 @@ class FeedfreeTrainerBase(Trainer):
summary_str = self.summary_op.eval()
self.add_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
def build_train_tower(self):
"""
Get input tensors from `self.input_method` and build the graph.
"""
def f():
inputs = self._input_method.get_input_tensors()
self.model.build_graph(inputs)
ctx = get_current_tower_context()
if ctx is None:
with TowerContext(''):
f()
else:
assert ctx.is_training, ctx
f()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
......@@ -39,16 +51,15 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
self.build_train_tower()
cost = self.model.get_cost()
opt = self.config.optimizer
# GATE_NONE faster?
grads = opt.compute_gradients(
cost_var,
cost,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost_var, grads
return cost, grads
def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost."""
......
......@@ -141,8 +141,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# grads = grad_list[0]
else:
def get_cost():
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
self.build_train_tower()
return self.model.get_cost()
cost_list = MultiGPUTrainer._multi_tower_costs(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment