Commit 93beba57 authored by Yuxin Wu's avatar Yuxin Wu

update readme

parent 6dc04278
...@@ -13,7 +13,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po ...@@ -13,7 +13,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po
To use the script. You'll need: To use the script. You'll need:
+ [TensorFlow](https://tensorflow.org) >= 0.10 + TensorFlow 0.10,0.11rc1,0.11rc2. 0.11 is not supported due to [TF bug](https://github.com/tensorflow/tensorflow/issues/5888)
+ OpenCV bindings for Python + OpenCV bindings for Python
...@@ -22,7 +22,7 @@ To use the script. You'll need: ...@@ -22,7 +22,7 @@ To use the script. You'll need:
``` ```
git clone https://github.com/ppwwyyxx/tensorpack git clone https://github.com/ppwwyyxx/tensorpack
pip install --user -r tensorpack/requirements.txt pip install --user -r tensorpack/requirements.txt
pip install --user pyzmq scipy pip install --user scipy
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack` export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
``` ```
......
...@@ -6,6 +6,8 @@ Use A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.o ...@@ -6,6 +6,8 @@ Use A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.o
`./train-atari.py --env Breakout-v0 --gpu 0` `./train-atari.py --env Breakout-v0 --gpu 0`
The pre-trained models are all trained with 4 GPUs.
### To run a pretrained Atari model for 100 episodes: ### To run a pretrained Atari model for 100 episodes:
1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk) 1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from termcolor import colored
from ..utils import logger from ..utils import logger
...@@ -20,8 +21,9 @@ def describe_model(): ...@@ -20,8 +21,9 @@ def describe_model():
msg.append("{}: shape={}, dim={}".format( msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele)) v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2 size_mb = total * 4 / 1024.0**2
msg.append("Total param={} ({:01f} MB assuming all float32)".format(total, size_mb)) msg.append(colored(
logger.info("Model Parameters: {}".format('\n'.join(msg))) "Total param={} ({:01f} MB assuming all float32)".format(total, size_mb), 'cyan'))
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg))
def get_shape_str(tensors): def get_shape_str(tensors):
......
...@@ -197,8 +197,8 @@ def get_model_loader(filename): ...@@ -197,8 +197,8 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name Get a corresponding model loader by looking at the file name
:return: either a ParamRestore or SaverRestore :return: either a ParamRestore or SaverRestore
""" """
assert os.path.isfile(filename), filename
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
return ParamRestore(np.load(filename, encoding='latin1').item()) return ParamRestore(np.load(filename, encoding='latin1').item())
else: else:
return SaverRestore(filename) return SaverRestore(filename)
...@@ -42,6 +42,7 @@ def download(url, dir): ...@@ -42,6 +42,7 @@ def download(url, dir):
raise raise
assert size > 0, "Download an empty file!" assert size > 0, "Download an empty file!"
sys.stdout.write('\n') sys.stdout.write('\n')
# TODO human-readable size
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.') print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment