Commit 6640f9bb authored by Yuxin Wu's avatar Yuxin Wu

support old checkpoint format.

parent b6df5567
......@@ -20,7 +20,7 @@ Tutorials are not fully finished. See some [examples](examples) to learn about t
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, Image to Image.
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, Image to Image.
### Speech / NLP:
......
......@@ -7,7 +7,7 @@ Training examples with __reproducible__ and meaningful performance.
+ [An illustrative mnist example with explanation of the framework](mnist-convnet.py)
+ [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [ResNet for ImageNet/Cifar10/SVHN](ResNet)
+ [Train ResNet for ImageNet/Cifar10/SVHN](ResNet)
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](HED)
......@@ -21,8 +21,8 @@ Training examples with __reproducible__ and meaningful performance.
+ [Deep Q-Network(DQN) variants on Atari games](DeepQNetwork)
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](A3C-Gym)
## Unsupervised:
+ [Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN](GAN)
## Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants](GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, Image to Image.
## Speech / NLP:
+ [LSTM-CTC for speech recognition](CTC-TIMIT)
......
......@@ -49,6 +49,32 @@ class NewSession(SessionInit):
sess.run(tf.global_variables_initializer())
class CheckpointReaderAdapter(object):
"""
An adapter to work around old checkpoint format, where the keys are op
names instead of tensor names (with :0).
"""
def __init__(self, reader):
self._reader = reader
m = self._reader.get_variable_to_shape_map()
self._map = {k if k.endswith(':0') else k + ':0': v
for k, v in m.iteritems()}
def get_variable_to_shape_map(self):
return self._map
def get_tensor(self, name):
if self._reader.has_tensor(name):
return self._reader.get_tensor(name)
if name in self._map:
assert name.endswith(':0'), name
name = name[:-2]
return self._reader.get_tensor(name)
def has_tensor(self, name):
return name in self._map
class SaverRestore(SessionInit):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
......@@ -92,6 +118,7 @@ class SaverRestore(SessionInit):
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
reader = CheckpointReaderAdapter(reader)
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith(PREDICT_TOWER):
......
......@@ -107,8 +107,11 @@ class MultiPredictorTowerTrainer(Trainer):
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
Args:
tower (int): return the kth predict_func
Returns:
an OnlinePredictor instance
"""
return self._predictor_factory.get_predictor(input_names, output_names, tower)
......
......@@ -48,7 +48,7 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
* x: execute ``sys.exit()``
* s: save image to "out.png"
"""
name = 'random_window_name'
name = 'tensorpack_viz_window'
cv2.imshow(name, img)
def mouse_cb(event, x, y, *args):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment