Commit b82a1fda authored by Yuxin Wu's avatar Yuxin Wu

fix import; introducing new interface to train after graph has been built.

parent f7993410
......@@ -8,8 +8,9 @@ import numpy as np
import time
from tensorpack import (Trainer, QueueInput,
ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter,
MultiGPUTrainerBase,
TowerContext)
from tensorpack.train.utility import LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
......
......@@ -32,30 +32,11 @@ class InputDesc(
shape (tuple):
name (str):
"""
shape = tuple(shape) # has to be tuple for self to be hashable
shape = tuple(shape) # has to be tuple for "self" to be hashable
self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None
return self
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def build_placeholder(self, prefix=''):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
......
......@@ -284,3 +284,51 @@ class Trainer(object):
self._predictor_factory = PredictorFactory(
self.model, self.vs_name_for_predictor)
return self._predictor_factory
def launch_train(
run_step, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
This is a simpler interface to start training after the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are uknown
only after the graph has been built.
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
.. code-block:: python
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
"""
assert steps_per_epoch is not None, steps_per_epoch
trainer = Trainer(TrainConfig(
callbacks=callbacks,
extra_callbacks=extra_callbacks,
monitors=monitors,
session_creator=session_creator,
session_config=session_config,
session_init=session_init,
starting_epoch=starting_epoch,
steps_per_epoch=steps_per_epoch,
max_epoch=max_epoch))
if isinstance(run_step, (tf.Tensor, tf.Operation)):
trainer.train_op = run_step
else:
assert callable(run_step), run_step
trainer.run_step = lambda self: run_step()
trainer.train()
......@@ -9,7 +9,6 @@ from ..callbacks import (
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
......@@ -69,9 +68,6 @@ class TrainConfig(object):
assert isinstance(v, tp), v.__class__
# process data & model
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!"
if dataflow is not None:
assert_type(dataflow, DataFlow)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment