Commit 12192d59 authored by Yuxin Wu's avatar Yuxin Wu

clean-ups

parent f67a1aff
......@@ -19,6 +19,7 @@ class InputSource(object):
""" Base class for the abstract InputSource. """
_name_scope = None
_setup_done = False
def get_input_tensors(self):
"""
......@@ -42,11 +43,15 @@ class InputSource(object):
list[Callback]: extra callbacks needed by this InputSource.
"""
self._setup(inputs_desc)
self._setup_done = True
return self.get_callbacks()
def _setup(self, inputs_desc):
pass
def setup_done(self):
return self._setup_done
@memoized
def get_callbacks(self):
"""
......
......@@ -45,6 +45,7 @@ class SimpleBuilder(GraphBuilder):
Returns:
tf.Operation: the training op
"""
assert input.setup_done()
with TowerContext('', is_training=True) as ctx:
cost = get_cost_fn(*input.get_input_tensors())
......@@ -132,6 +133,7 @@ class DataParallelBuilder(GraphBuilder):
@staticmethod
def _make_fn(input, get_cost_fn, get_opt_fn):
# internal use only
assert input.setup_done()
get_opt_fn = memoized(get_opt_fn)
def get_grad_fn():
......
......@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory
__all__ = ['Trainer', 'StopTraining', 'launch_train']
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
......@@ -295,6 +295,8 @@ def launch_train(
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
** Work In Progress! Don't use**
This is another trainer interface, to start training **after** the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training with callbacks & monitors.
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from ..graph_builder.training import (
......@@ -27,6 +28,8 @@ class MultiGPUTrainerBase(Trainer):
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
log_deprecated("MultiGPUTrainerBase.build_on_multitower",
"Please use DataParallelBuilder.build_on_towers", "2018-01-31")
return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment