Commit 47992266 authored by Yuxin Wu's avatar Yuxin Wu

check gpu availability to avoid confusion

parent 9588d6ca
......@@ -54,7 +54,7 @@ class AtariPlayer(RLEnvironment):
"rom {} not found. Please download at {}".format(rom_file, ROM_URL)
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
except AttributeError:
if execute_only_once():
logger.warn("You're not using latest ALE")
......
......@@ -211,6 +211,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
args = parser.parse_args()
assert tf.test.is_gpu_available()
logger.auto_set_dir()
data = get_celebA_data(args.data, args.style_A, args.style_B)
......
......@@ -47,6 +47,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
image, label = inputs
image = image / 128.0
assert tf.test.is_gpu_available()
image = tf.transpose(image, [0, 3, 1, 2])
def residual(name, l, increase_dim=False, first=False):
......
......@@ -23,6 +23,8 @@ DEPTH = None
class Model(ModelDesc):
def __init__(self, data_format='NCHW'):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format
def _get_inputs(self):
......
......@@ -40,11 +40,15 @@ class Model(ModelDesc):
if is_training:
tf.summary.image("train_image", image, 10)
image = tf.transpose(image, [0, 3, 1, 2])
if tf.test.is_gpu_available():
image = tf.transpose(image, [0, 3, 1, 2])
data_format = 'NCHW'
else:
data_format = 'NHWC'
image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3), \
argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'):
argscope([Conv2D, MaxPooling, BatchNorm], data_format=data_format):
logits = LinearWrap(image) \
.Conv2D('conv1.1', out_channel=64) \
.Conv2D('conv1.2', out_channel=64) \
......
......@@ -5,11 +5,38 @@
import numpy as np
from collections import deque
from six.moves import range
from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer']
class HistoryBuffer(object):
def __init__(self, hist_len, concat_axis=2):
self.buf = deque(maxlen=hist_len)
self.concat_axis = concat_axis
def push(self, s):
self.buf.append(s)
def clear(self):
self.buf.clear()
def get(self):
difflen = self.buf.maxlen - len(self.buf)
if difflen == 0:
ret = self.buf
else:
zeros = [np.zeros_like(self.buf[0]) for k in range(difflen)]
for k in self.buf:
zeros.append(k)
ret = zeros
return np.concatenate(ret, axis=self.concat_axis)
def __len__(self):
return len(self.buf)
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart.
......@@ -22,30 +49,23 @@ class HistoryFramePlayer(ProxyPlayer):
and `hist_len-1` history.
"""
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
self.history = HistoryBuffer(hist_len, 2)
s = self.player.current_state()
self.history.append(s)
self.history.push(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
assert len(zeros) == self.history.maxlen
return np.concatenate(zeros, axis=2)
return self.history.get()
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
self.history.push(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
self.history.push(s)
return (r, isOver)
def restart_episode(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment