Commit 6d67faf9 authored by Yuxin Wu's avatar Yuxin Wu

add dir option in saver. fix bug in GANTrainer

parent 7ce3d7ab
......@@ -287,6 +287,7 @@ lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
```
This essentially creates a non-trainable variable with initial value `1e-4` and also track this value inside TensorBoard.
You can certainly just use `lr = 1e-4`, but then you'll lose the ability to modify it during training (through callbacks).
Let's have a look at the entire code:
```python
......
......@@ -35,7 +35,7 @@ Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/
To train:
```
./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms
# use `--algo` to select other DQN algorithms. See `-h` for more options.
```
To visualize the agent:
......@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load trained.model
```
A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym)
A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym)
......@@ -13,9 +13,8 @@ from tensorpack.dataflow import DataFlow
class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
self._input_method = QueueInput(config.dataset)
self._input_method = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config)
def _setup(self):
......
......@@ -16,15 +16,17 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
"""
Save the model to ``logger.LOG_DIR`` directory every epoch.
Save the model every epoch.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
"""
Args:
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
var_collections (str or list): the variable collection (or list of collections) o save.
"""
self.keep_recent = keep_recent
......@@ -32,23 +34,20 @@ class ModelSaver(Callback):
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
self.checkpoint_dir = checkpoint_dir
def _setup_graph(self):
vars = []
for key in self.var_collections:
vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model')
try:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2)
except:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
self.path = os.path.join(self.checkpoint_dir, 'model')
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
write_version=tf.train.SaverDef.V2)
self.meta_graph_written = False
@staticmethod
......@@ -70,7 +69,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try:
if not self.meta_graph_written:
self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR,
os.path.join(self.checkpoint_dir,
'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True
......@@ -97,11 +96,16 @@ class MinSaver(Callback):
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``:
"min-val-error.tfmodel":
.. code-block:: python
MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.LOG_DIR`` (the default). And it will save
the model to that directory as well.
"""
self.monitor_stat = monitor_stat
self.reverse = reverse
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment