Commit b4085f77 authored by Yuxin Wu's avatar Yuxin Wu

update docs. remove _register_monitor

parent a21eb14f
...@@ -67,6 +67,11 @@ class InputDesc( ...@@ -67,6 +67,11 @@ class InputDesc(
return self._cached_placeholder return self._cached_placeholder
return self.build_placeholder() return self.build_placeholder()
@staticmethod
def from_tensor(t):
return InputDesc(
t.dtype, t.shape.as_list(), t.name[:-2])
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDescBase(object): class ModelDescBase(object):
......
...@@ -76,13 +76,13 @@ class Trainer(object): ...@@ -76,13 +76,13 @@ class Trainer(object):
""" """
self._callbacks = [] self._callbacks = []
self.loop = TrainLoop() self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors.
# Hacks! # Hacks!
if config is not None: if config is not None:
logger.warn("You're initializing new trainer with old trainer API!") logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.") logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!") logger.warn("It may work now through some hacks, but please switch to the new API!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
self._config = config self._config = config
self.inputs_desc = config.model.get_inputs_desc() self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper( self.tower_func = TowerFuncWrapper(
...@@ -117,19 +117,6 @@ class Trainer(object): ...@@ -117,19 +117,6 @@ class Trainer(object):
else: else:
self._callbacks.append(cb) self._callbacks.append(cb)
def _register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self._monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._register_callback(mon)
def run_step(self): def run_step(self):
""" """
Defines what to do in one iteration. The default is: Defines what to do in one iteration. The default is:
...@@ -154,8 +141,10 @@ class Trainer(object): ...@@ -154,8 +141,10 @@ class Trainer(object):
self._register_callback(MaintainStepCounter()) self._register_callback(MaintainStepCounter())
for cb in callbacks: for cb in callbacks:
self._register_callback(cb) self._register_callback(cb)
for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
for m in monitors: for m in monitors:
self._register_monitor(m) self._register_callback(m)
self.monitors = Monitors(monitors) self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback self._register_callback(self.monitors) # monitors is also a callback
...@@ -283,8 +272,7 @@ class Trainer(object): ...@@ -283,8 +272,7 @@ class Trainer(object):
else: else:
logger.warn("You're calling new trainers with old trainer API!") logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers soon!") logger.warn("Now it returns the old trainer for you, please switch to use new trainers soon!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to " logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs) return old_trainer(*args, **kwargs)
else: else:
return super(Trainer, cls).__new__(cls) return super(Trainer, cls).__new__(cls)
......
...@@ -123,13 +123,13 @@ class SingleCostTrainer(TowerTrainer): ...@@ -123,13 +123,13 @@ class SingleCostTrainer(TowerTrainer):
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn) self.set_tower_func(get_cost_fn)
# TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks: for cb in internal_callbacks:
self._register_callback(cb) self._register_callback(cb)
# TODO register directly instead of return?
@abstractmethod @abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment