Commit 7b9d9b8c authored by Yuxin Wu's avatar Yuxin Wu

Do not require ALE if using DQN with gym only (fix #1091)

parent e4aca035
......@@ -401,6 +401,7 @@ _DEPRECATED_NAMES = set([
'average_grads',
'aggregate_grads',
'allreduce_grads',
'get_checkpoint_path'
])
def autodoc_skip_member(app, what, name, obj, skip, options):
......
......@@ -10,8 +10,8 @@ Both are necessary.
`tf.train.NewCheckpointReader` is the offical tool to parse TensorFlow checkpoint.
Read [TF docs](https://www.tensorflow.org/api_docs/python/tf/train/NewCheckpointReader) for details.
Tensorpack also provides some small tools to work with checkpoints, see
[documentation](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars)
Tensorpack also provides a small tool to load checkpoints, see
[load_chkpt_vars](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars)
for details.
[scripts/ls-checkpoint.py](../scripts/ls-checkpoint.py)
......@@ -19,19 +19,28 @@ demos how to print all variables and their shapes in a checkpoint.
[scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint.
It takes a metagraph file (which is also saved by `ModelSaver`) and only saves variables that the model needs at inference time.
It can dump the model to a `var-name: value` dict saved in npz format.
It dumps the model to a `var-name: value` dict saved in npz format.
## Load a Model to a Session
Model loading (in either training or inference) is through the `session_init` interface.
Model loading (in both training and inference) is through the `session_init` interface.
Currently there are two ways a session can be restored:
[session_init=SaverRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore)
which restores a TF checkpoint,
or [session_init=DictRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore) which restores a dict.
[get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader)
is a small helper to decide which one to use from a file name.
To load multiple models, use [ChainInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit).
Many models in tensorpack model zoo are provided in the form of numpy dictionary (`.npz`),
because it is easier to load and manipulate without requiring TensorFlow.
To load such files to a session, use `DictRestore(dict(np.load(filename)))`.
You can also use
[get_model_loader](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader),
a small helper to create a `SaverRestore` or `DictRestore` based on the file name.
`DictRestore` is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load a TF checkpoint into a dict in order to make changes, use
[load_chkpt_vars](../modules/tfutils.html#tensorpack.tfutils.varmanip.load_chkpt_vars).
Variable restoring is completely based on __name match__ between
variables in the current graph and variables in the `session_init` initializer.
......@@ -40,4 +49,5 @@ Variables that appear in only one side will be printed as warning.
## Transfer Learning
Therefore, transfer learning is trivial.
If you want to load a pre-trained model, just use the same variable names.
If you want to re-train some layer, just rename it.
If you want to re-train some layer, just rename either the variables in the
graph or the variables in your loader.
......@@ -12,7 +12,6 @@ import tensorflow as tf
from tensorpack import *
from atari import AtariPlayer
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
from common import Evaluator, eval_model_multithread, play_n_episodes
from DQNModel import Model as DQNModel
......@@ -52,6 +51,7 @@ def get_player(viz=False, train=False):
if USE_GYM:
env = gym.make(ENV_NAME)
else:
from atari import AtariPlayer
env = AtariPlayer(ENV_NAME, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env)
......
......@@ -31,7 +31,11 @@ class ReplayMemory(object):
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
self.history_len = int(history_len)
self.state = np.zeros((self.max_size,) + state_shape, dtype='uint8')
state_shape = (self.max_size,) + state_shape
logger.info("Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory.".format(
np.prod(state_shape) / 1024.0**3))
self.state = np.zeros(state_shape, dtype='uint8')
self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool')
......
......@@ -45,7 +45,7 @@ In this implementation, quantized operations are all performed through `tf.float
+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
Pretrained model for (1,4,32)-ResNet18 and several AlexNet are available at
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/).
They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment