Commit ce2cc714 authored by Yuxin Wu's avatar Yuxin Wu

refactor inferencer and improve hed

parent 6ef876c9
......@@ -55,8 +55,6 @@ class Model(ModelDesc):
def _build_graph(self, input_vars, is_training):
image, edgemap = input_vars
# TODO fix this
edgemap = tf.identity(edgemap, name='edgemap-tmp')
image = image - tf.constant([104, 116, 122], dtype='float32')
def branch(name, l, up):
......@@ -119,8 +117,8 @@ def get_data(name):
class CropMultiple16(imgaug.ImageAugmentor):
def _get_augment_params(self, img):
newh = img.shape[0] / 16 * 16
neww = img.shape[1] / 16 * 16
newh = img.shape[0] // 16 * 16
neww = img.shape[1] // 16 * 16
assert newh > 0 and neww > 0
diffh = img.shape[0] - newh
h0 = 0 if diffh == 0 else self.rng.randint(diffh)
......@@ -141,10 +139,12 @@ def get_data(name):
imgaug.Flip(horiz=True),
imgaug.Flip(vert=True),
]
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
else:
# the original image shape (320x480) in bsds is already a multiple of 16
pass
# the original image shape (321x481) in BSDS is not a multiple of 16
IMAGE_SHAPE = (320, 480)
shape_aug = [imgaug.RandomCrop(IMAGE_SHAPE)]
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
def f(m):
m[m>=0.49] = 1
m[m<0.49] = 0
......@@ -163,10 +163,12 @@ def get_data(name):
#ds = PrefetchDataZMQ(ds, 3)
return ds
def view_data(ds):
def view_data():
ds = get_data('train')
ds.reset_state()
for ims, edgemaps in ds.get_data():
for im, edgemap in zip(ims, edgemaps):
assert im.shape[0] % 16 == 0 and im.shape[1] % 16 == 0, im.shape
cv2.imshow("im", im / 255.0)
cv2.waitKey(1000)
cv2.imshow("edge", edgemap)
......@@ -191,7 +193,7 @@ def get_config():
HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val,
BinaryClassificationStats('prediction',
'edgemap-tmp'))
'edgemap'))
]),
model=Model(),
step_per_epoch=step_per_epoch,
......@@ -207,7 +209,7 @@ def run(model_path, image_path):
predict_func = get_predict_func(pred_config)
im = cv2.imread(image_path)
assert im is not None
im = cv2.resize(im, (im.shape[0] / 16 * 16, im.shape[1] / 16 * 16))
im = cv2.resize(im, (im.shape[0] // 16 * 16, im.shape[1] // 16 * 16))
outputs = predict_func([[im.astype('float32')]])
for k in range(6):
pred = outputs[k][0]
......@@ -223,8 +225,7 @@ if __name__ == '__main__':
args = parser.parse_args()
if args.view:
ds = get_data('train')
view_data(ds)
view_data()
elif args.run:
run(args.load, args.run)
else:
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import six
from six.moves import zip, map
......@@ -66,10 +67,14 @@ class InferenceRunner(Callback):
A callback that runs different kinds of inferencer.
"""
def __init__(self, ds, infs):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
......@@ -79,27 +84,36 @@ class InferenceRunner(Callback):
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors()
input_names = [x.name for x in self.input_vars]
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors)
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.reuse_input_vars()
self.input_tensors = [x.name for x in input_vars]
def _find_output_tensors(self):
self.output_tensors = [] # list of names
self.inf_to_tensors = [] # list of list of (var_name: output_idx)
for inf in self.infs:
inf_tensors = inf.get_output_tensors()
def find_oid(t):
if t in self.output_tensors:
return self.output_tensors.index(t)
else:
self.output_tensors.append(t)
return len(self.output_tensors) - 1
inf_tensors = [(t, find_oid(t)) for t in inf_tensors]
self.inf_to_tensors.append(inf_tensors)
IOTensor = InferenceRunner.IOTensor
self.output_tensors = []
def find_oid(t):
tensorname = get_op_tensor_name(t)[1]
if tensorname in self.input_tensors:
# this inferencer needs the input dp
return IOTensor(self.input_tensors.index(tensorname), False)
if t in self.output_tensors:
return IOTensor(self.output_tensors.index(t), True)
else:
self.output_tensors.append(t)
return IOTensor(len(self.output_tensors) - 1, True)
self.inf_to_tensors = [
[find_oid(t) for t in inf.get_output_tensors()]
for inf in self.infs]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
......@@ -109,11 +123,10 @@ class InferenceRunner(Callback):
self.ds.reset_state()
with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [outputs[k[1]] for k in tensormap]
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(dp, inf_output)
pbar.update()
......@@ -166,7 +179,7 @@ class ScalarStats(Inferencer):
class ClassificationError(Inferencer):
"""
Compute classification error from a `wrong` variable
Compute classification error in batch mode, from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch.
You can use `tf.nn.in_top_k` to record top-k error as well.
......
......@@ -53,8 +53,8 @@ class RotationAndCropValid(ImageAugmentor):
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0])
newx = center[0] - neww * 0.5
newy = center[1] - newh * 0.5
newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy+newh,newx:newx+neww]
......@@ -81,4 +81,4 @@ class RotationAndCropValid(ImageAugmentor):
cos_2a = cos_a*cos_a - sin_a*sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
return wr,hr
return int(wr), int(hr)
......@@ -33,6 +33,7 @@ class PredictConfig(object):
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
"""
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
# XXX does it work? start with minimal memory, but allow growth.
......
......@@ -13,7 +13,9 @@ __all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_var',
'get_op_var_name',
'get_op_tensor_name',
'get_vars_by_names',
'get_tensors_by_names',
'backup_collection',
'restore_collection',
'clear_collection',
......@@ -53,21 +55,23 @@ def get_global_step():
tf.get_default_session(),
get_global_step_var())
def get_op_var_name(name):
def get_op_tensor_name(name):
"""
Variable name is assumed to be ``op_name + ':0'``
Tensor name is assumed to be ``op_name + ':0'``
:param name: an op or a variable name
:returns: (op_name, variable_name)
:param name: an op or a tensor name
:returns: (op_name, tensor_name)
"""
if name.endswith(':0'):
return name[:-2], name
else:
return name, name + ':0'
def get_vars_by_names(names):
get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names):
"""
Get a list of variables in the default graph by a list of names
Get a list of tensors in the default graph by a list of names
"""
ret = []
G = tf.get_default_graph()
......@@ -76,6 +80,8 @@ def get_vars_by_names(names):
ret.append(G.get_tensor_by_name(varn))
return ret
get_vars_by_names = get_tensors_by_names
def backup_collection(keys):
ret = {}
for k in keys:
......
......@@ -68,7 +68,7 @@ def print_stat(x, message=None):
"""
if message is None:
message = x.op.name
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20,
return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), x], summarize=20,
message=message, name='print_' + x.op.name)
def rms(x, name=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment