Commit 4e472eb5 authored by Yuxin Wu's avatar Yuxin Wu

misc funcs

parent c83f2d9f
...@@ -148,22 +148,14 @@ class Model(ModelDesc): ...@@ -148,22 +148,14 @@ class Model(ModelDesc):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0] return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
while True: def f(s):
s = player.current_state() act = func([[s]])[0][0].argmax()
outputs = func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
if verbose:
print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(NUM_ACTIONS)) act = random.choice(range(NUM_ACTIONS))
if verbose: if verbose:
print(act) print(act)
reward, isOver = player.action(act) return act
if isOver: return player.play_one_episode(f)
sc = player.stats['score'][0]
player.reset_stat()
return sc
def play_model(model_path): def play_model(model_path):
player = get_player(0.013) player = get_player(0.013)
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from collections import deque from collections import deque
from .envbase import ProxyPlayer from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer', 'PreventStuckPlayer'] __all__ = ['HistoryFramePlayer', 'PreventStuckPlayer', 'LimitLengthPlayer']
class HistoryFramePlayer(ProxyPlayer): class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images""" """ Include history frames in state, or use black images"""
...@@ -62,3 +62,19 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -62,3 +62,19 @@ class PreventStuckPlayer(ProxyPlayer):
if isOver: if isOver:
self.act_que.clear() self.act_que.clear()
return (r, isOver) return (r, isOver)
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
self.cnt = 0
def action(self, act):
r, isOver = self.player.action(act)
self.cnt += 1
if self.cnt == self.limit:
isOver = True
if isOver:
self.cnt == 0
return (r, isOver)
...@@ -39,6 +39,20 @@ class RLEnvironment(object): ...@@ -39,6 +39,20 @@ class RLEnvironment(object):
""" reset the statistics counter""" """ reset the statistics counter"""
self.stats = defaultdict(list) self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
""" play one episode for eval.
:params func: call with the state and return an action
:returns: the score of this episode
"""
while True:
s = self.current_state()
act = func(s)
r, isOver = self.action(act)
if isOver:
s = self.stats[stat]
self.reset_stat()
return s
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ for testing only"""
def __init__(self): def __init__(self):
...@@ -71,3 +85,6 @@ class ProxyPlayer(RLEnvironment): ...@@ -71,3 +85,6 @@ class ProxyPlayer(RLEnvironment):
def stats(self): def stats(self):
return self.player.stats return self.player.stats
def play_one_episode(self, func, stat='score'):
return self.player.play_one_episode(self, func, stat)
...@@ -140,7 +140,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -140,7 +140,7 @@ class HumanHyperParamSetter(HyperParamSetter):
return ret return ret
except: except:
logger.warn( logger.warn(
"Failed to parse {} in {}".format( "Failed to find {} in {}".format(
self.param.readable_name, self.file_name)) self.param.readable_name, self.file_name))
return None return None
......
...@@ -11,7 +11,8 @@ from ..utils import * ...@@ -11,7 +11,8 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData', 'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent'] 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -25,7 +26,11 @@ class BatchData(ProxyDataFlow): ...@@ -25,7 +26,11 @@ class BatchData(ProxyDataFlow):
""" """
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
assert batch_size <= ds.size() try:
s = ds.size()
assert batch_size <= ds.size()
except NotImplementedError:
pass
self.batch_size = batch_size self.batch_size = batch_size
self.remainder = remainder self.remainder = remainder
...@@ -313,6 +318,16 @@ class JoinData(DataFlow): ...@@ -313,6 +318,16 @@ class JoinData(DataFlow):
for itr in itrs: for itr in itrs:
del itr del itr
class DataFromQueue(DataFlow):
""" provide data from a queue
"""
def __init__(self, queue):
self.queue = queue
def get_data(self):
while True:
yield self.queue.get()
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
:param ds: a :mod:`DataFlow` instance :param ds: a :mod:`DataFlow` instance
......
...@@ -21,8 +21,6 @@ from .common import * ...@@ -21,8 +21,6 @@ from .common import *
try: try:
if six.PY2: if six.PY2:
from tornado.concurrent import Future from tornado.concurrent import Future
import tornado.options as options
options.parse_command_line(['--logging=debug'])
else: else:
from concurrent.futures import Future from concurrent.futures import Future
except ImportError: except ImportError:
...@@ -146,6 +144,9 @@ class MultiThreadAsyncPredictor(object): ...@@ -146,6 +144,9 @@ class MultiThreadAsyncPredictor(object):
for id, f in enumerate( for id, f in enumerate(
trainer.get_predict_funcs( trainer.get_predict_funcs(
input_names, output_names, nr_thread))] input_names, output_names, nr_thread))]
# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])
def run(self): def run(self):
for t in self.threads: for t in self.threads:
......
...@@ -78,3 +78,8 @@ def print_stat(x): ...@@ -78,3 +78,8 @@ def print_stat(x):
Use it like: x = print_stat(x) Use it like: x = print_stat(x)
""" """
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20) return tf.Print(x, [tf.reduce_mean(x), x], summarize=20)
def rms(x, name=None):
if name is None:
name = x.op.name + '/rms'
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
...@@ -152,9 +152,9 @@ class Trainer(object): ...@@ -152,9 +152,9 @@ class Trainer(object):
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
# avoid sigint get handled by other processes with self.sess.as_default():
start_proc_mask_signal(self.extra_threads_procs) # avoid sigint get handled by other processes
start_proc_mask_signal(self.extra_threads_procs)
def process_grads(self, grads): def process_grads(self, grads):
g = [] g = []
......
...@@ -20,3 +20,6 @@ class LookUpTable(object): ...@@ -20,3 +20,6 @@ class LookUpTable(object):
def get_idx(self, obj): def get_idx(self, obj):
return self.obj2idx[obj] return self.obj2idx[obj]
def __str__(self):
return self.idx2obj.__str__()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment