Commit c44b65fc authored by Yuxin Wu's avatar Yuxin Wu

improve error messages in TrainConfig/PredictConfig type checking (#1029)

parent 2ce43d70
...@@ -19,18 +19,18 @@ The steps are: ...@@ -19,18 +19,18 @@ The steps are:
1. train the model by 1. train the model by
python export.py python export-model.py
2. export the model by 2. export the model by
python export.py --export serving --load train_log/export/checkpoint python export-model.py --export serving --load train_log/export/checkpoint
python export.py --export compact --load train_log/export/checkpoint python export-model.py --export compact --load train_log/export/checkpoint
3. run inference by 3. run inference by
python export.py --apply default --load train_log/export/checkpoint python export-model.py --apply default --load train_log/export/checkpoint
python export.py --apply inference_graph --load train_log/export/checkpoint python export-model.py --apply inference_graph --load train_log/export/checkpoint
python export.py --apply compact --load /tmp/compact_graph.pb python export-model.py --apply compact --load /tmp/compact_graph.pb
""" """
......
...@@ -20,7 +20,7 @@ if __name__ == '__main__': ...@@ -20,7 +20,7 @@ if __name__ == '__main__':
params = dict(np.load(fpath)) params = dict(np.load(fpath))
dic = {k: v.shape for k, v in six.iteritems(params)} dic = {k: v.shape for k, v in six.iteritems(params)}
else: else:
path = get_checkpoint_path(sys.argv[1]) path = get_checkpoint_path(fpath)
reader = tf.train.NewCheckpointReader(path) reader = tf.train.NewCheckpointReader(path)
dic = reader.get_variable_to_shape_map() dic = reader.get_variable_to_shape_map()
pprint.pprint(dic) pprint.pprint(dic)
...@@ -53,10 +53,13 @@ class PredictConfig(object): ...@@ -53,10 +53,13 @@ class PredictConfig(object):
create_graph (bool): create a new graph, or use the default graph create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized. when predictor is first initialized.
""" """
def assert_type(v, tp): def assert_type(v, tp, name):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
if model is not None: if model is not None:
assert_type(model, ModelDescBase) assert_type(model, ModelDescBase, 'model')
assert inputs_desc is None and tower_func is None assert inputs_desc is None and tower_func is None
self.inputs_desc = model.get_inputs_desc() self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
...@@ -70,7 +73,7 @@ class PredictConfig(object): ...@@ -70,7 +73,7 @@ class PredictConfig(object):
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
self.session_init = session_init self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit, 'session_init')
if session_creator is None: if session_creator is None:
self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config()) self.session_creator = tf.train.ChiefSessionCreator(config=get_default_sess_config())
...@@ -82,13 +85,13 @@ class PredictConfig(object): ...@@ -82,13 +85,13 @@ class PredictConfig(object):
if self.input_names is None: if self.input_names is None:
self.input_names = [k.name for k in self.inputs_desc] self.input_names = [k.name for k in self.inputs_desc]
self.output_names = output_names self.output_names = output_names
assert_type(self.output_names, list) assert_type(self.output_names, list, 'output_names')
assert_type(self.input_names, list) assert_type(self.input_names, list, 'input_names')
if len(self.input_names) == 0: if len(self.input_names) == 0:
logger.warn('PredictConfig receives empty "input_names".') logger.warn('PredictConfig receives empty "input_names".')
# assert len(self.input_names), self.input_names # assert len(self.input_names), self.input_names
for v in self.input_names: for v in self.input_names:
assert_type(v, six.string_types) assert_type(v, six.string_types, 'Each item in input_names')
assert len(self.output_names), self.output_names assert len(self.output_names), self.output_names
self.return_input = bool(return_input) self.return_input = bool(return_input)
......
...@@ -248,6 +248,7 @@ def get_model_loader(filename): ...@@ -248,6 +248,7 @@ def get_model_loader(filename):
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
assert isinstance(filename, six.string_types), filename
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
......
...@@ -98,33 +98,35 @@ class TrainConfig(object): ...@@ -98,33 +98,35 @@ class TrainConfig(object):
""" """
# TODO type checker decorator # TODO type checker decorator
def assert_type(v, tp): def assert_type(v, tp, name):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), \
"{} has to be type '{}', but an object of type '{}' found.".format(
name, tp.__name__, v.__class__.__name__)
# process data & model # process data & model
assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!" assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!"
if dataflow is not None: if dataflow is not None:
assert_type(dataflow, DataFlow) assert_type(dataflow, DataFlow, 'dataflow')
if data is not None: if data is not None:
assert_type(data, InputSource) assert_type(data, InputSource, 'data')
self.dataflow = dataflow self.dataflow = dataflow
self.data = data self.data = data
if model is not None: if model is not None:
assert_type(model, ModelDescBase) assert_type(model, ModelDescBase, 'model')
self.model = model self.model = model
if callbacks is not None: if callbacks is not None:
assert_type(callbacks, list) assert_type(callbacks, list, 'callbacks')
self.callbacks = callbacks self.callbacks = callbacks
if extra_callbacks is not None: if extra_callbacks is not None:
assert_type(extra_callbacks, list) assert_type(extra_callbacks, list, 'extra_callbacks')
self.extra_callbacks = extra_callbacks self.extra_callbacks = extra_callbacks
if monitors is not None: if monitors is not None:
assert_type(monitors, list) assert_type(monitors, list, 'monitors')
self.monitors = monitors self.monitors = monitors
if session_init is not None: if session_init is not None:
assert_type(session_init, SessionInit) assert_type(session_init, SessionInit, 'session_init')
self.session_init = session_init self.session_init = session_init
if session_creator is None: if session_creator is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment