Commit fc079e63 authored by Yuxin Wu's avatar Yuxin Wu

name conflict in regularize. more defaults in trainer

parent db573204
......@@ -60,7 +60,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
to_regularize = []
with tf.name_scope('regularize_cost'):
with tf.name_scope(name + '_internals'):
costs = []
for p in params:
para_name = p.op.name
......
......@@ -7,6 +7,7 @@ import weakref
import time
from six.moves import range
import six
import sys
from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor,
......@@ -137,6 +138,10 @@ class Trainer(object):
def setup_callbacks(self, callbacks, monitors):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
Args:
callbacks ([Callback]):
monitors ([TrainingMonitor]):
"""
describe_trainable_vars() # TODO weird
......@@ -160,6 +165,10 @@ class Trainer(object):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
Args:
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
"""
session_init._setup_graph()
......@@ -181,9 +190,12 @@ class Trainer(object):
logger.info("Graph Finalized.")
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
def main_loop(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Run the main training loop.
Args:
steps_per_epoch, starting_epoch, max_epoch (int):
"""
with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
......@@ -223,7 +235,7 @@ class Trainer(object):
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
steps_per_epoch, starting_epoch=1, max_epoch=sys.maxint - 1):
"""
Implemented by:
......@@ -242,7 +254,7 @@ class Trainer(object):
def train_with_defaults(
self, callbacks=None, monitors=None,
session_creator=None, session_init=None,
steps_per_epoch=None, starting_epoch=1, max_epoch=9999):
steps_per_epoch=None, starting_epoch=1, max_epoch=sys.maxint - 1):
"""
Same as :meth:`train()`, but will:
......
......@@ -200,7 +200,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
if not isinstance(session_creator, NewSessionCreator) or \
session_creator.user_provided_config:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init)
......@@ -239,7 +239,7 @@ class HorovodTrainer(SingleCostTrainer):
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
"Cannot set session_creator for horovod training! ")
"session_creator has to be `NewSessionCreator` for horovod training! ")
session_creator.config.gpu_options.visible_device_list = str(self._local_rank)
super(HorovodTrainer, self).initialize(
session_creator, session_init)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment