Commit 5828a161 authored by Yuxin Wu's avatar Yuxin Wu

call model without scope & fix atari.py

parent e457e2db
......@@ -72,7 +72,8 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image):
""" image: [0,255]"""
image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True):
with argscope(Conv2D, nl=PReLU.f, use_bias=True), \
argscope(LeakyReLU, alpha=0.01):
l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2)
......@@ -87,7 +88,7 @@ class Model(ModelDesc):
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))())
.FullyConnected('fc0', 512, nl=LeakyReLU)())
if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
else:
......
......@@ -10,8 +10,7 @@ from collections import deque
import threading
import six
from six.moves import range
from tensorpack.utils import (get_rng, logger, memoized,
get_dataset_path, execute_only_once)
from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......
......@@ -50,17 +50,13 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, kernel_shape=5, stride=2), \
argscope(LeakyReLU, alpha=0.2):
l = (LinearWrap(imgs)
.Conv2D('conv0', 64)
.LeakyReLU('lr0')
.Conv2D('conv0', 64, nl=LeakyReLU)
.Conv2D('conv1', 64*2)
.BatchNorm('bn1')
.LeakyReLU('lr1')
.BatchNorm('bn1').LeakyReLU()
.Conv2D('conv2', 64*4)
.BatchNorm('bn2')
.LeakyReLU('lr2')
.BatchNorm('bn2').LeakyReLU()
.Conv2D('conv3', 64*8)
.BatchNorm('bn3')
.LeakyReLU('lr3')
.BatchNorm('bn3').LeakyReLU()
.FullyConnected('fct', 1, nl=tf.identity)())
return l
......
......@@ -56,7 +56,7 @@ class ExpReplay(DataFlow, Callback):
setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size)
self.mem = deque(maxlen=int(memory_size))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
self._predictor_io_names = predictor_io_names
......
......@@ -5,6 +5,7 @@
from pkgutil import walk_packages
from types import ModuleType
import tensorflow as tf
import six
import os
import os.path
from ..utils import logger
......@@ -49,13 +50,18 @@ class LinearWrap(object):
layer = eval(layer_name)
if hasattr(layer, 'f'):
# this is a registered tensorpack layer
# parse arguments by tensorpack model convention
if layer.use_scope:
def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
else:
def f(*args, **kwargs):
ret = layer(self._t, *args, **kwargs)
if isinstance(args[0], six.string_types):
name, args = args[0], args[1:]
ret = layer(name, self._t, *args, **kwargs)
else:
ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret)
return f
else:
......
......@@ -34,6 +34,7 @@ def layer_register(
Can be overriden when creating the layer.
:param log_shape: log input/output shape of this layer
:param use_scope: whether to call this layer with an extra first argument as scope
if set to False, will try to figure out whether the first argument is scope name
"""
def wrapper(func):
......@@ -45,8 +46,16 @@ def layer_register(
assert isinstance(name, six.string_types), name
else:
assert not log_shape and not summary_activation
inputs = args[0]
name = None
if isinstance(args[0], six.string_types):
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
else:
inputs = args[0]
name = None
if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
(isinstance(inputs, (list, tuple)) and
isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
raise ValueError("Invalid inputs to layer: " + str(inputs))
do_summary = kwargs.pop(
'summary_activation', summary_activation)
......
......@@ -47,7 +47,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name = 'output'
return tf.mul(x, 0.5, name=name)
@layer_register(log_shape=False)
@layer_register(use_scope=False, log_shape=False)
def LeakyReLU(x, alpha, name=None):
"""
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
......@@ -64,7 +64,7 @@ def LeakyReLU(x, alpha, name=None):
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
# TODO wrap it as a layer with use_scope=False?
@layer_register(log_shape=False, use_scope=False)
def BNReLU(x, name=None):
x = BatchNorm('bn', x, use_local_stat=None)
x = tf.nn.relu(x, name=name)
......
......@@ -73,7 +73,8 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
def build_patch_list(patch_list,
nr_row=None, nr_col=None, border=None,
max_width=1000, max_height=1000,
shuffle=False, bgcolor=255, viz=False, lclick_cb=None):
shuffle=False, bgcolor=255,
viz=False, lclick_cb=None):
"""
Generate patches.
:param patch_list: bhw or bhwc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment