Commit 030a1d31 authored by Yuxin Wu's avatar Yuxin Wu

param dumper

parent 6a2425d0
......@@ -115,9 +115,7 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
sqrcost = tf.square(target - pred_action_value)
abscost = tf.abs(target - pred_action_value) # robust error func
cost = tf.select(abscost < 1, sqrcost, abscost)
cost = symbf.clipped_l2_loss(target - pred_action_value)
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
self.cost = tf.reduce_mean(cost, name='cost')
......
......@@ -39,12 +39,21 @@ def eval_with_funcs(predict_funcs, nr_eval):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self._func = func
self.q = queue
def func(self, *args, **kwargs):
if self.stopped():
raise RuntimeError("stopped!")
return self._func(*args, **kwargs)
def run(self):
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
try:
score = play_one_episode(player, self.func)
except RuntimeError:
return
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=2)
......
......@@ -7,23 +7,39 @@ import argparse
import tensorflow as tf
import imp
from tensorpack.utils import *
from tensorpack import *
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.dataflow import *
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument('--config', help='config file')
parser.add_argument('--meta', help='metagraph file')
parser.add_argument(dest='model')
parser.add_argument(dest='output')
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
assert args.config or args.meta, "Either config or metagraph must be present!"
with tf.Graph().as_default() as G:
config = get_config_func()
config.model.build_graph(config.model.get_input_vars(), is_training=False)
if args.config:
MODEL = imp.load_source('config_script', args.config).Model
M = MODEL()
M.build_graph(M.get_input_vars(), is_training=False)
else:
M = ModelFromMetaGraph(args.meta)
# loading...
init = sessinit.SaverRestore(args.model)
sess = tf.Session()
init.init(sess)
# dump ...
with sess.as_default():
varmanip.dump_session_params(args.output)
if args.output.endswith('npy'):
varmanip.dump_session_params(args.output)
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
logger.info("Variables to dump:")
logger.info(", ".join([v.name for v in var]))
saver = tf.train.Saver(var_list=var)
saver.save(sess, args.output, write_meta_graph=False)
......@@ -24,9 +24,9 @@ class GymEnv(RLEnvironment):
"""
def __init__(self, name, dumpdir=None, viz=False):
self.gymenv = gym.make(name)
#if dumpdir:
#mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0)
if dumpdir:
mkdir_p(dumpdir)
self.gymenv.monitor.start(dumpdir)
self.reset_stat()
self.rwd_counter = StatCounter()
......
......@@ -78,6 +78,13 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
def clipped_l2_loss(x, name=None):
if name is None:
name = 'clipped_l2_loss'
sqrcost = tf.square(x)
abscost = tf.abs(x)
return tf.select(abscost < 1, sqrcost, abscost, name=name)
def get_scalar_var(name, init_value):
return tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value),
......
......@@ -9,6 +9,7 @@ from collections import defaultdict
import re
import numpy as np
from ..utils import logger
from ..utils.naming import *
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment