Commit 8644248a authored by Yuxin Wu's avatar Yuxin Wu

get_predict_func in simpletrainer. before trying a different inference framework

parent 65fb37b9
...@@ -117,5 +117,6 @@ if __name__ == '__main__': ...@@ -117,5 +117,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train() #QueueInputTrainer(config).train()
SimpleInputTrainer(config).train()
...@@ -20,7 +20,7 @@ class ModelDesc(object): ...@@ -20,7 +20,7 @@ class ModelDesc(object):
def get_input_vars(self): def get_input_vars(self):
""" """
Create or return (if already created) input TF vars in the graph. Create or return (if already created) raw input TF placeholder vars in the graph.
:returns: the list of raw input vars in the graph :returns: the list of raw input vars in the graph
""" """
......
...@@ -46,6 +46,7 @@ class PredictConfig(object): ...@@ -46,6 +46,7 @@ class PredictConfig(object):
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False. :param return_input: whether to produce (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
......
...@@ -47,11 +47,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -47,11 +47,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """ """ A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point. elements are (task_id, dp) :param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output) :param outqueue: output queue put result. elements are (task_id, output)
:param config: a `PredictConfig`
""" """
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue self.inqueue = inqueue
...@@ -67,17 +64,17 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -67,17 +64,17 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else: else:
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class MultiThreadPredictWorker(threading.Thread): #class CurrentSessionPredictor():
def __init__(self, idx, gpuid, config): #def __init__(self, idx, gpuid, config):
""" #"""
:param idx: index of the worker. the 0th worker will print log. #:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU. #:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig` #:param config: a `PredictConfig`
""" #"""
super(MultiProcessPredictWorker, self).__init__() #super(MultiProcessPredictWorker, self).__init__()
self.idx = idx #self.idx = idx
self.gpuid = gpuid #self.gpuid = gpuid
self.config = config #self.config = config
def run(self): #def run(self):
pass #pass
...@@ -37,7 +37,7 @@ class SessionInit(object): ...@@ -37,7 +37,7 @@ class SessionInit(object):
class JustCurrentSession(SessionInit): class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder""" """ Just use the current default session. This is a no-op placeholder"""
def _init(self, sess): def _init(self, sess):
logger.info("Using the current running session ..") pass
class NewSession(SessionInit): class NewSession(SessionInit):
""" """
......
...@@ -27,9 +27,8 @@ class SimpleTrainer(Trainer): ...@@ -27,9 +27,8 @@ class SimpleTrainer(Trainer):
def train(self): def train(self):
model = self.model model = self.model
input_vars = model.get_input_vars() self.input_vars = model.get_input_vars()
self.input_vars = input_vars model.build_graph(self.input_vars, True)
model.build_graph(input_vars, True)
cost_var = model.get_cost() cost_var = model.get_cost()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
...@@ -53,6 +52,26 @@ class SimpleTrainer(Trainer): ...@@ -53,6 +52,26 @@ class SimpleTrainer(Trainer):
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
input_vars = []
for n in input_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
assert v in self.input_vars
input_vars.append(v)
output_vars = []
for n in output_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
output_vars.append(v)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var): def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
...@@ -85,7 +104,6 @@ class EnqueueThread(threading.Thread): ...@@ -85,7 +104,6 @@ class EnqueueThread(threading.Thread):
finally: finally:
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """
Trainer which builds a FIFO queue for input. Trainer which builds a FIFO queue for input.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment