Commit 5d529d03 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in ProxyCallback when before_run is used

parent 087e66db
...@@ -79,9 +79,8 @@ class Model(ModelDesc): ...@@ -79,9 +79,8 @@ class Model(ModelDesc):
def get_data(name, batch): def get_data(name, batch):
isTrain = name == 'train' isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
datadir = args.data
return get_imagenet_dataflow( return get_imagenet_dataflow(
datadir, name, batch, augmentors, dir_structure='original') args.data, name, batch, augmentors, dir_structure='original')
def get_config(model, fake=False): def get_config(model, fake=False):
...@@ -106,8 +105,10 @@ def get_config(model, fake=False): ...@@ -106,8 +105,10 @@ def get_config(model, fake=False):
infs = [ClassificationError('wrong-top1', 'val-error-top1'), infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')] ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1: if nr_tower == 1:
# single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs)) callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else: else:
# multi-GPU inference (with mandatory queue prefetch)
callbacks.append(DataParallelInferenceRunner( callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower)))) dataset_val, infs, list(range(nr_tower))))
......
...@@ -250,7 +250,7 @@ class ProxyCallback(Callback): ...@@ -250,7 +250,7 @@ class ProxyCallback(Callback):
self.cb.after_epoch() self.cb.after_epoch()
def _before_run(self, ctx): def _before_run(self, ctx):
self.cb._before_run(ctx) return self.cb._before_run(ctx)
def _after_run(self, ctx, run_values): def _after_run(self, ctx, run_values):
self.cb._after_run(ctx, run_values) self.cb._after_run(ctx, run_values)
......
...@@ -146,8 +146,7 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -146,8 +146,7 @@ class ILSVRC12Files(RNGDataFlow):
class ILSVRC12(ILSVRC12Files): class ILSVRC12(ILSVRC12Files):
""" """
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999], Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
and optionally a bounding box of [xmin, ymin, xmax, ymax].
""" """
def __init__(self, dir, name, meta_dir=None, def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'): shuffle=None, dir_structure='original'):
......
...@@ -41,7 +41,7 @@ class SVHNDigit(RNGDataFlow): ...@@ -41,7 +41,7 @@ class SVHNDigit(RNGDataFlow):
if not os.path.isfile(filename): if not os.path.isfile(filename):
url = SVHN_URL + os.path.basename(filename) url = SVHN_URL + os.path.basename(filename)
logger.info("File {} not found!".format(filename)) logger.info("File {} not found!".format(filename))
logger.info("Downloading from {}.".format(url)) logger.info("Downloading from {} ...".format(url))
download(url, os.path.dirname(filename)) download(url, os.path.dirname(filename))
logger.info("Loading {} ...".format(filename)) logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename) data = scipy.io.loadmat(filename)
......
...@@ -24,8 +24,9 @@ l1_regularizer = tf.contrib.layers.l1_regularizer ...@@ -24,8 +24,9 @@ l1_regularizer = tf.contrib.layers.l1_regularizer
def regularize_cost(regex, func, name='regularize_cost'): def regularize_cost(regex, func, name='regularize_cost'):
""" """
Apply a regularizer on trainable variables matching the regex. Apply a regularizer on trainable variables matching the regex, and print
In replicated mode, will only regularize variables within the current tower. the matched variables (only print once in multi-tower training).
In replicated mode, it will only regularize variables within the current tower.
Args: Args:
regex (str): a regex to match variable names, e.g. "conv.*/W" regex (str): a regex to match variable names, e.g. "conv.*/W"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment