Commit 175bc41a authored by Yuxin Wu's avatar Yuxin Wu

some updates to trainer

parent de6d5502
...@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config ...@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with tf.Graph().as_default() as G: with tf.Graph().as_default() as G:
config = get_config_func() config = get_config_func()
config.model.get_cost(config.model.get_input_vars(), is_training=False) config.model.build_graph(config.model.get_input_vars(), is_training=False)
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
sess = tf.Session() sess = tf.Session()
init.init(sess) init.init(sess)
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import cv2 import cv2
from collections import deque from collections import deque
from ...utils import get_rng from ...utils import get_rng
from . import RLEnvironment from .rlenv import RLEnvironment
__all__ = ['AtariDriver', 'AtariPlayer'] __all__ = ['AtariDriver', 'AtariPlayer']
...@@ -32,6 +32,8 @@ class AtariDriver(object): ...@@ -32,6 +32,8 @@ class AtariDriver(object):
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
if isinstance(viz, int):
viz = float(viz)
self.viz = viz self.viz = viz
self.romname = os.path.basename(rom_file) self.romname = os.path.basename(rom_file)
if self.viz and isinstance(self.viz, float): if self.viz and isinstance(self.viz, float):
...@@ -64,6 +66,7 @@ class AtariDriver(object): ...@@ -64,6 +66,7 @@ class AtariDriver(object):
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret) cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1 self.framenum += 1
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
ret = ret[36:204,:] # several online repos all use this
return ret return ret
def get_num_actions(self): def get_num_actions(self):
...@@ -109,6 +112,7 @@ class AtariPlayer(RLEnvironment): ...@@ -109,6 +112,7 @@ class AtariPlayer(RLEnvironment):
""" """
self.frames.clear() self.frames.clear()
s = self.driver.grab_image() s = self.driver.grab_image()
s = cv2.resize(s, self.image_shape) s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len): for _ in range(self.hist_len):
self.frames.append(s) self.frames.append(s)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
__all__ = ['RLEnvironment'] __all__ = ['RLEnvironment', 'NaiveRLEnvironment']
class RLEnvironment(object): class RLEnvironment(object):
__meta__ = ABCMeta __meta__ = ABCMeta
...@@ -23,3 +23,15 @@ class RLEnvironment(object): ...@@ -23,3 +23,15 @@ class RLEnvironment(object):
:params act: the action :params act: the action
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
class NaiveRLEnvironment(RLEnvironment):
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
...@@ -36,6 +36,7 @@ class Trainer(object): ...@@ -36,6 +36,7 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
self.model = config.model self.model = config.model
self.extra_threads_procs = config.extra_threads_procs
@abstractmethod @abstractmethod
def train(self): def train(self):
...@@ -84,7 +85,7 @@ class Trainer(object): ...@@ -84,7 +85,7 @@ class Trainer(object):
callbacks.setup_graph(self) callbacks.setup_graph(self)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
self._start_all_threads() self._start_concurrency()
with self.sess.as_default(): with self.sess.as_default():
try: try:
...@@ -121,12 +122,15 @@ class Trainer(object): ...@@ -121,12 +122,15 @@ class Trainer(object):
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
def _start_all_threads(self): def _start_concurrency(self):
""" """
Run all threads before starting training Run all threads before starting training
""" """
tf.train.start_queue_runners( tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
for k in self.extra_threads_procs:
k.start()
def process_grads(self, grads): def process_grads(self, grads):
g = [] g = []
......
...@@ -32,6 +32,7 @@ class TrainConfig(object): ...@@ -32,6 +32,7 @@ class TrainConfig(object):
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch. :param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to 100 :param max_epoch: maximum number of epoch to run training. default to 100
:param nr_tower: int. number of towers. default to 1. :param nr_tower: int. number of towers. default to 1.
:param extra_threads_procs: list of `Startable` threads or processes
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -53,5 +54,6 @@ class TrainConfig(object): ...@@ -53,5 +54,6 @@ class TrainConfig(object):
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -209,12 +209,9 @@ class QueueInputTrainer(Trainer): ...@@ -209,12 +209,9 @@ class QueueInputTrainer(Trainer):
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars) self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
self.main_loop() self.main_loop()
def _start_all_threads(self):
super(QueueInputTrainer, self)._start_all_threads()
self.input_th.start()
def run_step(self): def run_step(self):
if self.async: if self.async:
if not self.async_running: if not self.async_running:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment