Commit 00811100 authored by Yuxin Wu's avatar Yuxin Wu

check tf version in multigpu.

parent 4a88dfc3
......@@ -10,4 +10,4 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
os.environ['TF_AUTOTUNE_THRESHOLD'] = '3' # use more warm-up
os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
__version__ = '0.1.9'
__version__ = '0.2.0'
......@@ -10,7 +10,6 @@ from ..utils.argtools import graph_memoized
from ..utils.naming import GLOBAL_STEP_OP_NAME
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
'get_op_tensor_name',
......
......@@ -27,6 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerParameterServer']
def _check_tf_version():
ver = float('.'.join(tf.VERSION.split('.')[:2]))
assert ver >= 1.1, "TF version {} is too old to run multi GPU training!".format(ver)
def apply_prefetch_policy(config, use_stage=True):
if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow)
......@@ -55,6 +60,8 @@ class MultiGPUTrainerBase(Trainer):
List of outputs of ``func``, evaluated on each tower.
"""
logger.info("Training a model of {} tower".format(len(towers)))
if len(towers) > 1:
_check_tf_version()
ret = []
if devices is not None:
......
......@@ -3,7 +3,6 @@
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..predict import (OnlinePredictor,
PredictorTowerBuilder)
......@@ -34,9 +33,8 @@ class PredictorFactory(object):
an online predictor (which has to be used under a default session)
"""
tower = self.towers[tower]
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
# just ensure the tower exists. won't rebuild
self._tower_builder.build(tower)
# just ensure the tower exists. won't rebuild (memoized)
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment