Commit f7fecef9 authored by Yuxin Wu's avatar Yuxin Wu

More permissive type casting when loading a model.

parent defa5c61
...@@ -79,15 +79,15 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'): ...@@ -79,15 +79,15 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
def do_evaluate(pred_config, output_file): def do_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS num_tower = max(cfg.TRAIN.NUM_GPUS, 1)
graph_funcs = MultiTowerOfflinePredictor( graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors() pred_config, list(range(num_tower))).get_predictors()
for dataset in cfg.DATA.VAL: for dataset in cfg.DATA.VAL:
logger.info("Evaluating {} ...".format(dataset)) logger.info("Evaluating {} ...".format(dataset))
dataflows = [ dataflows = [
get_eval_dataflow(dataset, shard=k, num_shards=num_gpu) get_eval_dataflow(dataset, shard=k, num_shards=num_tower)
for k in range(num_gpu)] for k in range(num_tower)]
all_results = multithread_predict_dataflow(dataflows, graph_funcs) all_results = multithread_predict_dataflow(dataflows, graph_funcs)
output = output_file + '-' + dataset output = output_file + '-' + dataset
DatasetRegistry.get(dataset).eval_inference_results(all_results, output) DatasetRegistry.get(dataset).eval_inference_results(all_results, output)
......
...@@ -75,25 +75,30 @@ class SessionUpdate(object): ...@@ -75,25 +75,30 @@ class SessionUpdate(object):
value.shape, varshape, name)) value.shape, varshape, name))
value = value.reshape(varshape) value = value.reshape(varshape)
# fix some common type incompatibility problems, but not all # Be permissive, and allow some common type incompatibility problems
def upcast(vartype, valtype): def allow_cast(to_type, from_type):
# vartype: a tf dtype # to_type: a tf dtype
# valtype: a numpy dtype # from_type: a numpy dtype
# allow up-casting from_type = tf.as_dtype(from_type)
if vartype == tf.float64 and valtype == np.float32:
return np.float64 # allow up/down casting between floating points
if vartype in [tf.int64, tf.int32] and valtype in [np.int32, np.int16, np.int8]: if from_type.is_floating and to_type.is_floating:
return np.int64 if vartype == tf.int64 else np.int32 return True
return None
if from_type.is_integer and to_type.is_integer:
# only allow up-casting between integers
if to_type.min <= from_type.min and to_type.max >= from_type.max:
return True
return False
if hasattr(value, 'dtype'): if hasattr(value, 'dtype'):
vartype = var.dtype.as_numpy_dtype vartype = var.dtype.as_numpy_dtype
if vartype != value.dtype: if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype) msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, var.dtype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype)
if newtype is not None: if allow_cast(var.dtype.base_dtype, value.dtype):
value = newtype(value) value = vartype(value)
logger.warn(msg + " Load it after casting!") logger.warn(msg + " The value will be loaded after casting!")
else: else:
assert vartype == value.dtype, msg assert vartype == value.dtype, msg
return value return value
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment