Commit 148d7dd9 authored by Yuxin Wu's avatar Yuxin Wu

bug fix

parent c279dbfe
...@@ -20,7 +20,7 @@ class StartProcOrThread(Callback): ...@@ -20,7 +20,7 @@ class StartProcOrThread(Callback):
self._procs_threads = procs_threads self._procs_threads = procs_threads
def _before_train(self): def _before_train(self):
logger.info("Starting threads & procs: " + \ logger.info("Starting " + \
' .'.join([k.name for k in self._procs_threads])) ', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
...@@ -116,13 +116,15 @@ class ScaleGradient(MapGradient): ...@@ -116,13 +116,15 @@ class ScaleGradient(MapGradient):
""" """
Scale certain gradient by a multiplier Scale certain gradient by a multiplier
""" """
def __init__(self, multipliers): def __init__(self, multipliers, log=True):
""" """
:param multipliers: list of (regex, float) :param multipliers: list of (regex, float)
:param log: whether to do logging or not
""" """
if not isinstance(multipliers, list): if not isinstance(multipliers, list):
multipliers = [multipliers] multipliers = [multipliers]
self.multipliers = multipliers self.multipliers = multipliers
self._log = log
super(ScaleGradient, self).__init__(self._mapper) super(ScaleGradient, self).__init__(self._mapper)
def _mapper(self, grad, var): def _mapper(self, grad, var):
...@@ -133,6 +135,7 @@ class ScaleGradient(MapGradient): ...@@ -133,6 +135,7 @@ class ScaleGradient(MapGradient):
regex = regex + '$' regex = regex + '$'
if re.match(regex, varname): if re.match(regex, varname):
if self._log:
logger.info("Apply lr multiplier {} for {}".format(val, varname)) logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up if val != 0: # skip zero to speed up
return grad * val return grad * val
......
...@@ -15,6 +15,7 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -15,6 +15,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -94,6 +95,7 @@ class Trainer(object): ...@@ -94,6 +95,7 @@ class Trainer(object):
def setup(self): def setup(self):
self._setup() self._setup()
describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
......
...@@ -10,11 +10,10 @@ from six.moves import zip, range ...@@ -10,11 +10,10 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import * from ..utils.naming import *
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import QueueInputTrainer from .trainer import QueueInputTrainer
...@@ -89,7 +88,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -89,7 +88,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model()
# [debug]: do nothing in training # [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0] #self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
...@@ -101,7 +99,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -101,7 +99,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower))) if self.config.nr_tower > 1:
gradprocs.insert(0, ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False))
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list] grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
...@@ -109,7 +108,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -109,7 +108,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self.config.optimizer.apply_gradients( self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var()), grad_list[0], get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
describe_model()
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
......
...@@ -15,7 +15,6 @@ from ..utils import logger, SUMMARY_BACKUP_KEYS ...@@ -15,7 +15,6 @@ from ..utils import logger, SUMMARY_BACKUP_KEYS
from ..tfutils import (get_vars_by_names, freeze_collection, from ..tfutils import (get_vars_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
...@@ -81,7 +80,6 @@ class SimpleTrainer(Trainer): ...@@ -81,7 +80,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average()) summary_moving_average())
describe_model()
# create an infinte data producer # create an infinte data producer
self.config.dataset.reset_state() self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
...@@ -204,7 +202,6 @@ class QueueInputTrainer(Trainer): ...@@ -204,7 +202,6 @@ class QueueInputTrainer(Trainer):
grads = self._single_tower_grad() grads = self._single_tower_grad()
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
describe_model()
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment