Commit ff0a4528 authored by Yuxin Wu's avatar Yuxin Wu

separate predict related code

parent d7a85f44
...@@ -266,7 +266,6 @@ if __name__ == '__main__': ...@@ -266,7 +266,6 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
global ROM_FILE
ROM_FILE = args.rom ROM_FILE = args.rom
if args.task == 'play': if args.task == 'play':
......
...@@ -36,7 +36,7 @@ class AtariPlayer(RLEnvironment): ...@@ -36,7 +36,7 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000))) self.ale.setInt("random_seed", self.rng.randint(0, 1000))
self.ale.setInt("frame_skip", frame_skip) self.ale.setInt("frame_skip", frame_skip)
self.ale.setBool('color_averaging', True) self.ale.setBool('color_averaging', True)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
...@@ -125,7 +125,7 @@ if __name__ == '__main__': ...@@ -125,7 +125,7 @@ if __name__ == '__main__':
#im = a.grab_image() #im = a.grab_image()
#cv2.imshow(a.romname, im) #cv2.imshow(a.romname, im)
act = rng.choice(range(num)) act = rng.choice(range(num))
print act print(act)
r, o = a.action(act) r, o = a.action(act)
a.current_state() a.current_state()
#time.sleep(0.1) #time.sleep(0.1)
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: predict.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import numpy as np
from collections import namedtuple from collections import namedtuple
from tqdm import tqdm from six.moves import zip
from six.moves import zip, range
import multiprocessing from ..tfutils import *
from .utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .tfutils import * import multiprocessing
from .utils import logger
from .tfutils.modelutils import describe_model
from .dataflow import DataFlow, BatchData
from .dataflow.dftools import dataflow_to_process_queue
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func', __all__ = ['PredictConfig', 'get_predict_func', 'PredictResult' ]
'ParallelPredictWorker']
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
...@@ -27,7 +19,6 @@ class PredictConfig(object): ...@@ -27,7 +19,6 @@ class PredictConfig(object):
""" """
The config used by `get_predict_func`. The config used by `get_predict_func`.
:param session_config: a `tf.ConfigProto` instance to instantiate the session.
:param session_init: a `utils.sessinit.SessionInit` instance to :param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session. initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data :param input_data_mapping: Decide the mapping from each component in data
...@@ -68,6 +59,7 @@ class PredictConfig(object): ...@@ -68,6 +59,7 @@ class PredictConfig(object):
def get_predict_func(config): def get_predict_func(config):
""" """
Produce a simple predictor function in a newly-created session without any parallelism.
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return :returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``. a list of output values defined in ``config.output_var_names``.
...@@ -86,10 +78,7 @@ def get_predict_func(config): ...@@ -86,10 +78,7 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names] for n in output_var_names]
if config.session_config: sess = tf.Session()
sess = tf.Session(config=config.session_config)
else:
sess = tf.Session()
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
...@@ -99,119 +88,3 @@ def get_predict_func(config): ...@@ -99,119 +88,3 @@ def get_predict_func(config):
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
return run_input return run_input
class ParallelPredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(ParallelPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
def _init_runtime(self):
if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else:
logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = ''
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
class QueuePredictWorker(ParallelPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(QueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
self.outqueue = outqueue
def run(self):
self._init_runtime()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))
class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = list(range(self.nr_gpu))
self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self):
"""
Run over the dataset and return a list of all predictions.
"""
return list(self.get_result())
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import tensorflow as tf
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
from ..utils import logger
from ..tfutils import *
from .common import *
__all__ = ['ParallelPredictWorker', 'QueuePredictWorker']
class ParallelPredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(ParallelPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
def _init_runtime(self):
if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpuid)
else:
logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = ''
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default():
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
class QueuePredictWorker(ParallelPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(QueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
self.outqueue = outqueue
def run(self):
self._init_runtime()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dataset.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from six.moves import range
from tqdm import tqdm
from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .concurrency import *
__all__ = ['DatasetPredictor']
class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = list(range(self.nr_gpu))
self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self):
"""
Run over the dataset and return a list of all predictions.
"""
return list(self.get_result())
...@@ -12,9 +12,13 @@ import six ...@@ -12,9 +12,13 @@ import six
from ..utils import logger from ..utils import logger
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'ParamRestore', __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
'ParamRestore',
'JustCurrentSession',
'dump_session_params'] 'dump_session_params']
# TODO they initialize_all at the beginning by default.
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session""" """ Base class for utilities to initialize a session"""
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -30,6 +34,11 @@ class SessionInit(object): ...@@ -30,6 +34,11 @@ class SessionInit(object):
def _init(self, sess): def _init(self, sess):
pass pass
class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder"""
def _init(self, sess):
logger.info("Using the current running session ..")
class NewSession(SessionInit): class NewSession(SessionInit):
""" """
Create a new session. All variables will be initialized by their Create a new session. All variables will be initialized by their
...@@ -139,7 +148,7 @@ class ParamRestore(SessionInit): ...@@ -139,7 +148,7 @@ class ParamRestore(SessionInit):
def dump_session_params(path): def dump_session_params(path):
""" Dump value of all trainable variables to a dict and save to `path` as """ Dump value of all trainable variables to a dict and save to `path` as
npy format. npy format, loadable by ParamRestore
""" """
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {} result = {}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment