Commit 035f597d authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 5a461be1
......@@ -20,3 +20,5 @@ Usage Questions, e.g.:
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things do I need to do?"
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
......@@ -16,35 +16,16 @@ If you think:
Then it is a good time to open an issue.
## How to dump/inspect a model
## How to print/dump intermediate results in training
When you enable `ModelSaver` as a callback,
trained models will be stored in TensorFlow checkpoint format, which typically includes a
`.data-xxxxx` file and a `.index` file. Both are necessary.
1. Learn `tf.Print`.
To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Please note that it
expects a model path without the extension.
2. Know [DumpTensors](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.DumpTensors[]),
[ProcessTensors](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.ProcessTensors) callbacks.
And it's also easy to write your own version of them.
You can dump a cleaner version of the model (without unnecessary variables), using
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy/npz format.
The script expects a metagraph file which is also saved by `ModelSaver`.
## How to load a model / do transfer learning
All model loading (in either training or testing) is through the `session_init` initializer
in `TrainConfig` or `PredictConfig`.
The common choices for this option are `SaverRestore` which restores a
TF checkpoint, or `DictRestore` which restores a dict. (`get_model_loader` is a small helper to
decide which one to use from a file name.)
Doing transfer learning is trivial.
Variable restoring is completely based on name match between
the current graph and the `SessionInit` initializer.
Therefore, if you want to load some model, just use the same variable name
so the old value will be loaded into the variable.
If you want to re-train some layer, just rename it.
Unmatched variables on both sides will be printed as a warning.
3. The [ProgressBar](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.ProgressBar)
callback can print some scalar statistics, though not enabled by default.
## How to freeze some variables in training
......
......@@ -43,6 +43,7 @@ User Tutorials
trainer
training-interface
callback
save-load
summary
faq
......
# Save and Load models
## Work with TF Checkpoint
The `ModelSaver` callback saves the model to `logger.get_logger_dir()`,
in TensorFlow checkpoint format.
One checkpoint typically includes a `.data-xxxxx` file and a `.index` file.
Both are necessary.
To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`.
For example, [scripts/ls-checkpoint.py](../scripts/ls-checkpoint.py)
uses it to print all variables and their shapes in a checkpoint.
[scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint.
It takes a metagraph file (which is also saved by `ModelSaver`) and only saves variables that the model needs at inference time.
It can dump the model to a `var-name: value` dict saved in npy/npz format.
## Load a Model
Model loading (in either training or testing) is through the `session_init` interface.
Currently there are two ways a session can be restored:
`session_init=SaverRestore(...)` which restores a
TF checkpoint, or `session_init=DictRestore(...)` which restores a dict.
(`get_model_loader` is a small helper to decide which one to use from a file name.)
Variable restoring is completely based on name match between
variables in the current graph and variables in the `session_init` initializer.
Variables that appear in only one side will be printed as warning.
## Transfer Learning
Therefore, transfer learning is trivial.
If you want to load some model, just use the same variable names.
If you want to re-train some layer, just rename it.
......@@ -72,6 +72,8 @@ class ProgressBar(Callback):
self._fetches = self.get_tensors_maybe_in_tower(self._names) or None
if self._fetches:
for t in self._fetches:
assert t.shape.ndims == 0, "ProgressBar can only print scalars, not {}".format(t)
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment