Commit 0d20cb3d authored by Yuxin Wu's avatar Yuxin Wu

new inference framework based on trainer.get_predict_func

parent 8644248a
......@@ -8,14 +8,9 @@ import argparse
import numpy as np
import os
from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
"""
A small convnet model for cifar 10 or cifar100 dataset.
......@@ -65,7 +60,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
wrong = symbf.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
......@@ -161,4 +156,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
#QueueInputTrainer(config).train()
SimpleTrainer(config).train()
......@@ -9,9 +9,6 @@ import os, sys
import argparse
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.callbacks import *
"""
MNIST ConvNet example.
......@@ -117,6 +114,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
#QueueInputTrainer(config).train()
SimpleInputTrainer(config).train()
QueueInputTrainer(config).train()
......@@ -63,7 +63,7 @@ class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
type = TestCallbackType()
#type = TestCallbackType()
def __init__(self, ds, vcs):
"""
......@@ -82,12 +82,15 @@ class InferenceRunner(Callback):
def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors()
input_names = [x.name for x in self.input_vars]
self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors)
for v in self.vcs:
v.trainer = self.trainer
def _find_output_tensors(self):
self.output_tensors = []
self.vc_to_vars = []
self.output_tensors = [] # list of names
self.vc_to_vars = [] # list of list of (var_name: output_idx)
for vc in self.vcs:
vc_vars = vc._get_output_tensors()
def find_oid(var):
......@@ -99,12 +102,6 @@ class InferenceRunner(Callback):
vc_vars = [(var, find_oid(var)) for var in vc_vars]
self.vc_to_vars.append(vc_vars)
# convert name to tensors
def get_tensor(name):
_, varname = get_op_var_name(name)
return self.graph.get_tensor_by_name(varname)
self.output_tensors = list(map(get_tensor, self.output_tensors))
def _trigger_epoch(self):
for vc in self.vcs:
vc.before_inference()
......@@ -112,8 +109,9 @@ class InferenceRunner(Callback):
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
outputs = sess.run(self.output_tensors, feed_dict=feed)
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs = self.pred_func(dp)
for vc, varsmap in zip(self.vcs, self.vc_to_vars):
vc_output = [outputs[k[1]] for k in varsmap]
vc.datapoint(dp, vc_output)
......
......@@ -9,7 +9,9 @@ import tensorflow as tf
__all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_var',
'get_op_var_name']
'get_op_var_name',
'get_vars_by_names'
]
def get_default_sess_config(mem_fraction=0.9):
"""
......@@ -53,3 +55,14 @@ def get_op_var_name(name):
return name[:-2], name
else:
return name, name + ':0'
def get_vars_by_names(names):
"""
Get a list of variables in the default graph by a list of names
"""
ret = []
G = tf.get_default_graph()
for n in names:
opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
......@@ -50,6 +50,11 @@ class Trainer(object):
""" run an iteration"""
pass
@abstractmethod
def get_predict_func(self, input_names, output_names):
""" return a predict function"""
pass
def trigger_epoch(self):
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
......
......@@ -53,25 +53,16 @@ class SimpleTrainer(Trainer):
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
input_vars = []
for n in input_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
input_vars = get_vars_by_names(input_names)
for v in input_vars:
assert v in self.input_vars
input_vars.append(v)
output_vars = []
for n in output_names:
opn, varn = get_op_var_name(n)
v = tf.get_default_graph().get_tensor_by_name(varn)
output_vars.append(v)
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
......@@ -126,6 +117,7 @@ class QueueInputTrainer(Trainer):
self.async = async
if self.async:
assert self.config.nr_tower > 1
self._dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
......@@ -148,6 +140,7 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
self._dequed_inputs.append(ret)
return ret
def _single_tower_grad(self):
......@@ -248,6 +241,27 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
if self.config.nr_tower == 1:
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_vars = get_vars_by_names(output_names)
else:
# TODO naive impl: use the first tower only
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_names = ['tower0/' + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def start_train(config):
tr = QueueInputTrainer(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment