Commit ba965954 authored by Yuxin Wu's avatar Yuxin Wu

fix SaverRestoreRelaxed

parent 2aefaf78
......@@ -82,7 +82,7 @@ class SessionUpdate(object):
vartype = var.value().dtype
if vartype != val.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, val.dtype)
newtype = upcast(var.dtype, val.dtype)
newtype = upcast(var.dtype.base_dtype, val.dtype)
if newtype is not None:
val = newtype(val)
logger.warn(msg + " Load it after casting!")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment