Commit 10cc1962 authored by Yuxin Wu's avatar Yuxin Wu

extract common utilities out of train/

parent a64d25cf
......@@ -366,7 +366,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'dump_chkpt_vars',
'VisualQA',
'huber_loss',
'DumpTensor'
'DumpTensor',
'StepTensorPrinter'
]:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -10,14 +10,15 @@ import tqdm
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import (
get_op_tensor_name, get_op_or_tensor_by_name)
get_op_tensor_name, get_op_or_tensor_by_name, get_global_step_var)
from .base import Callback
__all__ = ['StepTensorPrinter', 'ProgressBar']
__all__ = ['TensorPrinter', 'StepTensorPrinter', 'ProgressBar']
class StepTensorPrinter(Callback):
class TensorPrinter(Callback):
""" Prints the value of some tensors in each step.
It's an example of how ``before_run/after_run`` works.
"""
......@@ -44,6 +45,9 @@ class StepTensorPrinter(Callback):
logger.info("{}: {}".format(n, v))
StepTensorPrinter = TensorPrinter
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
......@@ -96,3 +100,37 @@ class ProgressBar(Callback):
def _after_train(self):
if self._bar: # training may get killed before the first step
self._bar.close()
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
"""
_chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with self.graph.colocate_with(gs_var):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
......@@ -35,4 +35,5 @@ for _, module_name, _ in iter_modules(
if module_name in _TO_IMPORT:
_global_import(module_name) # import the content to tfutils.*
__all__.extend(['sessinit', 'summary', 'optimizer',
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions'])
'sesscreate', 'gradproc', 'varreplace', 'symbolic_functions',
'distributed'])
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
def get_distributed_session_creator(server):
"""
Args:
server (tf.train.Server):
Returns:
tf.train.SessionCreator
"""
server_def = server.server_def
is_chief = (server_def.job_name == 'worker') and (server_def.task_index == 0)
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
class _Creator(tf.train.SessionCreator):
def create_session(self):
if is_chief:
return sm.prepare_session(master=server.target, init_op=init_op)
else:
return sm.wait_for_session(master=server.target)
return _Creator()
......@@ -10,15 +10,15 @@ import tensorflow as tf
from .config import TrainConfig
from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value, get_global_step_var
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
from ..graph_builder.predictor_factory import PredictorFactory
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
......@@ -30,40 +30,6 @@ class StopTraining(BaseException):
pass
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is always enabled by the trainer, you don't need to worry about it.
"""
chief_only = False
"""
In distributed training, we let each worker maintain its local global_step.
"""
def _setup_graph(self):
# ensure it exists
gs_var = get_global_step_var()
with tf.name_scope(None):
with self.graph.colocate_with(gs_var):
self.gs_incr_op = tf.assign_add(
gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self):
if self.global_step != 0:
logger.info("Start training with global_step={}".format(self.global_step))
def _before_run(self, _):
# always increase global_step when hooked_sess.run is called
return self._fetches
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
class Trainer(object):
""" Base class for a trainer.
......
......@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
import os
from ..utils import logger
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
......@@ -63,9 +63,6 @@ class DistributedTrainerReplicated(Trainer):
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data
......@@ -117,29 +114,7 @@ class DistributedTrainerReplicated(Trainer):
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.")
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
ready_op = tf.report_uninitialized_variables()
sm = tf.train.SessionManager(
local_init_op=local_init_op,
ready_op=ready_op, graph=tf.get_default_graph())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def _create_session():
if self.is_chief:
return sm.prepare_session(master=self.server.target, init_op=init_op)
else:
return sm.wait_for_session(master=self.server.target)
class _Creator(tf.train.SessionCreator):
def create_session(self):
return _create_session()
self.config.session_creator = _Creator()
self.config.session_creator = get_distributed_session_creator(self.server)
@property
def vs_name_for_predictor(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment