Commit c44b65fc authored by Yuxin Wu's avatar Yuxin Wu

improve error messages in TrainConfig/PredictConfig type checking (#1029)

parent 2ce43d70
......@@ -19,18 +19,18 @@ The steps are:
1. train the model by
python export.py
python export-model.py
2. export the model by
python export.py --export serving --load train_log/export/checkpoint
python export.py --export compact --load train_log/export/checkpoint
python export-model.py --export serving --load train_log/export/checkpoint
python export-model.py --export compact --load train_log/export/checkpoint
3. run inference by
python export.py --apply default --load train_log/export/checkpoint
python export.py --apply inference_graph --load train_log/export/checkpoint
python export.py --apply compact --load /tmp/compact_graph.pb
python export-model.py --apply default --load train_log/export/checkpoint
python export-model.py --apply inference_graph --load train_log/export/checkpoint
python export-model.py --apply compact --load /tmp/compact_graph.pb
"""
......
......@@ -20,7 +20,7 @@ if __name__ == '__main__':
params = dict(np.load(fpath))
dic = {k: v.shape for k, v in six.iteritems(params)}
else:
path = get_checkpoint_path(sys.argv[1])
path = get_checkpoint_path(fpath)
reader = tf.train.NewCheckpointReader(path)
dic = reader.get_variable_to_shape_map()
pprint.pprint(dic)
......@@ -53,10 +53,13 @@ class PredictConfig(object):
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
def assert_type(v, tp, name):
assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
if model is not None:
assert_type(model, ModelDescBase)
assert_type(model, ModelDescBase, 'model')
assert inputs_desc is None and tower_func is None
self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
......@@ -70,7 +73,7 @@ class PredictConfig(object):
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit)
assert_type(self.session_init, SessionInit, 'session_init')
if session_creator is None:
self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config())
......@@ -82,13 +85,13 @@ class PredictConfig(object):
if self.input_names is None:
self.input_names = [k.name for k in self.inputs_desc]
self.output_names = output_names
assert_type(self.output_names, list)
assert_type(self.input_names, list)
assert_type(self.output_names, list, 'output_names')
assert_type(self.input_names, list, 'input_names')
if len(self.input_names) == 0:
logger.warn('PredictConfig receives empty "input_names".')
# assert len(self.input_names), self.input_names
for v in self.input_names:
assert_type(v, six.string_types)
assert_type(v, six.string_types, 'Each item in input_names')
assert len(self.output_names), self.output_names
self.return_input = bool(return_input)
......
......@@ -248,6 +248,7 @@ def get_model_loader(filename):
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
"""
assert isinstance(filename, six.string_types), filename
if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item())
......
......@@ -98,33 +98,35 @@ class TrainConfig(object):
"""
# TODO type checker decorator
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
def assert_type(v, tp, name):
assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
# process data & model
assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!"
if dataflow is not None:
assert_type(dataflow, DataFlow)
assert_type(dataflow, DataFlow, 'dataflow')
if data is not None:
assert_type(data, InputSource)
assert_type(data, InputSource, 'data')
self.dataflow = dataflow
self.data = data
if model is not None:
assert_type(model, ModelDescBase)
assert_type(model, ModelDescBase, 'model')
self.model = model
if callbacks is not None:
assert_type(callbacks, list)
assert_type(callbacks, list, 'callbacks')
self.callbacks = callbacks
if extra_callbacks is not None:
assert_type(extra_callbacks, list)
assert_type(extra_callbacks, list, 'extra_callbacks')
self.extra_callbacks = extra_callbacks
if monitors is not None:
assert_type(monitors, list)
assert_type(monitors, list, 'monitors')
self.monitors = monitors
if session_init is not None:
assert_type(session_init, SessionInit)
assert_type(session_init, SessionInit, 'session_init')
self.session_init = session_init
if session_creator is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment