Commit 148d7dd9 authored by Yuxin Wu's avatar Yuxin Wu

bug fix

parent c279dbfe
......@@ -20,7 +20,7 @@ class StartProcOrThread(Callback):
self._procs_threads = procs_threads
def _before_train(self):
logger.info("Starting threads & procs: " + \
' .'.join([k.name for k in self._procs_threads]))
logger.info("Starting " + \
', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads)
......@@ -116,13 +116,15 @@ class ScaleGradient(MapGradient):
"""
Scale certain gradient by a multiplier
"""
def __init__(self, multipliers):
def __init__(self, multipliers, log=True):
"""
:param multipliers: list of (regex, float)
:param log: whether to do logging or not
"""
if not isinstance(multipliers, list):
multipliers = [multipliers]
self.multipliers = multipliers
self._log = log
super(ScaleGradient, self).__init__(self._mapper)
def _mapper(self, grad, var):
......@@ -133,6 +135,7 @@ class ScaleGradient(MapGradient):
regex = regex + '$'
if re.match(regex, varname):
if self._log:
logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up
return grad * val
......
......@@ -15,6 +15,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_summary
__all__ = ['Trainer', 'StopTraining']
......@@ -94,6 +95,7 @@ class Trainer(object):
def setup(self):
self._setup()
describe_model()
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
......
......@@ -10,11 +10,10 @@ from six.moves import zip, range
from ..utils import logger
from ..utils.naming import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import QueueInputTrainer
......@@ -89,7 +88,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
describe_model()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
......@@ -101,7 +99,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower)))
if self.config.nr_tower > 1:
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread
......@@ -109,7 +108,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op')
describe_model()
self._start_async_threads(grad_list)
......
......@@ -15,7 +15,6 @@ from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors
......@@ -81,7 +80,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
# create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
......@@ -204,7 +202,6 @@ class QueueInputTrainer(Trainer):
grads = self._single_tower_grad()
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
describe_model()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment