Commit 77c8bde9 authored by Yuxin Wu's avatar Yuxin Wu

trainer.build_train_tower(). and some doc fix.

parent a313662f
...@@ -23,19 +23,17 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1' ...@@ -23,19 +23,17 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1'
MOCK_MODULES = ['scipy', MOCK_MODULES = ['scipy',
'tensorflow', 'tensorflow.contrib', #'tensorflow', 'tensorflow.contrib',
'tensorflow.python.ops', #'tensorflow.python.ops',
'tensorflow.contrib.framework', #'tensorflow.contrib.framework',
'tensorflow.models', #'tensorflow.python',
'tensorflow.models.rnn', #'tensorflow.python.training',
'tensorflow.models.rnn.ptb', 'sklearn.datasets', 'sklearn',
'tensorflow.python',
'tensorflow.python.training',
'sklearn.datasets',
'scipy.misc', 'h5py', 'nltk', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb', 'tornado.concurrent', 'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb',
'tornado', 'msgpack', 'msgpack_numpy', 'ale_python_interface', 'tornado.concurrent', 'tornado',
'sklearn', 'functools32'] 'msgpack', 'msgpack_numpy',
'gym', 'functools32']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name) sys.modules[mod_name] = mock.Mock(name=mod_name)
......
...@@ -2,5 +2,6 @@ termcolor ...@@ -2,5 +2,6 @@ termcolor
numpy numpy
tqdm tqdm
decorator decorator
tensorflow
Sphinx==1.5.1 Sphinx==1.5.1
recommonmark==0.4.0 recommonmark==0.4.0
...@@ -73,9 +73,7 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -73,9 +73,7 @@ class GANTrainer(FeedfreeTrainerBase):
def _setup(self): def _setup(self):
super(GANTrainer, self)._setup() super(GANTrainer, self)._setup()
with TowerContext(''): self.build_train_tower()
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op') self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
......
...@@ -11,6 +11,9 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer', ...@@ -11,6 +11,9 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
class ProxyOptimizer(tf.train.Optimizer): class ProxyOptimizer(tf.train.Optimizer):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def __init__(self, opt): def __init__(self, opt):
self._opt = opt self._opt = opt
...@@ -54,8 +57,8 @@ def apply_grad_processors(opt, gradprocs): ...@@ -54,8 +57,8 @@ def apply_grad_processors(opt, gradprocs):
class PostProcessVariablesOptimizer(ProxyOptimizer): class PostProcessVariablesOptimizer(ProxyOptimizer):
""" """
An optimizer which applies an operation to variables (e.g. clipping, An optimizer which applies an operation to variables
quantization) after updating the gradient. (e.g. clipping, quantization) after updating the gradient.
""" """
def __init__(self, opt, func, colocate=True): def __init__(self, opt, func, colocate=True):
""" """
......
...@@ -102,6 +102,10 @@ class TowerContext(object): ...@@ -102,6 +102,10 @@ class TowerContext(object):
self._scope.__exit__(exc_type, exc_val, exc_tb) self._scope.__exit__(exc_type, exc_val, exc_tb)
return False return False
def __str__(self):
return "TowerContext(name={}, is_training={})".format(
self._name, self._is_training)
def get_current_tower_context(): def get_current_tower_context():
global _CurrentTowerContext global _CurrentTowerContext
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import log_deprecated from ..utils import log_deprecated
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
...@@ -27,8 +27,20 @@ class FeedfreeTrainerBase(Trainer): ...@@ -27,8 +27,20 @@ class FeedfreeTrainerBase(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.add_summary(summary_str) self.add_summary(summary_str)
def _get_input_tensors(self): def build_train_tower(self):
return self._input_method.get_input_tensors() """
Get input tensors from `self.input_method` and build the graph.
"""
def f():
inputs = self._input_method.get_input_tensors()
self.model.build_graph(inputs)
ctx = get_current_tower_context()
if ctx is None:
with TowerContext(''):
f()
else:
assert ctx.is_training, ctx
f()
def _setup(self): def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
...@@ -39,16 +51,15 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -39,16 +51,15 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient""" """ get the cost and gradient"""
actual_inputs = self._get_input_tensors() self.build_train_tower()
self.model.build_graph(actual_inputs) cost = self.model.get_cost()
cost_var = self.model.get_cost()
opt = self.config.optimizer opt = self.config.optimizer
# GATE_NONE faster? # GATE_NONE faster?
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost_var, cost,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
return cost_var, grads return cost, grads
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
......
...@@ -141,8 +141,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -141,8 +141,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# grads = grad_list[0] # grads = grad_list[0]
else: else:
def get_cost(): def get_cost():
actual_inputs = self._get_input_tensors() self.build_train_tower()
self.model.build_graph(actual_inputs)
return self.model.get_cost() return self.model.get_cost()
cost_list = MultiGPUTrainer._multi_tower_costs( cost_list = MultiGPUTrainer._multi_tower_costs(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment