Commit 1b98fe58 authored by Philipp Werner's avatar Philipp Werner Committed by GitHub

Improve compatibility to tensorflow 2.3 (#1487)

* fix gfile not found error occurring with tensorflow 2.3

* fix as_list() not found error occurring with tensorflow 2.3
Co-authored-by: default avatarPhilipp Werner <pw_post@gmx.de>
parent a12872dc
...@@ -33,13 +33,21 @@ def build_or_reuse_placeholder(tensor_spec): ...@@ -33,13 +33,21 @@ def build_or_reuse_placeholder(tensor_spec):
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name) assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \ assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor) "Tensor {} exists but is not compatible with the signature!".format(tensor)
if tensor.shape.as_list() == tensor_spec.shape.as_list():
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
# Comparing `tensor.shape` directly doesn't work, because # It might be desirable to use a placeholder of a different shape in some tower
# tensorflow thinks `tf.Dimension(None)` and `tf.Dimension(None)` are not equal. # (e.g., a less specific shape)
return tensor try:
if tensor.shape.as_list() == tensor_spec.shape.as_list():
# Comparing `tensor.shape` directly doesn't work in older versions of tensorflow,
# because tensorflow thinks `tf.Dimension(None)` and `tf.Dimension(None)` are not
# equal. Newer versions of tensorflow, e.g. 2.3, do not support as_list() for
# `tf.Dimension(None)` and raise a `ValueError`
return tensor
except ValueError:
if tensor.shape == tensor_spec.shape:
# With the newer version of tensorflow, comparing `tensor.shape` directly seems
# to work fine.
return tensor
except KeyError: except KeyError:
pass pass
with tfv1.name_scope(None): # clear any name scope it might get called in with tfv1.name_scope(None): # clear any name scope it might get called in
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# File: config.py # File: config.py
import os import os
import tensorflow as tf
from ..compat import tfv1
from ..callbacks import ( from ..callbacks import (
JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter) JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
...@@ -237,6 +237,6 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -237,6 +237,6 @@ class AutoResumeTrainConfig(TrainConfig):
if not dir: if not dir:
return None return None
path = os.path.join(dir, 'checkpoint') path = os.path.join(dir, 'checkpoint')
if not tf.gfile.Exists(path): if not tfv1.gfile.Exists(path):
return None return None
return SaverRestore(path) return SaverRestore(path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment