Commit 5828a161 authored by Yuxin Wu's avatar Yuxin Wu

call model without scope & fix atari.py

parent e457e2db
...@@ -72,7 +72,8 @@ class Model(ModelDesc): ...@@ -72,7 +72,8 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True): with argscope(Conv2D, nl=PReLU.f, use_bias=True), \
argscope(LeakyReLU, alpha=0.01):
l = (LinearWrap(image) l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5) .Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2) .MaxPooling('pool0', 2)
...@@ -87,7 +88,7 @@ class Model(ModelDesc): ...@@ -87,7 +88,7 @@ class Model(ModelDesc):
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2) #.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv2', out_channel=64, kernel_shape=3) #.Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))()) .FullyConnected('fc0', 512, nl=LeakyReLU)())
if METHOD != 'Dueling': if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity) Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
else: else:
......
...@@ -10,8 +10,7 @@ from collections import deque ...@@ -10,8 +10,7 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from tensorpack.utils import (get_rng, logger, memoized, from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once)
get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter from tensorpack.utils.stat import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......
...@@ -50,17 +50,13 @@ class Model(ModelDesc): ...@@ -50,17 +50,13 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, kernel_shape=5, stride=2), \ with argscope(Conv2D, nl=tf.identity, kernel_shape=5, stride=2), \
argscope(LeakyReLU, alpha=0.2): argscope(LeakyReLU, alpha=0.2):
l = (LinearWrap(imgs) l = (LinearWrap(imgs)
.Conv2D('conv0', 64) .Conv2D('conv0', 64, nl=LeakyReLU)
.LeakyReLU('lr0')
.Conv2D('conv1', 64*2) .Conv2D('conv1', 64*2)
.BatchNorm('bn1') .BatchNorm('bn1').LeakyReLU()
.LeakyReLU('lr1')
.Conv2D('conv2', 64*4) .Conv2D('conv2', 64*4)
.BatchNorm('bn2') .BatchNorm('bn2').LeakyReLU()
.LeakyReLU('lr2')
.Conv2D('conv3', 64*8) .Conv2D('conv3', 64*8)
.BatchNorm('bn3') .BatchNorm('bn3').LeakyReLU()
.LeakyReLU('lr3')
.FullyConnected('fct', 1, nl=tf.identity)()) .FullyConnected('fct', 1, nl=tf.identity)())
return l return l
......
...@@ -56,7 +56,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -56,7 +56,7 @@ class ExpReplay(DataFlow, Callback):
setattr(self, k, v) setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions() self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=int(memory_size))
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
self._predictor_io_names = predictor_io_names self._predictor_io_names = predictor_io_names
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from pkgutil import walk_packages from pkgutil import walk_packages
from types import ModuleType from types import ModuleType
import tensorflow as tf import tensorflow as tf
import six
import os import os
import os.path import os.path
from ..utils import logger from ..utils import logger
...@@ -49,12 +50,17 @@ class LinearWrap(object): ...@@ -49,12 +50,17 @@ class LinearWrap(object):
layer = eval(layer_name) layer = eval(layer_name)
if hasattr(layer, 'f'): if hasattr(layer, 'f'):
# this is a registered tensorpack layer # this is a registered tensorpack layer
# parse arguments by tensorpack model convention
if layer.use_scope: if layer.use_scope:
def f(name, *args, **kwargs): def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs) ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret) return LinearWrap(ret)
else: else:
def f(*args, **kwargs): def f(*args, **kwargs):
if isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
ret = layer(self._t, *args, **kwargs) ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret) return LinearWrap(ret)
return f return f
......
...@@ -34,6 +34,7 @@ def layer_register( ...@@ -34,6 +34,7 @@ def layer_register(
Can be overriden when creating the layer. Can be overriden when creating the layer.
:param log_shape: log input/output shape of this layer :param log_shape: log input/output shape of this layer
:param use_scope: whether to call this layer with an extra first argument as scope :param use_scope: whether to call this layer with an extra first argument as scope
if set to False, will try to figure out whether the first argument is scope name
""" """
def wrapper(func): def wrapper(func):
...@@ -45,8 +46,16 @@ def layer_register( ...@@ -45,8 +46,16 @@ def layer_register(
assert isinstance(name, six.string_types), name assert isinstance(name, six.string_types), name
else: else:
assert not log_shape and not summary_activation assert not log_shape and not summary_activation
if isinstance(args[0], six.string_types):
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
else:
inputs = args[0] inputs = args[0]
name = None name = None
if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
(isinstance(inputs, (list, tuple)) and
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs))
do_summary = kwargs.pop( do_summary = kwargs.pop(
'summary_activation', summary_activation) 'summary_activation', summary_activation)
......
...@@ -47,7 +47,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None): ...@@ -47,7 +47,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name = 'output' name = 'output'
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
@layer_register(log_shape=False) @layer_register(use_scope=False, log_shape=False)
def LeakyReLU(x, alpha, name=None): def LeakyReLU(x, alpha, name=None):
""" """
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
...@@ -64,7 +64,7 @@ def LeakyReLU(x, alpha, name=None): ...@@ -64,7 +64,7 @@ def LeakyReLU(x, alpha, name=None):
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) #x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name) #return tf.mul(x, 0.5, name=name)
# TODO wrap it as a layer with use_scope=False? @layer_register(log_shape=False, use_scope=False)
def BNReLU(x, name=None): def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None) x = BatchNorm('bn', x, use_local_stat=None)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
......
...@@ -73,7 +73,8 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs): ...@@ -73,7 +73,8 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
def build_patch_list(patch_list, def build_patch_list(patch_list,
nr_row=None, nr_col=None, border=None, nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000, max_width=1000, max_height=1000,
shuffle=False, bgcolor=255, viz=False, lclick_cb=None): shuffle=False, bgcolor=255,
viz=False, lclick_cb=None):
""" """
Generate patches. Generate patches.
:param patch_list: bhw or bhwc :param patch_list: bhw or bhwc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment