Commit 8644248a authored by Yuxin Wu's avatar Yuxin Wu

get_predict_func in simpletrainer. before trying a different inference framework

parent 65fb37b9
......@@ -117,5 +117,6 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()
#QueueInputTrainer(config).train()
SimpleInputTrainer(config).train()
......@@ -20,7 +20,7 @@ class ModelDesc(object):
def get_input_vars(self):
"""
Create or return (if already created) input TF vars in the graph.
Create or return (if already created) raw input TF placeholder vars in the graph.
:returns: the list of raw input vars in the graph
"""
......
......@@ -46,6 +46,7 @@ class PredictConfig(object):
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......
......@@ -47,11 +47,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
:param config: a `PredictConfig`
"""
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
......@@ -67,17 +64,17 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else:
self.outqueue.put((tid, self.func(dp)))
class MultiThreadPredictWorker(threading.Thread):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
#class CurrentSessionPredictor():
#def __init__(self, idx, gpuid, config):
#"""
#:param idx: index of the worker. the 0th worker will print log.
#:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
#:param config: a `PredictConfig`
#"""
#super(MultiProcessPredictWorker, self).__init__()
#self.idx = idx
#self.gpuid = gpuid
#self.config = config
def run(self):
pass
#def run(self):
#pass
......@@ -37,7 +37,7 @@ class SessionInit(object):
class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder"""
def _init(self, sess):
logger.info("Using the current running session ..")
pass
class NewSession(SessionInit):
"""
......
......@@ -27,9 +27,8 @@ class SimpleTrainer(Trainer):
def train(self):
model = self.model
input_vars = model.get_input_vars()
self.input_vars = input_vars
model.build_graph(input_vars, True)
self.input_vars = model.get_input_vars()
model.build_graph(self.input_vars, True)
cost_var = model.get_cost()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
......@@ -53,6 +52,26 @@ class SimpleTrainer(Trainer):
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
input_vars = []
for n in input_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
assert v in self.input_vars
input_vars.append(v)
output_vars = []
for n in output_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
output_vars.append(v)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
......@@ -85,7 +104,6 @@ class EnqueueThread(threading.Thread):
finally:
logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer):
"""
Trainer which builds a FIFO queue for input.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment