Commit 7346f13b authored by Yuxin Wu's avatar Yuxin Wu

misc updates

parent ed7a0793
......@@ -9,7 +9,7 @@ import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.tfutils import sessinit
from tensorpack.dataflow import *
parser = argparse.ArgumentParser()
......@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G:
config = get_config_func()
config.get_model_func(config.inputs, is_training=False)
config.model.get_cost(config.model.get_input_vars(), is_training=False)
init = sessinit.SaverRestore(args.model)
sess = tf.Session()
init.init(sess)
......
......@@ -55,7 +55,7 @@ class CallbackTimeLogger(object):
msgs.append("{}:{:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, ' '.join(msgs)))
self.tot, '; '.join(msgs)))
class TestCallbackContext(object):
"""
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip, map
......
......@@ -23,7 +23,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
if max_count is None:
max_count = sys.maxint
for i, dp in enumerate(ds.get_data()):
print i
if i % 100 == 0:
print(i)
if i > max_count:
return
img = dp[index]
......
......@@ -19,10 +19,14 @@ class ModelDesc(object):
def get_input_vars(self):
"""
Create and return raw input vars in the graph.
Create or return (if already created) input TF vars in the graph.
:returns: the list of raw input vars in the graph
"""
try:
return self.reuse_input_vars()
except KeyError:
pass
input_vars = self._get_input_vars()
ret = []
for v in input_vars:
......@@ -37,7 +41,7 @@ class ModelDesc(object):
@abstractmethod
def _get_input_vars(self):
pass
""":returns: a list of InputVar """
def get_cost(self, input_vars, is_training):
"""
......@@ -59,3 +63,4 @@ class ModelDesc(object):
def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()]
......@@ -75,8 +75,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
:returns: NHWC tensor
"""
shape = shape2d(shape)
input_shape = x.get_shape().as_list()
assert len(input_shape) == 4
input_shape = tf.shape(x)
if unpool_mat is None:
mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1
......@@ -90,13 +89,11 @@ def FixedUnPooling(x, shape, unpool_mat=None):
fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
prod = tf.reshape(prod, [-1, input_shape[3],
input_shape[1], input_shape[2],
shape[0], shape[1]])
prod = tf.reshape(prod, tf.pack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
prod = tf.reshape(prod, [-1, input_shape[1] * shape[0],
input_shape[2] * shape[1],
input_shape[3]])
prod = tf.reshape(prod, tf.pack(
[-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
return prod
@layer_register()
......
......@@ -8,6 +8,7 @@ import os
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
if name in ['common', 'argscope']:
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -39,8 +39,10 @@ def batch_flatten(x):
"""
Flatten the tensor except the first dimension.
"""
total_dim = np.prod(x.get_shape()[1:].as_list())
return tf.reshape(x, [-1, total_dim])
shape = x.get_shape().as_list()[1:]
if None not in shape:
return tf.reshape(x, [-1, np.prod(shape)])
return tf.reshape(x, tf.pack([tf.shape(x)[0], -1]))
def logSoftmax(x):
"""
......
......@@ -128,7 +128,14 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True)
def process_grads(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
procs = self.config.model.get_gradient_processor()
for proc in procs:
grads = proc.process(grads)
return grads
g = proc.process(g)
return g
......@@ -30,7 +30,7 @@ class SimpleTrainer(Trainer):
input_vars = model.get_input_vars()
self.input_vars = input_vars
cost_var = model.get_cost(input_vars, is_training=True)
avg_maintain_op = summary_moving_average(cost_var)
avg_maintain_op = summary_moving_average()
grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads)
......@@ -66,13 +66,14 @@ class EnqueueThread(threading.Thread):
self.daemon = True
def run(self):
with self.sess.as_default():
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.op.run(feed_dict=feed, session=self.sess)
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ale_python_interface import ALEInterface
import numpy as np
import os
import cv2
from .utils import get_rng
__all__ = ['AtariDriver']
class AtariDriver(object):
"""
A driver for atari games.
"""
def __init__(self, rom_file, frame_skip=1, viz=False):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param viz: visualize the game while running
"""
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(99999))
self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.viz = viz
self.romname = os.path.basename(rom_file)
if self.viz:
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self._reset()
self.last_image = self._grab_raw_image()
def _grab_raw_image(self):
"""
:returns: a 3-channel image
"""
m = np.zeros(self.height * self.width * 3, dtype=np.uint8)
self.ale.getScreenRGB(m)
return m.reshape((self.height, self.width, 3))
def grab_image(self):
"""
:returns: a gray-scale image, maximum over the last
"""
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
if self.viz:
cv2.imshow(self.romname, ret)
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret
def get_num_actions(self):
"""
:returns: the number of legal actions
"""
return len(self.actions)
def _reset(self):
self.ale.reset_game()
def next(self, act):
"""
:param act: an index of the action
:returns: (next_image, reward, isOver)
"""
r = self.ale.act(self.actions[act])
s = self.grab_image()
isOver = self.ale.game_over()
if isOver:
self._reset()
return (s, r, isOver)
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions()
rng = get_rng(num)
import time
while True:
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act = rng.choice(range(num))
s, r, o = a.next(act)
time.sleep(0.1)
print(r, o)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment