Commit 708e07b0 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 443cb84d
...@@ -380,6 +380,7 @@ _DEPRECATED_NAMES = set([ ...@@ -380,6 +380,7 @@ _DEPRECATED_NAMES = set([
'prediction_incorrect', 'huber_loss', 'prediction_incorrect', 'huber_loss',
# internal only # internal only
'SessionUpdate',
'apply_default_prefetch', 'apply_default_prefetch',
'average_grads', 'average_grads',
'aggregate_grads', 'aggregate_grads',
......
...@@ -21,10 +21,14 @@ You can use this predicate to choose a different code path in inference mode. ...@@ -21,10 +21,14 @@ You can use this predicate to choose a different code path in inference mode.
## Inference After Training ## Inference After Training
Tensorpack is a training interface -- it doesn't care what happened after training. Tensorpack is a training interface -- __it doesn't care what happened after training__.
It saves models to standard checkpoint format. You have everything needed for inference or model diagnosis after
So you can build the graph for inference, load the checkpoint, and then use whatever deployment methods TensorFlow supports. training:
But you'll need to read TF docs and __do it on your own__. 1. The trained weights: tensorpack saves them in standard TF checkpoint format.
2. The model: you've already written it yourself with TF symbolic functions.
Therefore, you can build the graph for inference, load the checkpoint, and then use whatever deployment methods TensorFlow supports.
And you'll need to read TF docs and __do it on your own__.
### Don't Use Training Metagraph for Inference ### Don't Use Training Metagraph for Inference
......
...@@ -8,8 +8,11 @@ in TensorFlow checkpoint format. ...@@ -8,8 +8,11 @@ in TensorFlow checkpoint format.
A TF checkpoint typically includes a `.data-xxxxx` file and a `.index` file. A TF checkpoint typically includes a `.data-xxxxx` file and a `.index` file.
Both are necessary. Both are necessary.
`tf.train.NewCheckpointReader` is the best tool to parse TensorFlow checkpoint. `tf.train.NewCheckpointReader` is the offical tool to parse TensorFlow checkpoint.
We have two example scripts to demo its usage, but read [TF docs](https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader) for details. Read [TF docs](https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader) for details.
Tensorpack also provides some small tools to work with checkpoints, see
[documentation](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars)
for details.
[scripts/ls-checkpoint.py](../scripts/ls-checkpoint.py) [scripts/ls-checkpoint.py](../scripts/ls-checkpoint.py)
demos how to print all variables and their shapes in a checkpoint. demos how to print all variables and their shapes in a checkpoint.
......
This diff is collapsed.
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: varmanip.py # File: varmanip.py
import six import six
import os import os
import pprint import pprint
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from ..utils.develop import deprecated
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
...@@ -50,7 +48,7 @@ class SessionUpdate(object): ...@@ -50,7 +48,7 @@ class SessionUpdate(object):
@staticmethod @staticmethod
def load_value_to_var(var, val, strict=False): def load_value_to_var(var, val, strict=False):
""" """
Call `var.load(val)` with the default session. Call `var.load(val)` with the default session, with some type checks.
Args: Args:
var (tf.Variable): var (tf.Variable):
...@@ -111,7 +109,7 @@ class SessionUpdate(object): ...@@ -111,7 +109,7 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npz format (loadable by :class:`DictRestore`). npz format (loadable by :func:`sessinit.get_model_loader`).
Args: Args:
path(str): the file name to save the parameters. Must ends with npz. path(str): the file name to save the parameters. Must ends with npz.
...@@ -203,11 +201,6 @@ def load_chkpt_vars(model_path): ...@@ -203,11 +201,6 @@ def load_chkpt_vars(model_path):
return result return result
@deprecated("Renamed to 'load_chkpt_vars!'", "2018-04-20")
def dump_chkpt_vars(model_path):
return load_chkpt_vars(model_path)
def is_training_name(name): def is_training_name(name):
""" """
**Guess** if this variable is only used in training. **Guess** if this variable is only used in training.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment