Commit ce2cc714 authored by Yuxin Wu's avatar Yuxin Wu

refactor inferencer and improve hed

parent 6ef876c9
...@@ -55,8 +55,6 @@ class Model(ModelDesc): ...@@ -55,8 +55,6 @@ class Model(ModelDesc):
def _build_graph(self, input_vars, is_training): def _build_graph(self, input_vars, is_training):
image, edgemap = input_vars image, edgemap = input_vars
# TODO fix this
edgemap = tf.identity(edgemap, name='edgemap-tmp')
image = image - tf.constant([104, 116, 122], dtype='float32') image = image - tf.constant([104, 116, 122], dtype='float32')
def branch(name, l, up): def branch(name, l, up):
...@@ -119,8 +117,8 @@ def get_data(name): ...@@ -119,8 +117,8 @@ def get_data(name):
class CropMultiple16(imgaug.ImageAugmentor): class CropMultiple16(imgaug.ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
newh = img.shape[0] / 16 * 16 newh = img.shape[0] // 16 * 16
neww = img.shape[1] / 16 * 16 neww = img.shape[1] // 16 * 16
assert newh > 0 and neww > 0 assert newh > 0 and neww > 0
diffh = img.shape[0] - newh diffh = img.shape[0] - newh
h0 = 0 if diffh == 0 else self.rng.randint(diffh) h0 = 0 if diffh == 0 else self.rng.randint(diffh)
...@@ -141,10 +139,12 @@ def get_data(name): ...@@ -141,10 +139,12 @@ def get_data(name):
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.Flip(vert=True), imgaug.Flip(vert=True),
] ]
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
else: else:
# the original image shape (320x480) in bsds is already a multiple of 16 # the original image shape (321x481) in BSDS is not a multiple of 16
pass IMAGE_SHAPE = (320, 480)
shape_aug = [imgaug.RandomCrop(IMAGE_SHAPE)]
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
def f(m): def f(m):
m[m>=0.49] = 1 m[m>=0.49] = 1
m[m<0.49] = 0 m[m<0.49] = 0
...@@ -163,10 +163,12 @@ def get_data(name): ...@@ -163,10 +163,12 @@ def get_data(name):
#ds = PrefetchDataZMQ(ds, 3) #ds = PrefetchDataZMQ(ds, 3)
return ds return ds
def view_data(ds): def view_data():
ds = get_data('train')
ds.reset_state() ds.reset_state()
for ims, edgemaps in ds.get_data(): for ims, edgemaps in ds.get_data():
for im, edgemap in zip(ims, edgemaps): for im, edgemap in zip(ims, edgemaps):
assert im.shape[0] % 16 == 0 and im.shape[1] % 16 == 0, im.shape
cv2.imshow("im", im / 255.0) cv2.imshow("im", im / 255.0)
cv2.waitKey(1000) cv2.waitKey(1000)
cv2.imshow("edge", edgemap) cv2.imshow("edge", edgemap)
...@@ -191,7 +193,7 @@ def get_config(): ...@@ -191,7 +193,7 @@ def get_config():
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val, InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', BinaryClassificationStats('prediction',
'edgemap-tmp')) 'edgemap'))
]), ]),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
...@@ -207,7 +209,7 @@ def run(model_path, image_path): ...@@ -207,7 +209,7 @@ def run(model_path, image_path):
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
im = cv2.imread(image_path) im = cv2.imread(image_path)
assert im is not None assert im is not None
im = cv2.resize(im, (im.shape[0] / 16 * 16, im.shape[1] / 16 * 16)) im = cv2.resize(im, (im.shape[0] // 16 * 16, im.shape[1] // 16 * 16))
outputs = predict_func([[im.astype('float32')]]) outputs = predict_func([[im.astype('float32')]])
for k in range(6): for k in range(6):
pred = outputs[k][0] pred = outputs[k][0]
...@@ -223,8 +225,7 @@ if __name__ == '__main__': ...@@ -223,8 +225,7 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
if args.view: if args.view:
ds = get_data('train') view_data()
view_data(ds)
elif args.run: elif args.run:
run(args.load, args.run) run(args.load, args.run)
else: else:
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple
import six import six
from six.moves import zip, map from six.moves import zip, map
...@@ -66,10 +67,14 @@ class InferenceRunner(Callback): ...@@ -66,10 +67,14 @@ class InferenceRunner(Callback):
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
""" """
def __init__(self, ds, infs): IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensors=None):
""" """
:param ds: inference dataset. a `DataFlow` instance. :param ds: inference dataset. a `DataFlow` instance.
:param infs: a list of `Inferencer` instance. :param infs: a list of `Inferencer` instance.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
""" """
assert isinstance(ds, DataFlow), type(ds) assert isinstance(ds, DataFlow), type(ds)
self.ds = ds self.ds = ds
...@@ -79,27 +84,36 @@ class InferenceRunner(Callback): ...@@ -79,27 +84,36 @@ class InferenceRunner(Callback):
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), str(v) assert isinstance(v, Inferencer), str(v)
self.input_tensors = input_tensors
def _setup_graph(self): def _setup_graph(self):
self.input_vars = self.trainer.model.reuse_input_vars() self._find_input_tensors() # these are all tensor names
self._find_output_tensors() self._find_output_tensors() # may be either tensor name or op name
input_names = [x.name for x in self.input_vars]
self.pred_func = self.trainer.get_predict_func( self.pred_func = self.trainer.get_predict_func(
input_names, self.output_tensors) self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.reuse_input_vars()
self.input_tensors = [x.name for x in input_vars]
def _find_output_tensors(self): def _find_output_tensors(self):
self.output_tensors = [] # list of names IOTensor = InferenceRunner.IOTensor
self.inf_to_tensors = [] # list of list of (var_name: output_idx) self.output_tensors = []
for inf in self.infs: def find_oid(t):
inf_tensors = inf.get_output_tensors() tensorname = get_op_tensor_name(t)[1]
def find_oid(t): if tensorname in self.input_tensors:
if t in self.output_tensors: # this inferencer needs the input dp
return self.output_tensors.index(t) return IOTensor(self.input_tensors.index(tensorname), False)
else: if t in self.output_tensors:
self.output_tensors.append(t) return IOTensor(self.output_tensors.index(t), True)
return len(self.output_tensors) - 1 else:
inf_tensors = [(t, find_oid(t)) for t in inf_tensors] self.output_tensors.append(t)
self.inf_to_tensors.append(inf_tensors) return IOTensor(len(self.output_tensors) - 1, True)
self.inf_to_tensors = [
[find_oid(t) for t in inf.get_output_tensors()]
for inf in self.infs]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self): def _trigger_epoch(self):
for inf in self.infs: for inf in self.infs:
...@@ -109,11 +123,10 @@ class InferenceRunner(Callback): ...@@ -109,11 +123,10 @@ class InferenceRunner(Callback):
self.ds.reset_state() self.ds.reset_state()
with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar: with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs = self.pred_func(dp) outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [outputs[k[1]] for k in tensormap] inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(dp, inf_output) inf.datapoint(dp, inf_output)
pbar.update() pbar.update()
...@@ -166,7 +179,7 @@ class ScalarStats(Inferencer): ...@@ -166,7 +179,7 @@ class ScalarStats(Inferencer):
class ClassificationError(Inferencer): class ClassificationError(Inferencer):
""" """
Compute classification error from a `wrong` variable Compute classification error in batch mode, from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch. The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch.
You can use `tf.nn.in_top_k` to record top-k error as well. You can use `tf.nn.in_top_k` to record top-k error as well.
......
...@@ -53,8 +53,8 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -53,8 +53,8 @@ class RotationAndCropValid(ImageAugmentor):
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1]) neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0]) newh = min(newh, ret.shape[0])
newx = center[0] - neww * 0.5 newx = int(center[0] - neww * 0.5)
newy = center[1] - newh * 0.5 newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh) #print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy+newh,newx:newx+neww] return ret[newy:newy+newh,newx:newx+neww]
...@@ -81,4 +81,4 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -81,4 +81,4 @@ class RotationAndCropValid(ImageAugmentor):
cos_2a = cos_a*cos_a - sin_a*sin_a cos_2a = cos_a*cos_a - sin_a*sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
return wr,hr return int(wr), int(hr)
...@@ -33,6 +33,7 @@ class PredictConfig(object): ...@@ -33,6 +33,7 @@ class PredictConfig(object):
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False. :param return_input: whether to return (input, output) pair or just output. default to False.
""" """
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
# XXX does it work? start with minimal memory, but allow growth. # XXX does it work? start with minimal memory, but allow growth.
......
...@@ -13,7 +13,9 @@ __all__ = ['get_default_sess_config', ...@@ -13,7 +13,9 @@ __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name', 'get_op_var_name',
'get_op_tensor_name',
'get_vars_by_names', 'get_vars_by_names',
'get_tensors_by_names',
'backup_collection', 'backup_collection',
'restore_collection', 'restore_collection',
'clear_collection', 'clear_collection',
...@@ -53,21 +55,23 @@ def get_global_step(): ...@@ -53,21 +55,23 @@ def get_global_step():
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
def get_op_var_name(name): def get_op_tensor_name(name):
""" """
Variable name is assumed to be ``op_name + ':0'`` Tensor name is assumed to be ``op_name + ':0'``
:param name: an op or a variable name :param name: an op or a tensor name
:returns: (op_name, variable_name) :returns: (op_name, tensor_name)
""" """
if name.endswith(':0'): if name.endswith(':0'):
return name[:-2], name return name[:-2], name
else: else:
return name, name + ':0' return name, name + ':0'
def get_vars_by_names(names): get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names):
""" """
Get a list of variables in the default graph by a list of names Get a list of tensors in the default graph by a list of names
""" """
ret = [] ret = []
G = tf.get_default_graph() G = tf.get_default_graph()
...@@ -76,6 +80,8 @@ def get_vars_by_names(names): ...@@ -76,6 +80,8 @@ def get_vars_by_names(names):
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
get_vars_by_names = get_tensors_by_names
def backup_collection(keys): def backup_collection(keys):
ret = {} ret = {}
for k in keys: for k in keys:
......
...@@ -68,7 +68,7 @@ def print_stat(x, message=None): ...@@ -68,7 +68,7 @@ def print_stat(x, message=None):
""" """
if message is None: if message is None:
message = x.op.name message = x.op.name
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20, return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), x], summarize=20,
message=message, name='print_' + x.op.name) message=message, name='print_' + x.op.name)
def rms(x, name=None): def rms(x, name=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment