Commit a6a2aba4 authored by Yuxin Wu's avatar Yuxin Wu

fix import

parent f409fbf0
......@@ -21,6 +21,6 @@ if _HAS_TF:
from tensorpack.trainv2 import *
else:
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase
from tensorpack.input_source import *
from tensorpack.predict import *
......@@ -20,7 +20,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utils']
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
......@@ -319,58 +319,3 @@ def _get_property(name):
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
def launch_train(
run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
"""
** Work In Progress! Don't use**
This is another trainer interface, to start training **after** the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training with callbacks & monitors.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are not known
only until the graph has been built.
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
model (None or ModelDesc): Certain callbacks (e.g. InferenceRunner) depends on
the existence of :class:`ModelDesc`. If you use a :class:`ModelDesc` to
build the graph, add it here to to allow those callbacks to work.
If you didn't use :class:`ModelDesc`, leave it empty.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
.. code-block:: python
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, model=model, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
"""
assert steps_per_epoch is not None, steps_per_epoch
trainer = Trainer(TrainConfig(
model=model,
callbacks=callbacks,
extra_callbacks=extra_callbacks,
monitors=monitors,
session_creator=session_creator,
session_config=session_config,
session_init=session_init,
starting_epoch=starting_epoch,
steps_per_epoch=steps_per_epoch,
max_epoch=max_epoch))
if isinstance(run_step, (tf.Tensor, tf.Operation)):
trainer.train_op = run_step
else:
assert callable(run_step), run_step
trainer.run_step = lambda self: run_step()
trainer.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment