Commit b4085f77 authored by Yuxin Wu's avatar Yuxin Wu

update docs. remove _register_monitor

parent a21eb14f
......@@ -67,6 +67,11 @@ class InputDesc(
return self._cached_placeholder
return self.build_placeholder()
@staticmethod
def from_tensor(t):
return InputDesc(
t.dtype, t.shape.as_list(), t.name[:-2])
@six.add_metaclass(ABCMeta)
class ModelDescBase(object):
......
......@@ -76,13 +76,13 @@ class Trainer(object):
"""
self._callbacks = []
self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors.
# Hacks!
if config is not None:
logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
......@@ -117,19 +117,6 @@ class Trainer(object):
else:
self._callbacks.append(cb)
def _register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self._monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._register_callback(mon)
def run_step(self):
"""
Defines what to do in one iteration. The default is:
......@@ -154,8 +141,10 @@ class Trainer(object):
self._register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
for m in monitors:
self._register_monitor(m)
self._register_callback(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
......@@ -283,8 +272,7 @@ class Trainer(object):
else:
logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers soon!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
......
......@@ -123,13 +123,13 @@ class SingleCostTrainer(TowerTrainer):
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
# TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self._register_callback(cb)
# TODO register directly instead of return?
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment