Commit 4879a4e4 authored by Yuxin Wu's avatar Yuxin Wu

clean-ups

parent 4f0cb51e
...@@ -8,27 +8,27 @@ import tensorflow as tf ...@@ -8,27 +8,27 @@ import tensorflow as tf
import argparse import argparse
from tensorpack import InputDesc, QueueInput, StagingInput, SyncMultiGPUTrainerReplicated from tensorpack import InputDesc, QueueInput, StagingInput, SyncMultiGPUTrainerReplicated
from tensorpack.dataflow import MapData, FakeData, MapDataComponent from tensorpack.dataflow import FakeData, MapDataComponent
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.contrib.keras import KerasModel from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.tfutils import get_current_tower_context
from tensorflow import keras
from tensorflow.python.keras.layers import * from tensorflow.python.keras.layers import *
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel
TOTAL_BATCH_SIZE = 512 TOTAL_BATCH_SIZE = 512
BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256) BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256)
def bn(x, name, zero_init=False): def bn(x, name, zero_init=False):
return BatchNormalization( return BatchNormalization(
axis=1, name=name, fused=True, axis=1, name=name, fused=True,
momentum=0.9, epsilon=1e-5, momentum=0.9, epsilon=1e-5,
gamma_initializer='zeros' if zero_init else 'ones')(x) gamma_initializer='zeros' if zero_init else 'ones')(x)
def conv(x, filters, kernel, strides=1, name=None): def conv(x, filters, kernel, strides=1, name=None):
return Conv2D(filters, kernel, name=name, return Conv2D(filters, kernel, name=name,
strides=strides, use_bias=False, padding='same', strides=strides, use_bias=False, padding='same',
...@@ -36,6 +36,7 @@ def conv(x, filters, kernel, strides=1, name=None): ...@@ -36,6 +36,7 @@ def conv(x, filters, kernel, strides=1, name=None):
scale=2.0, mode='fan_out', distribution='normal'), scale=2.0, mode='fan_out', distribution='normal'),
kernel_regularizer=tf.keras.regularizers.l2(5e-5))(x) kernel_regularizer=tf.keras.regularizers.l2(5e-5))(x)
def identity_block(input_tensor, kernel_size, filters, stage, block): def identity_block(input_tensor, kernel_size, filters, stage, block):
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
...@@ -56,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -56,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = Activation('relu')(x) x = Activation('relu')(x)
return x return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
...@@ -72,7 +74,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) ...@@ -72,7 +74,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = conv(x, filters3, (1, 1), name=conv_name_base + '2c') x = conv(x, filters3, (1, 1), name=conv_name_base + '2c')
x = bn(x, name=bn_name_base + '2c', zero_init=True) x = bn(x, name=bn_name_base + '2c', zero_init=True)
shortcut = conv(input_tensor, shortcut = conv(
input_tensor,
filters3, (1, 1), strides=strides, filters3, (1, 1), strides=strides,
name=conv_name_base + '1') name=conv_name_base + '1')
shortcut = bn(shortcut, name=bn_name_base + '1') shortcut = bn(shortcut, name=bn_name_base + '1')
...@@ -81,6 +84,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) ...@@ -81,6 +84,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = Activation('relu')(x) x = Activation('relu')(x)
return x return x
def resnet50(inputs): def resnet50(inputs):
input = tf.layers.Input(tensor=inputs[0]) input = tf.layers.Input(tensor=inputs[0])
...@@ -132,7 +136,7 @@ if __name__ == '__main__': ...@@ -132,7 +136,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true') parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
args = parser.parse_args() args = parser.parse_args()
logger.set_logger_dir("train_log/dynamic-bn") logger.set_logger_dir("train_log/imagenet-resnet-keras")
tf.keras.backend.set_image_data_format('channels_first') tf.keras.backend.set_image_data_format('channels_first')
...@@ -154,13 +158,11 @@ if __name__ == '__main__': ...@@ -154,13 +158,11 @@ if __name__ == '__main__':
df_train = MapDataComponent(df_train, one_hot, 1) df_train = MapDataComponent(df_train, one_hot, 1)
df_val = MapDataComponent(df_val, one_hot, 1) df_val = MapDataComponent(df_val, one_hot, 1)
input = QueueInput(df_train)
input = StagingInput(input)
M = KerasModel( M = KerasModel(
resnet50, resnet50,
inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')], inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')], targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')],
input=input, input=df_train,
trainer=SyncMultiGPUTrainerReplicated(nr_gpu)) trainer=SyncMultiGPUTrainerReplicated(nr_gpu))
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
...@@ -185,7 +187,7 @@ if __name__ == '__main__': ...@@ -185,7 +187,7 @@ if __name__ == '__main__':
if not args.fake: if not args.fake:
callbacks.append( callbacks.append(
DataParallelInferenceRunner( DataParallelInferenceRunner(
df_val, ScalarStats(['categorical_accuracy']), list(range(nr_gpu)))) df_val, ScalarStats(['categorical_accuracy']), nr_gpu))
M.fit( M.fit(
steps_per_epoch=100 if args.fake else 1281167 // TOTAL_BATCH_SIZE, steps_per_epoch=100 if args.fake else 1281167 // TOTAL_BATCH_SIZE,
......
...@@ -10,6 +10,7 @@ from tensorflow.python.keras import metrics as metrics_module ...@@ -10,6 +10,7 @@ from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase from ..train.trainers import DistributedTrainerBase
from ..train.interface import apply_default_prefetch
from ..callbacks import ( from ..callbacks import (
Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook, Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook,
ScalarStats) ScalarStats)
...@@ -177,7 +178,7 @@ class KerasModel(object): ...@@ -177,7 +178,7 @@ class KerasModel(object):
get_model ( -> keras.model.Model): get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
targets_desc ([InputDesc]): targets_desc ([InputDesc]):
input (InputSource): input (InputSource | DataFlow):
trainer (Trainer): the default will check the number of available trainer (Trainer): the default will check the number of available
GPUs and use them all. GPUs and use them all.
""" """
...@@ -194,7 +195,7 @@ class KerasModel(object): ...@@ -194,7 +195,7 @@ class KerasModel(object):
assert isinstance(trainer, Trainer), trainer assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase) assert not isinstance(trainer, DistributedTrainerBase)
self.input = input self.input = apply_default_prefetch(input, trainer)
self.trainer = trainer self.trainer = trainer
def compile(self, optimizer, loss, metrics=None): def compile(self, optimizer, loss, metrics=None):
......
...@@ -22,12 +22,16 @@ def apply_default_prefetch(input_source_or_dataflow, trainer): ...@@ -22,12 +22,16 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
Args: Args:
input_source_or_dataflow(InputSource | DataFlow): input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer): trainer (Trainer):
Returns:
InputSource
""" """
if not isinstance(input_source_or_dataflow, InputSource): if not isinstance(input_source_or_dataflow, InputSource):
# to mimic same behavior of the old trainer interface # to mimic same behavior of the old trainer interface
if type(trainer) == SimpleTrainer: if type(trainer) == SimpleTrainer:
input = FeedInput(input_source_or_dataflow) input = FeedInput(input_source_or_dataflow)
else: else:
logger.info("Automatically applying QueueInput on the DataFlow.")
input = QueueInput(input_source_or_dataflow) input = QueueInput(input_source_or_dataflow)
else: else:
input = input_source_or_dataflow input = input_source_or_dataflow
...@@ -39,6 +43,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer): ...@@ -39,6 +43,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)): if not isinstance(input, (StagingInput, DummyConstantInput)):
logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input) input = StagingInput(input)
return input return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment