Commit 6d67faf9 authored by Yuxin Wu's avatar Yuxin Wu

add dir option in saver. fix bug in GANTrainer

parent 7ce3d7ab
...@@ -287,6 +287,7 @@ lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True) ...@@ -287,6 +287,7 @@ lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
``` ```
This essentially creates a non-trainable variable with initial value `1e-4` and also track this value inside TensorBoard. This essentially creates a non-trainable variable with initial value `1e-4` and also track this value inside TensorBoard.
You can certainly just use `lr = 1e-4`, but then you'll lose the ability to modify it during training (through callbacks).
Let's have a look at the entire code: Let's have a look at the entire code:
```python ```python
......
...@@ -35,7 +35,7 @@ Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/ ...@@ -35,7 +35,7 @@ Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/
To train: To train:
``` ```
./DQN.py --rom breakout.bin ./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms # use `--algo` to select other DQN algorithms. See `-h` for more options.
``` ```
To visualize the agent: To visualize the agent:
...@@ -43,4 +43,4 @@ To visualize the agent: ...@@ -43,4 +43,4 @@ To visualize the agent:
./DQN.py --rom breakout.bin --task play --load trained.model ./DQN.py --rom breakout.bin --task play --load trained.model
``` ```
A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym) A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym)
...@@ -13,9 +13,8 @@ from tensorpack.dataflow import DataFlow ...@@ -13,9 +13,8 @@ from tensorpack.dataflow import DataFlow
class GANTrainer(FeedfreeTrainerBase): class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataset) self._input_method = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config) super(GANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
......
...@@ -16,15 +16,17 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] ...@@ -16,15 +16,17 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback): class ModelSaver(Callback):
""" """
Save the model to ``logger.LOG_DIR`` directory every epoch. Save the model every epoch.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES): var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
""" """
Args: Args:
keep_recent(int): see ``tf.train.Saver`` documentation. keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation. keep_freq(int): see ``tf.train.Saver`` documentation.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
var_collections (str or list): the variable collection (or list of collections) o save. var_collections (str or list): the variable collection (or list of collections) o save.
""" """
self.keep_recent = keep_recent self.keep_recent = keep_recent
...@@ -32,23 +34,20 @@ class ModelSaver(Callback): ...@@ -32,23 +34,20 @@ class ModelSaver(Callback):
if not isinstance(var_collections, list): if not isinstance(var_collections, list):
var_collections = [var_collections] var_collections = [var_collections]
self.var_collections = var_collections self.var_collections = var_collections
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
self.checkpoint_dir = checkpoint_dir
def _setup_graph(self): def _setup_graph(self):
vars = [] vars = []
for key in self.var_collections: for key in self.var_collections:
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
self.path = os.path.join(logger.LOG_DIR, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
try: self.saver = tf.train.Saver(
self.saver = tf.train.Saver( var_list=ModelSaver._get_var_dict(vars),
var_list=ModelSaver._get_var_dict(vars), max_to_keep=self.keep_recent,
max_to_keep=self.keep_recent, keep_checkpoint_every_n_hours=self.keep_freq,
keep_checkpoint_every_n_hours=self.keep_freq, write_version=tf.train.SaverDef.V2)
write_version=tf.train.SaverDef.V2)
except:
self.saver = tf.train.Saver(
var_list=ModelSaver._get_var_dict(vars),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
self.meta_graph_written = False self.meta_graph_written = False
@staticmethod @staticmethod
...@@ -70,7 +69,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -70,7 +69,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try: try:
if not self.meta_graph_written: if not self.meta_graph_written:
self.saver.export_meta_graph( self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR, os.path.join(self.checkpoint_dir,
'graph-{}.meta'.format(logger.get_time_str())), 'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys()) collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True self.meta_graph_written = True
...@@ -97,11 +96,16 @@ class MinSaver(Callback): ...@@ -97,11 +96,16 @@ class MinSaver(Callback):
Example: Example:
Save the model with minimum validation error to Save the model with minimum validation error to
"min-val-error.tfmodel" under ``logger.LOG_DIR``: "min-val-error.tfmodel":
.. code-block:: python .. code-block:: python
MinSaver('val-error') MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.LOG_DIR`` (the default). And it will save
the model to that directory as well.
""" """
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment