Commit b82a1fda authored by Yuxin Wu's avatar Yuxin Wu

fix import; introducing new interface to train after graph has been built.

parent f7993410
...@@ -8,8 +8,9 @@ import numpy as np ...@@ -8,8 +8,9 @@ import numpy as np
import time import time
from tensorpack import (Trainer, QueueInput, from tensorpack import (Trainer, QueueInput,
ModelDescBase, DataFlow, StagingInputWrapper, ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter, MultiGPUTrainerBase,
TowerContext) TowerContext)
from tensorpack.train.utility import LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
......
...@@ -32,30 +32,11 @@ class InputDesc( ...@@ -32,30 +32,11 @@ class InputDesc(
shape (tuple): shape (tuple):
name (str): name (str):
""" """
shape = tuple(shape) # has to be tuple for self to be hashable shape = tuple(shape) # has to be tuple for "self" to be hashable
self = super(InputDesc, cls).__new__(cls, type, shape, name) self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None self._cached_placeholder = None
return self return self
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def build_placeholder(self, prefix=''): def build_placeholder(self, prefix=''):
""" """
Build a tf.placeholder from the metadata, with an optional prefix. Build a tf.placeholder from the metadata, with an optional prefix.
......
...@@ -284,3 +284,51 @@ class Trainer(object): ...@@ -284,3 +284,51 @@ class Trainer(object):
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.model, self.vs_name_for_predictor) self.model, self.vs_name_for_predictor)
return self._predictor_factory return self._predictor_factory
def launch_train(
run_step, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
This is a simpler interface to start training after the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are uknown
only after the graph has been built.
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
.. code-block:: python
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
"""
assert steps_per_epoch is not None, steps_per_epoch
trainer = Trainer(TrainConfig(
callbacks=callbacks,
extra_callbacks=extra_callbacks,
monitors=monitors,
session_creator=session_creator,
session_config=session_config,
session_init=session_init,
starting_epoch=starting_epoch,
steps_per_epoch=steps_per_epoch,
max_epoch=max_epoch))
if isinstance(run_step, (tf.Tensor, tf.Operation)):
trainer.train_op = run_step
else:
assert callable(run_step), run_step
trainer.run_step = lambda self: run_step()
trainer.train()
...@@ -9,7 +9,6 @@ from ..callbacks import ( ...@@ -9,7 +9,6 @@ from ..callbacks import (
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -69,9 +68,6 @@ class TrainConfig(object): ...@@ -69,9 +68,6 @@ class TrainConfig(object):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
# process data & model # process data & model
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!" assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!"
if dataflow is not None: if dataflow is not None:
assert_type(dataflow, DataFlow) assert_type(dataflow, DataFlow)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment