Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
61384a65
Commit
61384a65
authored
Nov 13, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small bugfix
parent
b95ea88f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
18 additions
and
17 deletions
+18
-17
README.md
README.md
+1
-1
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+6
-4
examples/Atari2600/README.md
examples/Atari2600/README.md
+5
-4
examples/HED/hed.py
examples/HED/hed.py
+1
-3
examples/README.md
examples/README.md
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+3
-3
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-1
No files found.
README.md
View file @
61384a65
...
...
@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+
[
DoReFa-Net: training binary / low bitwidth CNN
](
examples/DoReFa-Net
)
+
[
InceptionV3 on ImageNet
](
examples/Inception/inceptionv3.py
)
+
[
ResNet for ImageNet/Cifar10 classification
](
examples/ResNet
)
+
[
ResNet for ImageNet/Cifar10
/SVHN
classification
](
examples/ResNet
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection
](
examples/HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
examples/SpatialTransformer
)
+
[
Double DQN plays Atari games
](
examples/Atari2600
)
...
...
examples/Atari2600/DQN.py
View file @
61384a65
...
...
@@ -23,9 +23,6 @@ import common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
METHOD
=
[
'DQN'
,
'Double'
,
'Dueling'
][
1
]
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
...
...
@@ -48,6 +45,7 @@ EVAL_EPISODE = 50
NUM_ACTIONS
=
None
ROM_FILE
=
None
METHOD
=
None
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
...
...
@@ -123,7 +121,8 @@ class Model(ModelDesc):
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
self
.
cost
=
tf
.
truediv
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
BATCH_SIZE
,
name
=
'cost'
)
self
.
cost
=
tf
.
truediv
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
tf
.
cast
(
BATCH_SIZE
,
tf
.
float32
),
name
=
'cost'
)
summary
.
add_param_summary
([(
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
])
])
# monitor all W
...
...
@@ -188,6 +187,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
choices
=
[
'play'
,
'eval'
,
'train'
],
default
=
'train'
)
parser
.
add_argument
(
'--rom'
,
help
=
'atari rom'
,
required
=
True
)
parser
.
add_argument
(
'--algo'
,
help
=
'algorithm'
,
choices
=
[
'DQN'
,
'Double'
,
'Dueling'
],
default
=
'Double'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
...
...
@@ -195,6 +196,7 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
ROM_FILE
=
args
.
rom
METHOD
=
args
.
algo
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
...
...
examples/Atari2600/README.md
View file @
61384a65
...
...
@@ -2,7 +2,7 @@
[
video demo
](
https://youtu.be/o21mddZtE5Y
)
Reproduce the following reinforcement learning
paper
s:
Reproduce the following reinforcement learning
method
s:
+
Nature-DQN in:
[
Human-level Control Through Deep Reinforcement Learning
](
http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html
)
...
...
@@ -29,12 +29,13 @@ D-DQN is faster at the beginning but will converge to 12it/s due of exploration
## How to use
Download
[
atari roms
](
https://github.com/openai/atari-py/tree/master/atari_py/atari_roms
)
to
`$TENSORPACK_DATASET/atari_rom
`
(defaults to tensorpack/dataflow/dataset/atari_rom
).
Download
an
[
atari rom
](
https://github.com/openai/atari-py/tree/master/atari_py/atari_roms
)
to
`$TENSORPACK_DATASET/atari_rom
/`
(defaults to tensorpack/dataflow/dataset/atari_rom/
).
To train:
```
./DQN.py --rom breakout.bin --gpu 0
./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms
```
To visualize the agent:
...
...
examples/HED/hed.py
View file @
61384a65
...
...
@@ -160,7 +160,6 @@ def get_config():
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
*
40
dataset_val
=
get_data
(
'val'
)
#dataset_test = get_data('test')
lr
=
tf
.
Variable
(
3e-5
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
...
...
@@ -169,8 +168,7 @@ def get_config():
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
StatPrinter
(),
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
30
,
6e-6
),
(
45
,
1e-6
),
(
60
,
8e-7
)]),
HumanHyperParamSetter
(
'learning_rate'
),
InferenceRunner
(
dataset_val
,
...
...
examples/README.md
View file @
61384a65
...
...
@@ -8,7 +8,7 @@ Training examples with __reproducible__ and meaningful performance.
+
[
Inception-BN with 71% accuracy
](
Inception/inception-bn.py
)
+
[
InceptionV3 with 74% accuracy (similar to the official code)
](
Inception/inceptionv3.py
)
+
[
DoReFa-Net: binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
+
[
ResNet for
Cifar10 and
SVHN
](
ResNet
)
+
[
ResNet for
ImageNet/Cifar10/
SVHN
](
ResNet
)
+
[
Holistically-Nested Edge Detection
](
HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
SpatialTransformer
)
+
[
DisturbLabel, because I don't believe the paper
](
DisturbLabel
)
...
...
tensorpack/models/batch_norm.py
View file @
61384a65
...
...
@@ -59,9 +59,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
assert
use_local_stat
==
ctx
.
is_training
if
use_local_stat
!=
ctx
.
is_training
:
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
if
ctx
.
is_training
:
if
use_local_stat
:
# training tower
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
...
...
@@ -72,7 +73,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
tf
.
add_to_collection
(
EXTRA_SAVE_VARS_KEY
,
ema_mean
)
tf
.
add_to_collection
(
EXTRA_SAVE_VARS_KEY
,
ema_var
)
else
:
assert
not
use_local_stat
if
ctx
.
is_main_tower
:
# not training, but main tower. need to create the vars
with
tf
.
name_scope
(
None
):
...
...
tensorpack/tfutils/sessinit.py
View file @
61384a65
...
...
@@ -159,7 +159,7 @@ class ParamRestore(SessionInit):
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
for
k
in
variable_names
-
param_names
:
if
not
is_training_
specific_
name
(
k
):
if
not
is_training_name
(
k
):
logger
.
warn
(
"Variable {} in the graph not found in the dict!"
.
format
(
k
))
for
k
in
param_names
-
variable_names
:
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment