Commit fc079e63 authored by Yuxin Wu's avatar Yuxin Wu

name conflict in regularize. more defaults in trainer

parent db573204
...@@ -60,7 +60,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -60,7 +60,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
to_regularize = [] to_regularize = []
with tf.name_scope('regularize_cost'): with tf.name_scope(name + '_internals'):
costs = [] costs = []
for p in params: for p in params:
para_name = p.op.name para_name = p.op.name
......
...@@ -7,6 +7,7 @@ import weakref ...@@ -7,6 +7,7 @@ import weakref
import time import time
from six.moves import range from six.moves import range
import six import six
import sys
from ..callbacks import ( from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor, Callback, Callbacks, Monitors, TrainingMonitor,
...@@ -137,6 +138,10 @@ class Trainer(object): ...@@ -137,6 +138,10 @@ class Trainer(object):
def setup_callbacks(self, callbacks, monitors): def setup_callbacks(self, callbacks, monitors):
""" """
Setup callbacks and monitors. Must be called after the main graph is built. Setup callbacks and monitors. Must be called after the main graph is built.
Args:
callbacks ([Callback]):
monitors ([TrainingMonitor]):
""" """
describe_trainable_vars() # TODO weird describe_trainable_vars() # TODO weird
...@@ -160,6 +165,10 @@ class Trainer(object): ...@@ -160,6 +165,10 @@ class Trainer(object):
""" """
Initialize self.sess and self.hooked_sess. Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup. Must be called after callbacks are setup.
Args:
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
""" """
session_init._setup_graph() session_init._setup_graph()
...@@ -181,9 +190,12 @@ class Trainer(object): ...@@ -181,9 +190,12 @@ class Trainer(object):
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
@call_only_once @call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999): def main_loop(self, steps_per_epoch, starting_epoch, max_epoch):
""" """
Run the main training loop. Run the main training loop.
Args:
steps_per_epoch, starting_epoch, max_epoch (int):
""" """
with self.sess.as_default(): with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch) self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
...@@ -223,7 +235,7 @@ class Trainer(object): ...@@ -223,7 +235,7 @@ class Trainer(object):
def train(self, def train(self,
callbacks, monitors, callbacks, monitors,
session_creator, session_init, session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch): steps_per_epoch, starting_epoch=1, max_epoch=sys.maxint - 1):
""" """
Implemented by: Implemented by:
...@@ -242,7 +254,7 @@ class Trainer(object): ...@@ -242,7 +254,7 @@ class Trainer(object):
def train_with_defaults( def train_with_defaults(
self, callbacks=None, monitors=None, self, callbacks=None, monitors=None,
session_creator=None, session_init=None, session_creator=None, session_init=None,
steps_per_epoch=None, starting_epoch=1, max_epoch=9999): steps_per_epoch=None, starting_epoch=1, max_epoch=sys.maxint - 1):
""" """
Same as :meth:`train()`, but will: Same as :meth:`train()`, but will:
......
...@@ -200,7 +200,7 @@ class DistributedTrainerReplicated(SingleCostTrainer): ...@@ -200,7 +200,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
if not isinstance(session_creator, NewSessionCreator) or \ if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config: session_creator.user_provided_config:
raise ValueError( raise ValueError(
"Cannot set session_creator or session_config for distributed training! " "You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.") "To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize( super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init) get_distributed_session_creator(), session_init)
...@@ -239,7 +239,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -239,7 +239,7 @@ class HorovodTrainer(SingleCostTrainer):
def initialize(self, session_creator, session_init): def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator):
raise ValueError( raise ValueError(
"Cannot set session_creator for horovod training! ") "session_creator has to be `NewSessionCreator` for horovod training! ")
session_creator.config.gpu_options.visible_device_list = str(self._local_rank) session_creator.config.gpu_options.visible_device_list = str(self._local_rank)
super(HorovodTrainer, self).initialize( super(HorovodTrainer, self).initialize(
session_creator, session_init) session_creator, session_init)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment