Commit 0831fe9d authored by Yuxin Wu's avatar Yuxin Wu

Rename eval_on_ILSVRC12->eval_classification; fix #1194

parent 1e9342a5
......@@ -18,7 +18,7 @@ from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.utils.gpu import get_num_gpu
from dorefa import get_dorefa, ternarize
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, fbresnet_augmentor, get_imagenet_dataflow
"""
This is a tensorpack script for the ImageNet results in paper:
......@@ -219,7 +219,7 @@ if __name__ == '__main__':
if args.eval:
BATCH_SIZE = 128
ds = get_data('val')
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), get_model_loader(args.load), ds)
sys.exit()
nr_tower = max(get_num_gpu(), 1)
......
......@@ -13,7 +13,7 @@ from tensorpack.dataflow import dataset
from tensorpack.tfutils.varreplace import remap_variables
from dorefa import get_dorefa
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
from imagenet_utils import ImageNetModel, eval_classification, fbresnet_augmentor
"""
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
......@@ -163,7 +163,7 @@ if __name__ == '__main__':
ds = dataset.ILSVRC12(args.data, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True)
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), get_model_loader(args.load), ds)
elif args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
......@@ -264,7 +264,11 @@ def fbresnet_mapper(isTrain):
"""
def eval_on_ILSVRC12(model, sessinit, dataflow):
def eval_classification(model, sessinit, dataflow):
"""
Eval a classification model on the dataset. It assumes the model inputs are
named "input" and "label", and contains "wrong-top1" and "wrong-top5" in the graph.
"""
pred_config = PredictConfig(
model=model,
session_init=sessinit,
......
......@@ -16,7 +16,7 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow
@layer_register(log_shape=True)
......@@ -251,7 +251,7 @@ if __name__ == '__main__':
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
eval_classification(model, get_model_loader(args.load), ds)
elif args.flops:
# manually build the graph with batch=1
with TowerContext('', is_training=False):
......
......@@ -13,7 +13,7 @@ from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow, get_imagenet_tfdata
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow, get_imagenet_tfdata
from resnet_model import (
preresnet_basicblock, preresnet_bottleneck, preresnet_group,
resnet_backbone, resnet_group,
......@@ -143,7 +143,7 @@ if __name__ == '__main__':
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_imagenet_dataflow(args.data, 'val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
eval_classification(model, get_model_loader(args.load), ds)
else:
if args.fake:
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
......
......@@ -16,7 +16,7 @@ from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta
from tensorpack.utils import logger
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow
from resnet_model import resnet_bottleneck, resnet_group
DEPTH = None
......@@ -172,6 +172,6 @@ if __name__ == '__main__':
if args.eval:
ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor())
eval_on_ILSVRC12(Model(), DictRestore(param), ds)
eval_classification(Model(), DictRestore(param), ds)
elif args.input:
run_test(param, args.input)
......@@ -98,7 +98,7 @@ class KerasModelCaller(object):
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
model = self.cached_model
outputs = model.call(input_tensors)
outputs = model.call(*input_tensors)
else:
# create new Keras model if not reuse
model = self.get_model(*input_tensors)
......
......@@ -45,7 +45,7 @@ class ExceptionHandler:
class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """
""" Produce images read from a list of files as (h, w, c) arrays. """
def __init__(self, files, channel=3, resize=None, shuffle=False):
"""
Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment