Commit 10cc1962 authored by Yuxin Wu's avatar Yuxin Wu

extract common utilities out of train/

parent a64d25cf
...@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA', 'VisualQA',
'huber_loss', 'huber_loss',
'DumpTensor' 'DumpTensor',
'StepTensorPrinter'
]: ]:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
...@@ -10,14 +10,15 @@ import tqdm ...@@ -10,14 +10,15 @@ import tqdm
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import ( from ..tfutils.common import (
get_op_tensor_name, get_op_or_tensor_by_name) get_op_tensor_name, get_op_or_tensor_by_name, get_global_step_var)
from .base import Callback from .base import Callback
__all__ = ['StepTensorPrinter', 'ProgressBar'] __all__ = ['TensorPrinter', 'StepTensorPrinter', 'ProgressBar']
class StepTensorPrinter(Callback): class TensorPrinter(Callback):
""" Prints the value of some tensors in each step. """ Prints the value of some tensors in each step.
It's an example of how ``before_run/after_run`` works. It's an example of how ``before_run/after_run`` works.
""" """
...@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback): ...@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
StepTensorPrinter = TensorPrinter
class ProgressBar(Callback): class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """ """ A progress bar based on tqdm. Enabled by default. """
...@@ -96,3 +100,37 @@ class ProgressBar(Callback): ...@@ -96,3 +100,37 @@ class ProgressBar(Callback):
def _after_train(self): def _after_train(self):
if self._bar: # training may get killed before the first step if self._bar: # training may get killed before the first step
self._bar.close() self._bar.close()
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
"""
_chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with self.graph.colocate_with(gs_var):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
...@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules( ...@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules(
if module_name in _TO_IMPORT: if module_name in _TO_IMPORT:
_global_import(module_name) # import the content to tfutils.* _global_import(module_name) # import the content to tfutils.*
__all__.extend(['sessinit', 'summary', 'optimizer', __all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions']) 'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions',
'distributed'])
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
def get_distributed_session_creator(server):
"""
Args:
server (tf.train.Server):
Returns:
tf.train.SessionCreator
"""
server_def = server.server_def
is_chief = (server_def.job_name == 'worker') and (server_def.task_index == 0)
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
class _Creator(tf.train.SessionCreator):
def create_session(self):
if is_chief:
return sm.prepare_session(master=server.target, init_op=init_op)
else:
return sm.wait_for_session(master=server.target)
return _Creator()
...@@ -10,15 +10,15 @@ import tensorflow as tf ...@@ -10,15 +10,15 @@ import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..callbacks import Callback, Callbacks from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value, get_global_step_var from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory from ..graph_builder.predictor_factory import PredictorFactory
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -30,40 +30,6 @@ class StopTraining(BaseException): ...@@ -30,40 +30,6 @@ class StopTraining(BaseException):
pass pass
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it.
"""
chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with self.graph.colocate_with(gs_var):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: distributed.py # File: distributed.py
import tensorflow as tf
import os import os
from ..utils import logger from ..utils import logger
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..graph_builder.distributed import DistributedReplicatedBuilder from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable from ..graph_builder.utils import override_to_local_variable
...@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer): ...@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer):
if self.job_name == 'worker': if self.job_name == 'worker':
# ps doesn't build any graph # ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server) self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster)) logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data self._input_source = config.data
...@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer):
"Cannot set session_creator or session_config for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.") "To use a custom session config, pass it with tf.train.Server.")
init_op = tf.global_variables_initializer() self.config.session_creator = get_distributed_session_creator(self.server)
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
else:
return sm.wait_for_session(master=self.server.target)
class _Creator(tf.train.SessionCreator):
def create_session(self):
return _create_session()
self.config.session_creator = _Creator()
@property @property
def vs_name_for_predictor(self): def vs_name_for_predictor(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment