Commit d43c8a28 authored by Yuxin Wu's avatar Yuxin Wu

updates about `launch_train`

parent b82a1fda
......@@ -8,12 +8,13 @@ A High Level Glance
.. image:: https://user-images.githubusercontent.com/1381301/29187907-2caaa740-7dc6-11e7-8220-e20ca52c3ca6.png
* DataFlow is a library to load data efficiently in Python.
* ``DataFlow`` is a library to load data efficiently in Python.
Apart from DataFlow, native TF operators can be used for data loading as well.
They will eventually be wrapped under the same interface and go through prefetching.
They will eventually be wrapped under the same ``InputSource`` interface and go through prefetching.
* You can use any TF-based symbolic function library to define a model, including
a small set of models within tensorpack. ``ModelDesc`` is an interface to connect symbolic graph to tensorpack trainers.
a small set of models within tensorpack. ``ModelDesc`` is an interface to connect the graph with the
``InputSource`` interface.
* tensorpack trainers manage the training loops for you.
They also include data parallel logic for multi-GPU or distributed training.
......@@ -22,6 +23,13 @@ A High Level Glance
* Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training,
everything you want to do other than the main iterations can be defined through callbacks and easily reused.
* All the components, though work perfectly together, are highly decorrelated: you can:
* Use DataFlow alone as a data loading library, without tensorfow at all.
* Use tensorpack to build the graph with multi-GPU or distributed support,
then train it with your own loops.
* Build the graph on your own, and train it with tensorpack callbacks.
User Tutorials
========================
......
......@@ -58,6 +58,8 @@ class InferenceRunnerBase(Callback):
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
how much iterations to run, so you want it to be accurate.
Also, InferenceRunner assumes that `trainer.model` exists.
"""
def __init__(self, input, infs, extra_hooks=None):
"""
......@@ -120,6 +122,7 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
def _setup_graph(self):
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
......@@ -178,6 +181,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus
def _setup_graph(self):
assert self.trainer.model is not None
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
self._handles = []
......
......@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory
__all__ = ['Trainer', 'StopTraining']
__all__ = ['Trainer', 'StopTraining', 'launch_train']
class StopTraining(BaseException):
......@@ -287,21 +287,25 @@ class Trainer(object):
def launch_train(
run_step, callbacks=None, extra_callbacks=None, monitors=None,
run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
This is a simpler interface to start training after the graph has been built already.
This is another trainer interface, to start training **after** the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training.
(with or without tensorpack), and invoke this function to start training with callbacks & monitors.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are uknown
only after the graph has been built.
One typical use is that callbacks often depend on names that are not known
only until the graph has been built.
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
model (None or ModelDesc): Certain callbacks (e.g. InferenceRunner) depends on
the existence of :class:`ModelDesc`. If you use a :class:`ModelDesc` to
build the graph, add it here to to allow those callbacks to work.
If you didn't use :class:`ModelDesc`, leave it empty.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
......@@ -310,13 +314,14 @@ def launch_train(
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
launch_train(train_op, model=model, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
"""
assert steps_per_epoch is not None, steps_per_epoch
trainer = Trainer(TrainConfig(
model=model,
callbacks=callbacks,
extra_callbacks=extra_callbacks,
monitors=monitors,
......
......@@ -64,7 +64,7 @@ def _set_file(path):
if os.path.isfile(path):
backup_name = path + '.' + _get_time_str()
shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name)) # noqa: F821
_logger.info("Existing log file '{}' backuped to '{}'".format(path, backup_name)) # noqa: F821
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment