Commit 7346f13b authored by Yuxin Wu's avatar Yuxin Wu

misc updates

parent ed7a0793
...@@ -9,7 +9,7 @@ import tensorflow as tf ...@@ -9,7 +9,7 @@ import tensorflow as tf
import imp import imp
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils import sessinit from tensorpack.tfutils import sessinit
from tensorpack.dataflow import * from tensorpack.dataflow import *
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config ...@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G: with tf.Graph().as_default() as G:
config = get_config_func() config = get_config_func()
config.get_model_func(config.inputs, is_training=False) config.model.get_cost(config.model.get_input_vars(), is_training=False)
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
sess = tf.Session() sess = tf.Session()
init.init(sess) init.init(sess)
......
...@@ -55,7 +55,7 @@ class CallbackTimeLogger(object): ...@@ -55,7 +55,7 @@ class CallbackTimeLogger(object):
msgs.append("{}:{:.3f}sec".format(name, t)) msgs.append("{}:{:.3f}sec".format(name, t))
logger.info( logger.info(
"Callbacks took {:.3f} sec in total. {}".format( "Callbacks took {:.3f} sec in total. {}".format(
self.tot, ' '.join(msgs))) self.tot, '; '.join(msgs)))
class TestCallbackContext(object): class TestCallbackContext(object):
""" """
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six.moves import zip, map from six.moves import zip, map
......
...@@ -23,7 +23,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -23,7 +23,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
if max_count is None: if max_count is None:
max_count = sys.maxint max_count = sys.maxint
for i, dp in enumerate(ds.get_data()): for i, dp in enumerate(ds.get_data()):
print i if i % 100 == 0:
print(i)
if i > max_count: if i > max_count:
return return
img = dp[index] img = dp[index]
......
...@@ -19,10 +19,14 @@ class ModelDesc(object): ...@@ -19,10 +19,14 @@ class ModelDesc(object):
def get_input_vars(self): def get_input_vars(self):
""" """
Create and return raw input vars in the graph. Create or return (if already created) input TF vars in the graph.
:returns: the list of raw input vars in the graph :returns: the list of raw input vars in the graph
""" """
try:
return self.reuse_input_vars()
except KeyError:
pass
input_vars = self._get_input_vars() input_vars = self._get_input_vars()
ret = [] ret = []
for v in input_vars: for v in input_vars:
...@@ -37,7 +41,7 @@ class ModelDesc(object): ...@@ -37,7 +41,7 @@ class ModelDesc(object):
@abstractmethod @abstractmethod
def _get_input_vars(self): def _get_input_vars(self):
pass """:returns: a list of InputVar """
def get_cost(self, input_vars, is_training): def get_cost(self, input_vars, is_training):
""" """
...@@ -59,3 +63,4 @@ class ModelDesc(object): ...@@ -59,3 +63,4 @@ class ModelDesc(object):
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()] return [CheckGradient()]#, SummaryGradient()]
...@@ -75,8 +75,7 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -75,8 +75,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
:returns: NHWC tensor :returns: NHWC tensor
""" """
shape = shape2d(shape) shape = shape2d(shape)
input_shape = x.get_shape().as_list() input_shape = tf.shape(x)
assert len(input_shape) == 4
if unpool_mat is None: if unpool_mat is None:
mat = np.zeros(shape, dtype='float32') mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1 mat[0][0] = 1
...@@ -90,13 +89,11 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -90,13 +89,11 @@ def FixedUnPooling(x, shape, unpool_mat=None):
fx = tf.expand_dims(fx, -1) # (bchw)x1 fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw) mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw) prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
prod = tf.reshape(prod, [-1, input_shape[3], prod = tf.reshape(prod, tf.pack(
input_shape[1], input_shape[2], [-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
shape[0], shape[1]])
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1]) prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
prod = tf.reshape(prod, [-1, input_shape[1] * shape[0], prod = tf.reshape(prod, tf.pack(
input_shape[2] * shape[1], [-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
input_shape[3]])
return prod return prod
@layer_register() @layer_register()
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), None, level=1) p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
if name in ['common', 'argscope']:
del globals()[name] del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -39,8 +39,10 @@ def batch_flatten(x): ...@@ -39,8 +39,10 @@ def batch_flatten(x):
""" """
Flatten the tensor except the first dimension. Flatten the tensor except the first dimension.
""" """
total_dim = np.prod(x.get_shape()[1:].as_list()) shape = x.get_shape().as_list()[1:]
return tf.reshape(x, [-1, total_dim]) if None not in shape:
return tf.reshape(x, [-1, np.prod(shape)])
return tf.reshape(x, tf.pack([tf.shape(x)[0], -1]))
def logSoftmax(x): def logSoftmax(x):
""" """
......
...@@ -128,7 +128,14 @@ class Trainer(object): ...@@ -128,7 +128,14 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
def process_grads(self, grads): def process_grads(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
procs = self.config.model.get_gradient_processor() procs = self.config.model.get_gradient_processor()
for proc in procs: for proc in procs:
grads = proc.process(grads) g = proc.process(g)
return grads return g
...@@ -30,7 +30,7 @@ class SimpleTrainer(Trainer): ...@@ -30,7 +30,7 @@ class SimpleTrainer(Trainer):
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
self.input_vars = input_vars self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True) cost_var = model.get_cost(input_vars, is_training=True)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average()
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads) grads = self.process_grads(grads)
...@@ -66,13 +66,14 @@ class EnqueueThread(threading.Thread): ...@@ -66,13 +66,14 @@ class EnqueueThread(threading.Thread):
self.daemon = True self.daemon = True
def run(self): def run(self):
with self.sess.as_default():
try: try:
while True: while True:
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ale_python_interface import ALEInterface
import numpy as np
import os
import cv2
from .utils import get_rng
__all__ = ['AtariDriver']
class AtariDriver(object):
"""
A driver for atari games.
"""
def __init__(self, rom_file, frame_skip=1, viz=False):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param viz: visualize the game while running
"""
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(99999))
self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.viz = viz
self.romname = os.path.basename(rom_file)
if self.viz:
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self._reset()
self.last_image = self._grab_raw_image()
def _grab_raw_image(self):
"""
:returns: a 3-channel image
"""
m = np.zeros(self.height * self.width * 3, dtype=np.uint8)
self.ale.getScreenRGB(m)
return m.reshape((self.height, self.width, 3))
def grab_image(self):
"""
:returns: a gray-scale image, maximum over the last
"""
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
if self.viz:
cv2.imshow(self.romname, ret)
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret
def get_num_actions(self):
"""
:returns: the number of legal actions
"""
return len(self.actions)
def _reset(self):
self.ale.reset_game()
def next(self, act):
"""
:param act: an index of the action
:returns: (next_image, reward, isOver)
"""
r = self.ale.act(self.actions[act])
s = self.grab_image()
isOver = self.ale.game_over()
if isOver:
self._reset()
return (s, r, isOver)
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions()
rng = get_rng(num)
import time
while True:
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act = rng.choice(range(num))
s, r, o = a.next(act)
time.sleep(0.1)
print(r, o)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment