Commit 92ee69dc authored by Yuxin Wu's avatar Yuxin Wu

Use default session config in predict

parent 149ad4eb
......@@ -15,14 +15,15 @@ Reproduce the following reinforcement learning methods:
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
## Performance & Speed
Claimed performance in the paper can be reproduced, on several games I've tested with.
![DQN](curve-breakout.png)
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
Batch-A3C implementation only took <2 hours. (Both are trained with a larger network noted in the code).
Batch-A3C implementation only took <2 hours.
Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on TitanX.
Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
## How to use
......
......@@ -98,7 +98,7 @@ def get_config():
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
],
steps_per_epoch=5000,
max_epoch=110,
......@@ -112,7 +112,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101])
type=int, default=50, choices=[50, 101])
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
......@@ -121,7 +121,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
BATCH_SIZE = 64 # something that can run on one gpu
BATCH_SIZE = 128 # something that can run on one gpu
ds = get_data('val')
eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit()
......
......@@ -6,7 +6,6 @@ import tensorflow as tf
import six
from ..graph_builder import ModelDesc
from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator
......@@ -22,7 +21,6 @@ class PredictConfig(object):
output_names=None,
return_input=False,
create_graph=True,
session_config=None, # deprecated
):
"""
Args:
......@@ -50,18 +48,13 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSessionCreator(config=get_default_sess_config(0.4))
self.session_creator = NewSessionCreator(config=get_default_sess_config())
else:
self.session_creator = session_creator
# inputs & outputs
self.input_names = input_names
if self.input_names is None:
# neither options is set, assume all inputs
raw_tensors = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_tensors]
self.output_names = output_names
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment