Commit 1b06a41a authored by Yuxin Wu's avatar Yuxin Wu

Inceptionv3, compute batch size from --gpu option instead of hard-coded. (#246)

parent ece733a9
...@@ -192,4 +192,5 @@ if __name__ == '__main__': ...@@ -192,4 +192,5 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
assert config.nr_tower == NR_GPU
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -22,13 +22,11 @@ This config follows the official inceptionv3 setup ...@@ -22,13 +22,11 @@ This config follows the official inceptionv3 setup
(https://github.com/tensorflow/models/tree/master/inception/inception) (https://github.com/tensorflow/models/tree/master/inception/inception)
with much much fewer lines of code. with much much fewer lines of code.
It reaches 74% single-crop validation accuracy, similar to the official code. It reaches 74% single-crop validation accuracy, similar to the official code.
The hyperparameters here are for 8 GPUs, so the effective batch size is 8*64 = 512.
""" """
TOTAL_BATCH_SIZE = 512 TOTAL_BATCH_SIZE = 512
NR_GPU = 8 NR_GPU = None
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = None
INPUT_SHAPE = 299 INPUT_SHAPE = 299
...@@ -285,19 +283,19 @@ def get_config(): ...@@ -285,19 +283,19 @@ def get_config():
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
logger.auto_set_dir() logger.auto_set_dir()
if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = NR_GPU
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -235,7 +235,7 @@ def eval_on_ILSVRC12(model_file, data_dir): ...@@ -235,7 +235,7 @@ def eval_on_ILSVRC12(model_file, data_dir):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true') parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
...@@ -247,14 +247,13 @@ if __name__ == '__main__': ...@@ -247,14 +247,13 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DEPTH = args.depth DEPTH = args.depth
if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval: if args.eval:
BATCH_SIZE = 128 # something that can run on one gpu BATCH_SIZE = 128 # something that can run on one gpu
eval_on_ILSVRC12(args.load, args.data) eval_on_ILSVRC12(args.load, args.data)
sys.exit() sys.exit()
assert args.gpu is not None, "Need to specify a list of gpu for training!"
NR_GPU = len(args.gpu.split(',')) NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
......
...@@ -9,6 +9,7 @@ from .base import RNGDataFlow ...@@ -9,6 +9,7 @@ from .base import RNGDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
...@@ -20,13 +21,13 @@ class ImageFromFile(RNGDataFlow): ...@@ -20,13 +21,13 @@ class ImageFromFile(RNGDataFlow):
Args: Args:
files (list): list of file paths. files (list): list of file paths.
channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3. channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3.
resize (tuple): (h, w). If given, resize the image. resize (tuple): int or (h, w) tuple. If given, resize the image.
""" """
assert len(files), "No image files given to ImageFromFile!" assert len(files), "No image files given to ImageFromFile!"
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
self.resize = resize self.resize = shape2d(resize)
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def size(self):
...@@ -40,7 +41,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -40,7 +41,7 @@ class ImageFromFile(RNGDataFlow):
if self.channel == 3: if self.channel == 3:
im = im[:, :, ::-1] im = im[:, :, ::-1]
if self.resize is not None: if self.resize is not None:
im = cv2.resize(im, self.resize[::-1]) im = cv2.resize(im, tuple(self.resize[::-1]))
if self.channel == 1: if self.channel == 1:
im = im[:, :, np.newaxis] im = im[:, :, np.newaxis]
yield [im] yield [im]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment