Commit c1f8042d authored by Yuxin Wu's avatar Yuxin Wu

fix bug in 3d1a30ff

parent 4888d1ea
......@@ -61,5 +61,5 @@ except ImportError:
# These lines will be programatically read/write by setup.py
# Don't touch them.
__version__ = '0.9.2'
__version__ = '0.9.3'
__git_version__ = __version__
......@@ -77,6 +77,8 @@ class SessionUpdate(object):
# fix some common type incompatibility problems, but not all
def upcast(vartype, valtype):
# vartype: a tf dtype
# valtype: a numpy dtype
# allow up-casting
if vartype == tf.float64 and valtype == np.float32:
return np.float64
......@@ -85,7 +87,7 @@ class SessionUpdate(object):
return None
if hasattr(value, 'dtype'):
vartype = var.dtype
vartype = var.dtype.as_numpy_dtype
if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment