Commit d7a85f44 authored by Yuxin Wu's avatar Yuxin Wu

misc update on framework

parent 11c46a71
......@@ -60,3 +60,5 @@ docs/_build/
# PyBuilder
target/
*.dat
*.bin
......@@ -16,7 +16,7 @@ from collections import deque
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import ensure_proc_terminate
from tensorpack.utils.concurrency import ensure_proc_terminate, subproc_call
from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictWorker
from tensorpack.tfutils import symbolic_functions as symbf
......@@ -33,7 +33,6 @@ for atari games
BATCH_SIZE = 32
IMAGE_SIZE = 84
NUM_ACTIONS = None
FRAME_HISTORY = 4
ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout
......@@ -49,6 +48,15 @@ INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100
NUM_ACTIONS = None
ROM_FILE = None
def get_player(viz=False):
pl = AtariPlayer(ROM_FILE, viz=viz, height_range=HEIGHT_RANGE, frame_skip=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions()
return pl
class Model(ModelDesc):
def _get_input_vars(self):
assert NUM_ACTIONS is not None
......@@ -56,8 +64,7 @@ class Model(ModelDesc):
InputVar(tf.int32, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE, FRAME_HISTORY), 'next_state'),
InputVar(tf.bool, (None,), 'isOver')
]
InputVar(tf.bool, (None,), 'isOver') ]
def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]"""
......@@ -89,7 +96,7 @@ class Model(ModelDesc):
with tf.variable_scope('target'):
targetQ_predict_value = tf.stop_gradient(
self._get_DQN_prediction(next_state, False)) # NxA
target = reward + (1 - tf.cast(isOver, tf.int32)) *
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * \
GAMMA * tf.reduce_max(targetQ_predict_value, 1) # Nx1
sqrcost = tf.square(target - pred_action_value)
......@@ -108,7 +115,7 @@ class Model(ModelDesc):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(tf.get_default_graph().get_tensor_by_name(new_name + ':0')))
return tf.group(*ops)
return tf.group(*ops, name='update_target_network')
def get_gradient_processor(self):
return [MapGradient(lambda grad: \
......@@ -120,28 +127,11 @@ def current_predictor(state):
pred = pred_var.eval(feed_dict={'state:0': [state]})
return pred[0]
class TargetNetworkUpdator(Callback):
def __init__(self, M):
self.M = M
def _setup_graph(self):
self.update_op = self.M.update_target_param()
def _update(self):
logger.info("Delayed Predictor updating...")
self.update_op.run()
def _before_train(self):
self._update()
def _trigger_epoch(self):
self._update()
def play_one_episode(player, func, verbose=False):
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state() # XXX
s = player.current_state()
outputs = func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
......@@ -160,16 +150,10 @@ def play_one_episode(player, func, verbose=False):
if isOver:
return tot_reward
def play_model(model_path, romfile):
player = HistoryFramePlayer(AtariPlayer(
romfile, viz=0.01, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS
NUM_ACTIONS = player.player.get_num_actions()
M = Model()
def play_model(model_path):
player = HistoryFramePlayer(get_player(0.01), FRAME_HISTORY)
cfg = PredictConfig(
model=M,
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
......@@ -178,7 +162,7 @@ def play_model(model_path, romfile):
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_model_multiprocess(model_path, romfile):
def eval_model_multiprocess(model_path):
M = Model()
cfg = PredictConfig(
model=M,
......@@ -192,11 +176,7 @@ def eval_model_multiprocess(model_path, romfile):
self.outq = outqueue
def run(self):
player = HistoryFramePlayer(AtariPlayer(
romfile, viz=0, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS
NUM_ACTIONS = player.player.get_num_actions()
player = HistoryFramePlayer(get_player(), FRAME_HISTORY)
self._init_runtime()
while True:
score = play_one_episode(player, self.func)
......@@ -216,32 +196,32 @@ def eval_model_multiprocess(model_path, romfile):
r = q.get()
stat.feed(r)
finally:
for p in procs:
p.terminate()
p.join()
if stat.count() > 0:
logger.info("Average Score: {}; Max Score: {}".format(
stat.average, stat.max))
return (stat.average, stat.max)
else:
return (0, 0)
def get_config(romfile):
logger.info("Average Score: {}; Max Score: {}".format(
stat.average, stat.max))
class Evaluator(Callback):
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subproc_call(
"CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {}".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')),
timeout=10*60)
if output:
last = output.strip().split('\n')[-1]
last = last[last.find(']')+1:]
mean, maximum = re.findall('[0-9\.\-]+', last)[-2:]
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
M = Model()
player = AtariPlayer(
romfile, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.get_num_actions()
M = Model()
dataset_train = ExpReplay(
predictor=current_predictor,
player=player,
player=get_player(),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE,
......@@ -255,18 +235,6 @@ def get_config(romfile):
lr = tf.Variable(0.00025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
class Evaluator(Callback):
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subprocess.check_output(
"""CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True)
output = output.strip()
output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.\-]+', output)[-2:]
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
......@@ -274,15 +242,13 @@ def get_config(romfile):
StatPrinter(),
ModelSaver(),
HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M),
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()),
dataset_train,
PeriodicCallback(Evaluator(), 2),
]),
session_config=get_default_sess_config(0.5),
model=M,
step_per_epoch=STEP_PER_EPOCH,
max_epoch=10000,
)
if __name__ == '__main__':
......@@ -300,15 +266,18 @@ if __name__ == '__main__':
if args.task != 'train':
assert args.load is not None
global ROM_FILE
ROM_FILE = args.rom
if args.task == 'play':
play_model(args.load, args.rom)
play_model(args.load)
sys.exit()
if args.task == 'eval':
eval_model_multiprocess(args.load, args.rom)
eval_model_multiprocess(args.load)
sys.exit()
with tf.Graph().as_default():
config = get_config(args.rom)
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: graph.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Graph related callbacks"""
from .base import Callback
from ..utils import logger
__all__ = ['RunOp']
class RunOp(Callback):
""" Run an op periodically"""
def __init__(self, setup_func, run_before=True, run_epoch=True):
"""
:param setup_func: a function that returns the op in the graph
:param run_before: run the op before training
:param run_epoch: run the op on every epoch trigger
"""
self.setup_func = setup_func
self.run_before = run_before
self.run_epoch = run_epoch
def _setup_graph(self):
self._op = self.setup_func()
#self._op_name = self._op.name
def _before_train(self):
if self.run_before:
self._op.run()
def _trigger_epoch(self):
if self.run_epoch:
self._op.run()
#def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
......@@ -4,15 +4,74 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import abstractmethod, ABCMeta
from abc import abstractmethod, ABCMeta, abstractproperty
import operator
import six
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter']
'ScheduledHyperParamSetter',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
class HyperParam(object):
""" Base class for a hyper param"""
__metaclass__ = ABCMeta
def setup_graph(self):
""" setup the graph in `setup_graph` callback stage, if necessary"""
pass
@abstractmethod
def set_value(self, v):
""" define how the value of the param will be set"""
pass
@abstractproperty
def readable_name(self):
pass
class GraphVarParam(HyperParam):
""" a variable in the graph"""
def __init__(self, name, shape=[]):
self.name = name
self.shape = shape
self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self):
all_vars = tf.all_variables()
for v in all_vars:
if v.name == self.var_name:
self.var = v
break
else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self._readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v})
@property
def readable_name(self):
return self._readable_name
class ObjAttrParam(HyperParam):
""" an attribute of an object"""
def __init__(self, obj, attrname):
self.obj = obj
self.attrname = attrname
def set_value(self, v):
setattr(self.obj, self.attrname, v)
@property
def readable_name(self):
return self.attrname
class HyperParamSetter(Callback):
"""
......@@ -20,51 +79,33 @@ class HyperParamSetter(Callback):
"""
__metaclass__ = ABCMeta
TF_VAR = 0
OBJ_ATTR = 1
def __init__(self, param, shape=[]):
def __init__(self, param):
"""
:param param: either a name of the variable in the graph, or a (object, attribute) tuple
:param shape: shape of the param
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`)
"""
if isinstance(param, tuple):
self.param_type = HyperParamSetter.OBJ_ATTR
self.obj_attr = param
self.readable_name = param[1]
else:
self.param_type = HyperParamSetter.TF_VAR
self.readable_name, self.var_name = get_op_var_name(param)
self.shape = shape
# if a string, assumed to be a scalar graph variable
if isinstance(param, six.string_types):
param = GraphVarParam(param)
assert isinstance(param, HyperParam), type(param)
self.param = param
self.last_value = None
def _setup_graph(self):
if self.param_type == HyperParamSetter.TF_VAR:
all_vars = tf.all_variables()
for v in all_vars:
if v.name == self.var_name:
self.var = v
break
else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self.readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def get_current_value(self):
self.param.setup_graph()
def get_value_to_set(self):
"""
:returns: the value to assign to the variable now.
"""
ret = self._get_current_value()
ret = self._get_value_to_set()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} will change to {}".format(
self.readable_name, self.epoch_num + 1, ret))
self.param.readable_name, self.epoch_num + 1, ret))
self.last_value = ret
return ret
@abstractmethod
def _get_current_value(self):
def _get_value_to_set(self):
pass
def _trigger_epoch(self):
......@@ -74,12 +115,9 @@ class HyperParamSetter(Callback):
self._set_param()
def _set_param(self):
v = self.get_current_value()
v = self.get_value_to_set()
if v is not None:
if self.param_type == HyperParamSetter.TF_VAR:
self.assign_op.eval(feed_dict={self.val_holder:v})
else:
setattr(self.obj_attr[0], self.obj_attr[1], v)
self.param.set_value(v)
class HumanHyperParamSetter(HyperParamSetter):
"""
......@@ -92,18 +130,18 @@ class HumanHyperParamSetter(HyperParamSetter):
self.file_name = file_name
super(HumanHyperParamSetter, self).__init__(param)
def _get_current_value(self):
def _get_value_to_set(self):
try:
with open(self.file_name) as f:
lines = f.readlines()
lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines}
ret = dic[self.readable_name]
ret = dic[self.param.readable_name]
return ret
except:
logger.warn(
"Failed to parse {} in {}".format(
self.readable_name, self.file_name))
self.param.readable_name, self.file_name))
return None
class ScheduledHyperParamSetter(HyperParamSetter):
......@@ -118,11 +156,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self.schedule = sorted(schedule, key=operator.itemgetter(0))
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_current_value(self):
def _get_value_to_set(self):
for e, v in self.schedule:
if e == self.epoch_num:
return v
return None
File mode changed from 100755 to 100644
mnist_data
cifar10_data
cifar100_data
svhn_data
ilsvrc_metadata
bsds500_data
......@@ -30,7 +30,7 @@ class TrainConfig(object):
:param model: a `ModelDesc` instance.j
:param starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to 100
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of towers. default to 1.
:param extra_threads_procs: list of `Startable` threads or processes
"""
......@@ -51,7 +51,7 @@ class TrainConfig(object):
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.starting_epoch = int(kwargs.pop('starting_epoch', 1))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
self.max_epoch = int(kwargs.pop('max_epoch', 99999))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
......
......@@ -14,6 +14,8 @@ if six.PY2:
else:
import subprocess
from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment