Commit 5d529d03 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in ProxyCallback when before_run is used

parent 087e66db
......@@ -79,9 +79,8 @@ class Model(ModelDesc):
def get_data(name, batch):
isTrain = name == 'train'
augmentors = fbresnet_augmentor(isTrain)
datadir = args.data
return get_imagenet_dataflow(
datadir, name, batch, augmentors, dir_structure='original')
args.data, name, batch, augmentors, dir_structure='original')
def get_config(model, fake=False):
......@@ -106,8 +105,10 @@ def get_config(model, fake=False):
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
# single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
# multi-GPU inference (with mandatory queue prefetch)
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
......
......@@ -250,7 +250,7 @@ class ProxyCallback(Callback):
self.cb.after_epoch()
def _before_run(self, ctx):
self.cb._before_run(ctx)
return self.cb._before_run(ctx)
def _after_run(self, ctx, run_values):
self.cb._after_run(ctx, run_values)
......
......@@ -146,8 +146,7 @@ class ILSVRC12Files(RNGDataFlow):
class ILSVRC12(ILSVRC12Files):
"""
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure='original'):
......
......@@ -41,7 +41,7 @@ class SVHNDigit(RNGDataFlow):
if not os.path.isfile(filename):
url = SVHN_URL + os.path.basename(filename)
logger.info("File {} not found!".format(filename))
logger.info("Downloading from {}.".format(url))
logger.info("Downloading from {} ...".format(url))
download(url, os.path.dirname(filename))
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
......
......@@ -24,8 +24,9 @@ l1_regularizer = tf.contrib.layers.l1_regularizer
def regularize_cost(regex, func, name='regularize_cost'):
"""
Apply a regularizer on trainable variables matching the regex.
In replicated mode, will only regularize variables within the current tower.
Apply a regularizer on trainable variables matching the regex, and print
the matched variables (only print once in multi-tower training).
In replicated mode, it will only regularize variables within the current tower.
Args:
regex (str): a regex to match variable names, e.g. "conv.*/W"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment