Commit 00811100 authored by Yuxin Wu's avatar Yuxin Wu

check tf version in multigpu.

parent 4a88dfc3
...@@ -10,4 +10,4 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339 ...@@ -10,4 +10,4 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
os.environ['TF_AUTOTUNE_THRESHOLD'] = '3' # use more warm-up os.environ['TF_AUTOTUNE_THRESHOLD'] = '3' # use more warm-up
os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566 os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
__version__ = '0.1.9' __version__ = '0.2.0'
...@@ -10,7 +10,6 @@ from ..utils.argtools import graph_memoized ...@@ -10,7 +10,6 @@ from ..utils.argtools import graph_memoized
from ..utils.naming import GLOBAL_STEP_OP_NAME from ..utils.naming import GLOBAL_STEP_OP_NAME
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
......
...@@ -27,6 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', ...@@ -27,6 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerParameterServer'] 'SyncMultiGPUTrainerParameterServer']
def _check_tf_version():
ver = float('.'.join(tf.VERSION.split('.')[:2]))
assert ver >= 1.1, "TF version {} is too old to run multi GPU training!".format(ver)
def apply_prefetch_policy(config, use_stage=True): def apply_prefetch_policy(config, use_stage=True):
if config.data is None and config.dataflow is not None: if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow) config.data = QueueInput(config.dataflow)
...@@ -55,6 +60,8 @@ class MultiGPUTrainerBase(Trainer): ...@@ -55,6 +60,8 @@ class MultiGPUTrainerBase(Trainer):
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
""" """
logger.info("Training a model of {} tower".format(len(towers))) logger.info("Training a model of {} tower".format(len(towers)))
if len(towers) > 1:
_check_tf_version()
ret = [] ret = []
if devices is not None: if devices is not None:
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# File: predict.py # File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor, from ..predict import (OnlinePredictor,
PredictorTowerBuilder) PredictorTowerBuilder)
...@@ -34,9 +33,8 @@ class PredictorFactory(object): ...@@ -34,9 +33,8 @@ class PredictorFactory(object):
an online predictor (which has to be used under a default session) an online predictor (which has to be used under a default session)
""" """
tower = self.towers[tower] tower = self.towers[tower]
with tf.variable_scope(tf.get_variable_scope(), reuse=True): # just ensure the tower exists. won't rebuild (memoized)
# just ensure the tower exists. won't rebuild self._tower_builder.build(tower)
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment