Commit c1f8042d authored by Yuxin Wu's avatar Yuxin Wu

fix bug in 3d1a30ff

parent 4888d1ea
...@@ -61,5 +61,5 @@ except ImportError: ...@@ -61,5 +61,5 @@ except ImportError:
# These lines will be programatically read/write by setup.py # These lines will be programatically read/write by setup.py
# Don't touch them. # Don't touch them.
__version__ = '0.9.2' __version__ = '0.9.3'
__git_version__ = __version__ __git_version__ = __version__
...@@ -77,6 +77,8 @@ class SessionUpdate(object): ...@@ -77,6 +77,8 @@ class SessionUpdate(object):
# fix some common type incompatibility problems, but not all # fix some common type incompatibility problems, but not all
def upcast(vartype, valtype): def upcast(vartype, valtype):
# vartype: a tf dtype
# valtype: a numpy dtype
# allow up-casting # allow up-casting
if vartype == tf.float64 and valtype == np.float32: if vartype == tf.float64 and valtype == np.float32:
return np.float64 return np.float64
...@@ -85,7 +87,7 @@ class SessionUpdate(object): ...@@ -85,7 +87,7 @@ class SessionUpdate(object):
return None return None
if hasattr(value, 'dtype'): if hasattr(value, 'dtype'):
vartype = var.dtype vartype = var.dtype.as_numpy_dtype
if vartype != value.dtype: if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype) msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype) newtype = upcast(var.dtype.base_dtype, value.dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment