Commit 93beba57 authored by Yuxin Wu's avatar Yuxin Wu

update readme

parent 6dc04278
......@@ -13,7 +13,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po
To use the script. You'll need:
+ [TensorFlow](https://tensorflow.org) >= 0.10
+ TensorFlow 0.10,0.11rc1,0.11rc2. 0.11 is not supported due to [TF bug](https://github.com/tensorflow/tensorflow/issues/5888)
+ OpenCV bindings for Python
......@@ -22,7 +22,7 @@ To use the script. You'll need:
```
git clone https://github.com/ppwwyyxx/tensorpack
pip install --user -r tensorpack/requirements.txt
pip install --user pyzmq scipy
pip install --user scipy
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
```
......
......@@ -6,6 +6,8 @@ Use A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.o
`./train-atari.py --env Breakout-v0 --gpu 0`
The pre-trained models are all trained with 4 GPUs.
### To run a pretrained Atari model for 100 episodes:
1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from termcolor import colored
from ..utils import logger
......@@ -20,8 +21,9 @@ def describe_model():
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2
msg.append("Total param={} ({:01f} MB assuming all float32)".format(total, size_mb))
logger.info("Model Parameters: {}".format('\n'.join(msg)))
msg.append(colored(
"Total param={} ({:01f} MB assuming all float32)".format(total, size_mb), 'cyan'))
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg))
def get_shape_str(tensors):
......
......@@ -197,8 +197,8 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name
:return: either a ParamRestore or SaverRestore
"""
assert os.path.isfile(filename), filename
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
return ParamRestore(np.load(filename, encoding='latin1').item())
else:
return SaverRestore(filename)
......@@ -42,6 +42,7 @@ def download(url, dir):
raise
assert size > 0, "Download an empty file!"
sys.stdout.write('\n')
# TODO human-readable size
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment