Commit 4e472eb5 authored by Yuxin Wu's avatar Yuxin Wu

misc funcs

parent c83f2d9f
......@@ -148,22 +148,14 @@ class Model(ModelDesc):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def play_one_episode(player, func, verbose=False):
while True:
s = player.current_state()
outputs = func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
if verbose:
print action_value, act
def f(s):
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
act = random.choice(range(NUM_ACTIONS))
if verbose:
print(act)
reward, isOver = player.action(act)
if isOver:
sc = player.stats['score'][0]
player.reset_stat()
return sc
return act
return player.play_one_episode(f)
def play_model(model_path):
player = get_player(0.013)
......
......@@ -8,7 +8,7 @@ import numpy as np
from collections import deque
from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer', 'PreventStuckPlayer']
__all__ = ['HistoryFramePlayer', 'PreventStuckPlayer', 'LimitLengthPlayer']
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images"""
......@@ -62,3 +62,19 @@ class PreventStuckPlayer(ProxyPlayer):
if isOver:
self.act_que.clear()
return (r, isOver)
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
self.cnt = 0
def action(self, act):
r, isOver = self.player.action(act)
self.cnt += 1
if self.cnt == self.limit:
isOver = True
if isOver:
self.cnt == 0
return (r, isOver)
......@@ -39,6 +39,20 @@ class RLEnvironment(object):
""" reset the statistics counter"""
self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
""" play one episode for eval.
:params func: call with the state and return an action
:returns: the score of this episode
"""
while True:
s = self.current_state()
act = func(s)
r, isOver = self.action(act)
if isOver:
s = self.stats[stat]
self.reset_stat()
return s
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
......@@ -71,3 +85,6 @@ class ProxyPlayer(RLEnvironment):
def stats(self):
return self.player.stats
def play_one_episode(self, func, stat='score'):
return self.player.play_one_episode(self, func, stat)
......@@ -140,7 +140,7 @@ class HumanHyperParamSetter(HyperParamSetter):
return ret
except:
logger.warn(
"Failed to parse {} in {}".format(
"Failed to find {} in {}".format(
self.param.readable_name, self.file_name))
return None
......
......@@ -11,7 +11,8 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent']
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -25,7 +26,11 @@ class BatchData(ProxyDataFlow):
"""
super(BatchData, self).__init__(ds)
if not remainder:
assert batch_size <= ds.size()
try:
s = ds.size()
assert batch_size <= ds.size()
except NotImplementedError:
pass
self.batch_size = batch_size
self.remainder = remainder
......@@ -313,6 +318,16 @@ class JoinData(DataFlow):
for itr in itrs:
del itr
class DataFromQueue(DataFlow):
""" provide data from a queue
"""
def __init__(self, queue):
self.queue = queue
def get_data(self):
while True:
yield self.queue.get()
def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
......
......@@ -21,8 +21,6 @@ from .common import *
try:
if six.PY2:
from tornado.concurrent import Future
import tornado.options as options
options.parse_command_line(['--logging=debug'])
else:
from concurrent.futures import Future
except ImportError:
......@@ -146,6 +144,9 @@ class MultiThreadAsyncPredictor(object):
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])
def run(self):
for t in self.threads:
......
......@@ -78,3 +78,8 @@ def print_stat(x):
Use it like: x = print_stat(x)
"""
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20)
def rms(x, name=None):
if name is None:
name = x.op.name + '/rms'
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
......@@ -152,9 +152,9 @@ class Trainer(object):
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes
start_proc_mask_signal(self.extra_threads_procs)
with self.sess.as_default():
# avoid sigint get handled by other processes
start_proc_mask_signal(self.extra_threads_procs)
def process_grads(self, grads):
g = []
......
......@@ -20,3 +20,6 @@ class LookUpTable(object):
def get_idx(self, obj):
return self.obj2idx[obj]
def __str__(self):
return self.idx2obj.__str__()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment