Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6640f9bb
Commit
6640f9bb
authored
Feb 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support old checkpoint format.
parent
b6df5567
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
7 deletions
+37
-7
README.md
README.md
+1
-1
examples/README.md
examples/README.md
+3
-3
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+27
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+5
-2
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+1
-1
No files found.
README.md
View file @
6640f9bb
...
...
@@ -20,7 +20,7 @@ Tutorials are not fully finished. See some [examples](examples) to learn about t
+
[
Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym
](
examples/A3C-Gym
)
### Unsupervised Learning:
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, Image to Image.
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN,
WGAN,
Image to Image.
### Speech / NLP:
...
...
examples/README.md
View file @
6640f9bb
...
...
@@ -7,7 +7,7 @@ Training examples with __reproducible__ and meaningful performance.
+
[
An illustrative mnist example with explanation of the framework
](
mnist-convnet.py
)
+
[
A tiny SVHN ConvNet with 97.8% accuracy
](
svhn-digit-convnet.py
)
+
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
+
[
ResNet for ImageNet/Cifar10/SVHN
](
ResNet
)
+
[
Train
ResNet for ImageNet/Cifar10/SVHN
](
ResNet
)
+
[
Inception-BN with 71% accuracy
](
Inception/inception-bn.py
)
+
[
InceptionV3 with 74% accuracy (similar to the official code)
](
Inception/inceptionv3.py
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
...
...
@@ -21,8 +21,8 @@ Training examples with __reproducible__ and meaningful performance.
+
[
Deep Q-Network(DQN) variants on Atari games
](
DeepQNetwork
)
+
[
Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym
](
A3C-Gym
)
## Unsupervised:
+
[
Generative Adversarial Network(GAN) variants
, including DCGAN, Image2Image, InfoGAN
](
GAN
)
## Unsupervised
Learning
:
+
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, Image to Image.
## Speech / NLP:
+
[
LSTM-CTC for speech recognition
](
CTC-TIMIT
)
...
...
tensorpack/tfutils/sessinit.py
View file @
6640f9bb
...
...
@@ -49,6 +49,32 @@ class NewSession(SessionInit):
sess
.
run
(
tf
.
global_variables_initializer
())
class
CheckpointReaderAdapter
(
object
):
"""
An adapter to work around old checkpoint format, where the keys are op
names instead of tensor names (with :0).
"""
def
__init__
(
self
,
reader
):
self
.
_reader
=
reader
m
=
self
.
_reader
.
get_variable_to_shape_map
()
self
.
_map
=
{
k
if
k
.
endswith
(
':0'
)
else
k
+
':0'
:
v
for
k
,
v
in
m
.
iteritems
()}
def
get_variable_to_shape_map
(
self
):
return
self
.
_map
def
get_tensor
(
self
,
name
):
if
self
.
_reader
.
has_tensor
(
name
):
return
self
.
_reader
.
get_tensor
(
name
)
if
name
in
self
.
_map
:
assert
name
.
endswith
(
':0'
),
name
name
=
name
[:
-
2
]
return
self
.
_reader
.
get_tensor
(
name
)
def
has_tensor
(
self
,
name
):
return
name
in
self
.
_map
class
SaverRestore
(
SessionInit
):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
...
...
@@ -92,6 +118,7 @@ class SaverRestore(SessionInit):
def
_read_checkpoint_vars
(
model_path
):
""" return a set of strings """
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
reader
=
CheckpointReaderAdapter
(
reader
)
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
if
v
.
startswith
(
PREDICT_TOWER
):
...
...
tensorpack/train/trainer.py
View file @
6640f9bb
...
...
@@ -107,8 +107,11 @@ class MultiPredictorTowerTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
Args:
tower (int): return the kth predict_func
Returns:
an OnlinePredictor instance
"""
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
...
...
tensorpack/utils/viz.py
View file @
6640f9bb
...
...
@@ -48,7 +48,7 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
* x: execute ``sys.exit()``
* s: save image to "out.png"
"""
name
=
'
random_window_name
'
name
=
'
tensorpack_viz_window
'
cv2
.
imshow
(
name
,
img
)
def
mouse_cb
(
event
,
x
,
y
,
*
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment