Commit 12192d59 authored by Yuxin Wu's avatar Yuxin Wu

clean-ups

parent f67a1aff
...@@ -19,6 +19,7 @@ class InputSource(object): ...@@ -19,6 +19,7 @@ class InputSource(object):
""" Base class for the abstract InputSource. """ """ Base class for the abstract InputSource. """
_name_scope = None _name_scope = None
_setup_done = False
def get_input_tensors(self): def get_input_tensors(self):
""" """
...@@ -42,11 +43,15 @@ class InputSource(object): ...@@ -42,11 +43,15 @@ class InputSource(object):
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
""" """
self._setup(inputs_desc) self._setup(inputs_desc)
self._setup_done = True
return self.get_callbacks() return self.get_callbacks()
def _setup(self, inputs_desc): def _setup(self, inputs_desc):
pass pass
def setup_done(self):
return self._setup_done
@memoized @memoized
def get_callbacks(self): def get_callbacks(self):
""" """
......
...@@ -45,6 +45,7 @@ class SimpleBuilder(GraphBuilder): ...@@ -45,6 +45,7 @@ class SimpleBuilder(GraphBuilder):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
""" """
assert input.setup_done()
with TowerContext('', is_training=True) as ctx: with TowerContext('', is_training=True) as ctx:
cost = get_cost_fn(*input.get_input_tensors()) cost = get_cost_fn(*input.get_input_tensors())
...@@ -132,6 +133,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -132,6 +133,7 @@ class DataParallelBuilder(GraphBuilder):
@staticmethod @staticmethod
def _make_fn(input, get_cost_fn, get_opt_fn): def _make_fn(input, get_cost_fn, get_opt_fn):
# internal use only # internal use only
assert input.setup_done()
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
def get_grad_fn(): def get_grad_fn():
......
...@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession ...@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
__all__ = ['Trainer', 'StopTraining', 'launch_train'] __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException): class StopTraining(BaseException):
...@@ -295,6 +295,8 @@ def launch_train( ...@@ -295,6 +295,8 @@ def launch_train(
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999): starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
""" """
** Work In Progress! Don't use**
This is another trainer interface, to start training **after** the graph has been built already. This is another trainer interface, to start training **after** the graph has been built already.
You can build the graph however you like You can build the graph however you like
(with or without tensorpack), and invoke this function to start training with callbacks & monitors. (with or without tensorpack), and invoke this function to start training with callbacks & monitors.
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..graph_builder.training import ( from ..graph_builder.training import (
...@@ -27,6 +28,8 @@ class MultiGPUTrainerBase(Trainer): ...@@ -27,6 +28,8 @@ class MultiGPUTrainerBase(Trainer):
For backward compatibility only For backward compatibility only
""" """
def build_on_multi_tower(towers, func, devices=None, use_vs=None): def build_on_multi_tower(towers, func, devices=None, use_vs=None):
log_deprecated("MultiGPUTrainerBase.build_on_multitower",
"Please use DataParallelBuilder.build_on_towers", "2018-01-31")
return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs) return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment