Commit c95292dc authored by Yuxin Wu's avatar Yuxin Wu

several bug fix and improvement

parent 744defbe
...@@ -189,8 +189,7 @@ class Model(ModelDesc): ...@@ -189,8 +189,7 @@ class Model(ModelDesc):
wd_w = tf.train.exponential_decay(0.00004, get_global_step_var(), wd_w = tf.train.exponential_decay(0.00004, get_global_step_var(),
80000, 0.7, True) 80000, 0.7, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss') wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss')
for k in [loss1, loss2, wd_cost]: add_moving_summary(loss1, loss2, wd_cost)
add_moving_summary(k)
self.cost = tf.add_n([0.4 * loss1, loss2, wd_cost], name='cost') self.cost = tf.add_n([0.4 * loss1, loss2, wd_cost], name='cost')
......
...@@ -59,7 +59,7 @@ class BSDS500(RNGDataFlow): ...@@ -59,7 +59,7 @@ class BSDS500(RNGDataFlow):
image_files = glob.glob(image_glob) image_files = glob.glob(image_glob)
gt_dir = os.path.join(self.data_root, 'groundTruth', name) gt_dir = os.path.join(self.data_root, 'groundTruth', name)
self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8') self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8')
self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='bool') self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='float32')
for idx, f in enumerate(image_files): for idx, f in enumerate(image_files):
im = cv2.imread(f, cv2.IMREAD_COLOR) im = cv2.imread(f, cv2.IMREAD_COLOR)
...@@ -73,14 +73,15 @@ class BSDS500(RNGDataFlow): ...@@ -73,14 +73,15 @@ class BSDS500(RNGDataFlow):
gt = loadmat(gt_file)['groundTruth'][0] gt = loadmat(gt_file)['groundTruth'][0]
n_annot = gt.shape[0] n_annot = gt.shape[0]
gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot)) gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot))
gt[gt < 3] = 0 gt[gt > 3] = 3
gt[gt != 0] = 1 gt = gt / 3.0
if gt.shape[0] > gt.shape[1]: if gt.shape[0] > gt.shape[1]:
gt = gt.transpose() gt = gt.transpose()
assert gt.shape == (IMG_H, IMG_W) assert gt.shape == (IMG_H, IMG_W)
self.data[idx] = im self.data[idx] = im
self.label[idx] = gt self.label[idx] = gt
#self.label[self.label<0.9] = 0
def size(self): def size(self):
return self.data.shape[0] return self.data.shape[0]
......
...@@ -57,6 +57,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -57,6 +57,7 @@ class AugmentImageComponent(MapDataComponent):
class AugmentImagesTogether(MapData): class AugmentImagesTogether(MapData):
""" Augment a list of images of the same shape, with the same parameters"""
def __init__(self, ds, augmentors, index=(0,1)): def __init__(self, ds, augmentors, index=(0,1)):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
......
...@@ -11,7 +11,9 @@ __all__ = ['Rotation'] ...@@ -11,7 +11,9 @@ __all__ = ['Rotation']
class Rotation(ImageAugmentor): class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0,1)): def __init__(self, max_deg, center_range=(0,1),
interp=cv2.INTER_CUBIC,
border=cv2.BORDER_REPLICATE):
""" """
:param max_deg: max abs value of the rotation degree :param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center :param center_range: the location of the rotation center
...@@ -25,6 +27,7 @@ class Rotation(ImageAugmentor): ...@@ -25,6 +27,7 @@ class Rotation(ImageAugmentor):
return cv2.getRotationMatrix2D(tuple(center), deg, 1) return cv2.getRotationMatrix2D(tuple(center), deg, 1)
def _augment(self, img, rot_m): def _augment(self, img, rot_m):
return cv2.warpAffine(img, rot_m, img.shape[1::-1], ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) flags=self.interp, borderMode=self.border)
return ret
...@@ -26,7 +26,7 @@ class GaussianNoise(ImageAugmentor): ...@@ -26,7 +26,7 @@ class GaussianNoise(ImageAugmentor):
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
return self.rng.randn(img.shape) return self.rng.randn(*img.shape)
def _augment(self, img, noise): def _augment(self, img, noise):
ret = img + noise ret = img + noise
......
...@@ -92,16 +92,16 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -92,16 +92,16 @@ def FixedUnPooling(x, shape, unpool_mat=None):
shape = shape2d(shape) shape = shape2d(shape)
# a faster implementation for this special case # a faster implementation for this special case
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None: if shape[0] == 2 and shape[1] == 2 and unpool_mat is None:
return UnPooling2x2ZeroFilled(x) return UnPooling2x2ZeroFilled(x)
input_shape = tf.shape(x) input_shape = tf.shape(x)
if unpool_mat is None: if unpool_mat is None:
mat = np.zeros(shape, dtype='float32') mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1 mat[0][0] = 1
unpool_mat = tf.Variable(mat, trainable=False, name='unpool_mat') unpool_mat = tf.constant(mat, name='unpool_mat')
elif isinstance(unpool_mat, np.ndarray): elif isinstance(unpool_mat, np.ndarray):
unpool_mat = tf.Variable(unpool_mat, trainable=False, name='unpool_mat') unpool_mat = tf.constant(unpool_mat, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape) assert unpool_mat.get_shape().as_list() == list(shape)
# perform a tensor-matrix kronecker product # perform a tensor-matrix kronecker product
......
...@@ -150,7 +150,7 @@ class ParamRestore(SessionInit): ...@@ -150,7 +150,7 @@ class ParamRestore(SessionInit):
variable_names = set([k.name for k in variables]) variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
intersect = variable_names and param_names intersect = variable_names & param_names
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
...@@ -159,12 +159,13 @@ class ParamRestore(SessionInit): ...@@ -159,12 +159,13 @@ class ParamRestore(SessionInit):
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in this graph!".format(k)) logger.warn("Variable {} in the dict not found in this graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
def ChainInit(SessionInit): class ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance.""" """ Init a session by a list of SessionInit instance."""
def __init__(self, sess_inits, new_session=True): def __init__(self, sess_inits, new_session=True):
""" """
......
...@@ -8,6 +8,7 @@ import re ...@@ -8,6 +8,7 @@ import re
from ..utils import * from ..utils import *
from . import get_global_step_var from . import get_global_step_var
from .symbolic_functions import rms
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average'] 'add_moving_summary', 'summary_moving_average']
...@@ -36,8 +37,7 @@ def add_activation_summary(x, name=None): ...@@ -36,8 +37,7 @@ def add_activation_summary(x, name=None):
tf.histogram_summary(name + '/activation', x) tf.histogram_summary(name + '/activation', x)
tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x)) tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x))
tf.scalar_summary( tf.scalar_summary(
name + '/activation_rms', name + '/activation_rms', rms(x))
tf.sqrt(tf.reduce_mean(tf.square(x))))
def add_param_summary(summary_lists): def add_param_summary(summary_lists):
""" """
...@@ -64,12 +64,10 @@ def add_param_summary(summary_lists): ...@@ -64,12 +64,10 @@ def add_param_summary(summary_lists):
tf.scalar_summary(name + '/mean', tf.reduce_mean(var)) tf.scalar_summary(name + '/mean', tf.reduce_mean(var))
return return
if action == 'rms': if action == 'rms':
tf.scalar_summary(name + '/rms', tf.scalar_summary(name + '/rms', rms(var))
tf.sqrt(tf.reduce_mean(tf.square(var))))
return return
raise RuntimeError("Unknown summary type: {}".format(action)) raise RuntimeError("Unknown summary type: {}".format(action))
import re
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
with tf.name_scope('param_summary'): with tf.name_scope('param_summary'):
for p in params: for p in params:
...@@ -84,6 +82,7 @@ def add_param_summary(summary_lists): ...@@ -84,6 +82,7 @@ def add_param_summary(summary_lists):
def add_moving_summary(v, *args): def add_moving_summary(v, *args):
""" """
:param v: tensor or list of tensor to summary :param v: tensor or list of tensor to summary
:param args: tensors to summary
""" """
if not isinstance(v, list): if not isinstance(v, list):
v = [v] v = [v]
......
...@@ -64,11 +64,14 @@ def load_caffe(model_desc, model_file): ...@@ -64,11 +64,14 @@ def load_caffe(model_desc, model_file):
prev_data_shape = net.blobs[prev_blob_name].data.shape[1:] prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
except ValueError: except ValueError:
prev_data_shape = None prev_data_shape = None
logger.info("Processing layer {} of type {}".format(
layername, layer.type))
if layer.type in param_processors: if layer.type in param_processors:
param_dict.update(param_processors[layer.type]( param_dict.update(param_processors[layer.type](
layername, layer.blobs, prev_data_shape)) layername, layer.blobs, prev_data_shape))
else: else:
assert len(layer.blobs) == 0, len(layer.blobs) if len(layer.blobs) != 0:
logger.warn("Layer type {} not supported!".format(layer.type))
logger.info("Model loaded from caffe. Params: " + \ logger.info("Model loaded from caffe. Params: " + \
" ".join(sorted(param_dict.keys()))) " ".join(sorted(param_dict.keys())))
return param_dict return param_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment