Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7b9d9b8c
Commit
7b9d9b8c
authored
Feb 21, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Do not require ALE if using DQN with gym only (fix #1091)
parent
e4aca035
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
10 deletions
+25
-10
docs/conf.py
docs/conf.py
+1
-0
docs/tutorial/save-load.md
docs/tutorial/save-load.md
+17
-7
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+1
-1
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+5
-1
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+1
-1
No files found.
docs/conf.py
View file @
7b9d9b8c
...
@@ -401,6 +401,7 @@ _DEPRECATED_NAMES = set([
...
@@ -401,6 +401,7 @@ _DEPRECATED_NAMES = set([
'average_grads'
,
'average_grads'
,
'aggregate_grads'
,
'aggregate_grads'
,
'allreduce_grads'
,
'allreduce_grads'
,
'get_checkpoint_path'
])
])
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
...
...
docs/tutorial/save-load.md
View file @
7b9d9b8c
...
@@ -10,8 +10,8 @@ Both are necessary.
...
@@ -10,8 +10,8 @@ Both are necessary.
`tf.train.NewCheckpointReader`
is the offical tool to parse TensorFlow checkpoint.
`tf.train.NewCheckpointReader`
is the offical tool to parse TensorFlow checkpoint.
Read
[
TF docs
](
https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader
)
for details.
Read
[
TF docs
](
https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader
)
for details.
Tensorpack also provides
some small tools to work with
checkpoints, see
Tensorpack also provides
a small tool to load
checkpoints, see
[
documentation
](
../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars
)
[
load_chkpt_vars
](
../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars
)
for details.
for details.
[
scripts/ls-checkpoint.py
](
../scripts/ls-checkpoint.py
)
[
scripts/ls-checkpoint.py
](
../scripts/ls-checkpoint.py
)
...
@@ -19,19 +19,28 @@ demos how to print all variables and their shapes in a checkpoint.
...
@@ -19,19 +19,28 @@ demos how to print all variables and their shapes in a checkpoint.
[
scripts/dump-model-params.py
](
../scripts/dump-model-params.py
)
can be used to remove unnecessary variables in a checkpoint.
[
scripts/dump-model-params.py
](
../scripts/dump-model-params.py
)
can be used to remove unnecessary variables in a checkpoint.
It takes a metagraph file (which is also saved by
`ModelSaver`
) and only saves variables that the model needs at inference time.
It takes a metagraph file (which is also saved by
`ModelSaver`
) and only saves variables that the model needs at inference time.
It
can dump
the model to a
`var-name: value`
dict saved in npz format.
It
dumps
the model to a
`var-name: value`
dict saved in npz format.
## Load a Model to a Session
## Load a Model to a Session
Model loading (in
either training or
inference) is through the
`session_init`
interface.
Model loading (in
both training and
inference) is through the
`session_init`
interface.
Currently there are two ways a session can be restored:
Currently there are two ways a session can be restored:
[
session_init=SaverRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
[
session_init=SaverRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
which restores a TF checkpoint,
which restores a TF checkpoint,
or
[
session_init=DictRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
which restores a dict.
or
[
session_init=DictRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
which restores a dict.
[
get_model_loader
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader
)
is a small helper to decide which one to use from a file name.
To load multiple models, use
[
ChainInit
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit
)
.
To load multiple models, use
[
ChainInit
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit
)
.
Many models in tensorpack model zoo are provided in the form of numpy dictionary (
`.npz`
),
because it is easier to load and manipulate without requiring TensorFlow.
To load such files to a session, use
`DictRestore(dict(np.load(filename)))`
.
You can also use
[
get_model_loader
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader
)
,
a small helper to create a
`SaverRestore`
or
`DictRestore`
based on the file name.
`DictRestore`
is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load a TF checkpoint into a dict in order to make changes, use
[
load_chkpt_vars
](
../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars
)
.
Variable restoring is completely based on __name match__ between
Variable restoring is completely based on __name match__ between
variables in the current graph and variables in the
`session_init`
initializer.
variables in the current graph and variables in the
`session_init`
initializer.
...
@@ -40,4 +49,5 @@ Variables that appear in only one side will be printed as warning.
...
@@ -40,4 +49,5 @@ Variables that appear in only one side will be printed as warning.
## Transfer Learning
## Transfer Learning
Therefore, transfer learning is trivial.
Therefore, transfer learning is trivial.
If you want to load a pre-trained model, just use the same variable names.
If you want to load a pre-trained model, just use the same variable names.
If you want to re-train some layer, just rename it.
If you want to re-train some layer, just rename either the variables in the
graph or the variables in your loader.
examples/DeepQNetwork/DQN.py
View file @
7b9d9b8c
...
@@ -12,7 +12,6 @@ import tensorflow as tf
...
@@ -12,7 +12,6 @@ import tensorflow as tf
from
tensorpack
import
*
from
tensorpack
import
*
from
atari
import
AtariPlayer
from
atari_wrapper
import
FireResetEnv
,
FrameStack
,
LimitLength
,
MapState
from
atari_wrapper
import
FireResetEnv
,
FrameStack
,
LimitLength
,
MapState
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
DQNModel
import
Model
as
DQNModel
from
DQNModel
import
Model
as
DQNModel
...
@@ -52,6 +51,7 @@ def get_player(viz=False, train=False):
...
@@ -52,6 +51,7 @@ def get_player(viz=False, train=False):
if
USE_GYM
:
if
USE_GYM
:
env
=
gym
.
make
(
ENV_NAME
)
env
=
gym
.
make
(
ENV_NAME
)
else
:
else
:
from
atari
import
AtariPlayer
env
=
AtariPlayer
(
ENV_NAME
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
env
=
AtariPlayer
(
ENV_NAME
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
env
=
FireResetEnv
(
env
)
env
=
FireResetEnv
(
env
)
...
...
examples/DeepQNetwork/expreplay.py
View file @
7b9d9b8c
...
@@ -31,7 +31,11 @@ class ReplayMemory(object):
...
@@ -31,7 +31,11 @@ class ReplayMemory(object):
self
.
_shape3d
=
(
state_shape
[
0
],
state_shape
[
1
],
self
.
_channel
*
(
history_len
+
1
))
self
.
_shape3d
=
(
state_shape
[
0
],
state_shape
[
1
],
self
.
_channel
*
(
history_len
+
1
))
self
.
history_len
=
int
(
history_len
)
self
.
history_len
=
int
(
history_len
)
self
.
state
=
np
.
zeros
((
self
.
max_size
,)
+
state_shape
,
dtype
=
'uint8'
)
state_shape
=
(
self
.
max_size
,)
+
state_shape
logger
.
info
(
"Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory."
.
format
(
np
.
prod
(
state_shape
)
/
1024.0
**
3
))
self
.
state
=
np
.
zeros
(
state_shape
,
dtype
=
'uint8'
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
...
...
examples/DoReFa-Net/README.md
View file @
7b9d9b8c
...
@@ -45,7 +45,7 @@ In this implementation, quantized operations are all performed through `tf.float
...
@@ -45,7 +45,7 @@ In this implementation, quantized operations are all performed through `tf.float
+
Look at the docstring in
`*-dorefa.py`
to see detailed usage and performance.
+
Look at the docstring in
`*-dorefa.py`
to see detailed usage and performance.
Pretrained model for (1,4,32)-ResNet18 and
(1,2,6)-
AlexNet are available at
Pretrained model for (1,4,32)-ResNet18 and
several
AlexNet are available at
[
tensorpack model zoo
](
http://models.tensorpack.com/DoReFa-Net/
)
.
[
tensorpack model zoo
](
http://models.tensorpack.com/DoReFa-Net/
)
.
They're provided in the format of numpy dictionary.
They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment