Commit 92ee69dc authored by Yuxin Wu's avatar Yuxin Wu

Use default session config in predict

parent 149ad4eb
...@@ -15,14 +15,15 @@ Reproduce the following reinforcement learning methods: ...@@ -15,14 +15,15 @@ Reproduce the following reinforcement learning methods:
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I + A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".) used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
## Performance & Speed
Claimed performance in the paper can be reproduced, on several games I've tested with. Claimed performance in the paper can be reproduced, on several games I've tested with.
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game. On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout game.
Batch-A3C implementation only took <2 hours. (Both are trained with a larger network noted in the code). Batch-A3C implementation only took <2 hours.
Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on TitanX. Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX.
## How to use ## How to use
......
...@@ -98,7 +98,7 @@ def get_config(): ...@@ -98,7 +98,7 @@ def get_config():
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]), ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]), [(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
], ],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
...@@ -112,7 +112,7 @@ if __name__ == '__main__': ...@@ -112,7 +112,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101]) type=int, default=50, choices=[50, 101])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
args = parser.parse_args() args = parser.parse_args()
...@@ -121,7 +121,7 @@ if __name__ == '__main__': ...@@ -121,7 +121,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval: if args.eval:
BATCH_SIZE = 64 # something that can run on one gpu BATCH_SIZE = 128 # something that can run on one gpu
ds = get_data('val') ds = get_data('val')
eval_on_ILSVRC12(Model(), args.load, ds) eval_on_ILSVRC12(Model(), args.load, ds)
sys.exit() sys.exit()
......
...@@ -6,7 +6,6 @@ import tensorflow as tf ...@@ -6,7 +6,6 @@ import tensorflow as tf
import six import six
from ..graph_builder import ModelDesc from ..graph_builder import ModelDesc
from ..utils.develop import log_deprecated
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -22,7 +21,6 @@ class PredictConfig(object): ...@@ -22,7 +21,6 @@ class PredictConfig(object):
output_names=None, output_names=None,
return_input=False, return_input=False,
create_graph=True, create_graph=True,
session_config=None, # deprecated
): ):
""" """
Args: Args:
...@@ -50,18 +48,13 @@ class PredictConfig(object): ...@@ -50,18 +48,13 @@ class PredictConfig(object):
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
if session_creator is None: if session_creator is None:
if session_config is not None: self.session_creator = NewSessionCreator(config=get_default_sess_config())
log_deprecated("PredictConfig(session_config=)", "Use session_creator instead!", "2017-04-20")
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSessionCreator(config=get_default_sess_config(0.4))
else: else:
self.session_creator = session_creator self.session_creator = session_creator
# inputs & outputs # inputs & outputs
self.input_names = input_names self.input_names = input_names
if self.input_names is None: if self.input_names is None:
# neither options is set, assume all inputs
raw_tensors = self.model.get_inputs_desc() raw_tensors = self.model.get_inputs_desc()
self.input_names = [k.name for k in raw_tensors] self.input_names = [k.name for k in raw_tensors]
self.output_names = output_names self.output_names = output_names
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment