Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6d67faf9
Commit
6d67faf9
authored
Jan 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add dir option in saver. fix bug in GANTrainer
parent
7ce3d7ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
19 deletions
+23
-19
docs/casestudies/colorize.md
docs/casestudies/colorize.md
+1
-0
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+2
-2
examples/GAN/GAN.py
examples/GAN/GAN.py
+1
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+19
-15
No files found.
docs/casestudies/colorize.md
View file @
6d67faf9
...
@@ -287,6 +287,7 @@ lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
...
@@ -287,6 +287,7 @@ lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
```
```
This essentially creates a non-trainable variable with initial value
`1e-4`
and also track this value inside TensorBoard.
This essentially creates a non-trainable variable with initial value
`1e-4`
and also track this value inside TensorBoard.
You can certainly just use
`lr = 1e-4`
, but then you'll lose the ability to modify it during training (through callbacks).
Let's have a look at the entire code:
Let's have a look at the entire code:
```
python
```
python
...
...
examples/DeepQNetwork/README.md
View file @
6d67faf9
...
@@ -35,7 +35,7 @@ Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/
...
@@ -35,7 +35,7 @@ Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/
To train:
To train:
```
```
./DQN.py --rom breakout.bin
./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms
# use `--algo` to select other DQN algorithms
. See `-h` for more options.
```
```
To visualize the agent:
To visualize the agent:
...
@@ -43,4 +43,4 @@ To visualize the agent:
...
@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load trained.model
./DQN.py --rom breakout.bin --task play --load trained.model
```
```
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/
OpenAIGym
](
../OpenAI
Gym
)
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/
A3C-Gym
](
../A3C-
Gym
)
examples/GAN/GAN.py
View file @
6d67faf9
...
@@ -13,9 +13,8 @@ from tensorpack.dataflow import DataFlow
...
@@ -13,9 +13,8 @@ from tensorpack.dataflow import DataFlow
class
GANTrainer
(
FeedfreeTrainerBase
):
class
GANTrainer
(
FeedfreeTrainerBase
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
_input_method
=
QueueInput
(
config
.
data
set
)
self
.
_input_method
=
QueueInput
(
config
.
data
flow
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
def
_setup
(
self
):
...
...
tensorpack/callbacks/saver.py
View file @
6d67faf9
...
@@ -16,15 +16,17 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...
@@ -16,15 +16,17 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class
ModelSaver
(
Callback
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model
to ``logger.LOG_DIR`` directory
every epoch.
Save the model every epoch.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
checkpoint_dir
=
None
,
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
):
"""
"""
Args:
Args:
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
var_collections (str or list): the variable collection (or list of collections) o save.
var_collections (str or list): the variable collection (or list of collections) o save.
"""
"""
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
...
@@ -32,23 +34,20 @@ class ModelSaver(Callback):
...
@@ -32,23 +34,20 @@ class ModelSaver(Callback):
if
not
isinstance
(
var_collections
,
list
):
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
var_collections
=
[
var_collections
]
self
.
var_collections
=
var_collections
self
.
var_collections
=
var_collections
if
checkpoint_dir
is
None
:
checkpoint_dir
=
logger
.
LOG_DIR
self
.
checkpoint_dir
=
checkpoint_dir
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
vars
=
[]
vars
=
[]
for
key
in
self
.
var_collections
:
for
key
in
self
.
var_collections
:
vars
.
extend
(
tf
.
get_collection
(
key
))
vars
.
extend
(
tf
.
get_collection
(
key
))
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
'model'
)
try
:
self
.
saver
=
tf
.
train
.
Saver
(
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
ModelSaver
.
_get_var_dict
(
vars
),
var_list
=
ModelSaver
.
_get_var_dict
(
vars
),
max_to_keep
=
self
.
keep_recent
,
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
,
write_version
=
tf
.
train
.
SaverDef
.
V2
)
write_version
=
tf
.
train
.
SaverDef
.
V2
)
except
:
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
ModelSaver
.
_get_var_dict
(
vars
),
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
self
.
meta_graph_written
=
False
self
.
meta_graph_written
=
False
@
staticmethod
@
staticmethod
...
@@ -70,7 +69,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -70,7 +69,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try
:
try
:
if
not
self
.
meta_graph_written
:
if
not
self
.
meta_graph_written
:
self
.
saver
.
export_meta_graph
(
self
.
saver
.
export_meta_graph
(
os
.
path
.
join
(
logger
.
LOG_DIR
,
os
.
path
.
join
(
self
.
checkpoint_dir
,
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
collection_list
=
self
.
graph
.
get_all_collection_keys
())
collection_list
=
self
.
graph
.
get_all_collection_keys
())
self
.
meta_graph_written
=
True
self
.
meta_graph_written
=
True
...
@@ -97,11 +96,16 @@ class MinSaver(Callback):
...
@@ -97,11 +96,16 @@ class MinSaver(Callback):
Example:
Example:
Save the model with minimum validation error to
Save the model with minimum validation error to
"min-val-error.tfmodel"
under ``logger.LOG_DIR``
:
"min-val-error.tfmodel":
.. code-block:: python
.. code-block:: python
MinSaver('val-error')
MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.LOG_DIR`` (the default). And it will save
the model to that directory as well.
"""
"""
self
.
monitor_stat
=
monitor_stat
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
reverse
=
reverse
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment