Commit 4879a4e4 authored by Yuxin Wu's avatar Yuxin Wu

clean-ups

parent 4f0cb51e
......@@ -8,27 +8,27 @@ import tensorflow as tf
import argparse
from tensorpack import InputDesc, QueueInput, StagingInput, SyncMultiGPUTrainerReplicated
from tensorpack.dataflow import MapData, FakeData, MapDataComponent
from tensorpack.dataflow import FakeData, MapDataComponent
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import *
from tensorpack.tfutils import get_current_tower_context
from tensorflow import keras
from tensorflow.python.keras.layers import *
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel
TOTAL_BATCH_SIZE = 512
BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256)
def bn(x, name, zero_init=False):
return BatchNormalization(
axis=1, name=name, fused=True,
momentum=0.9, epsilon=1e-5,
gamma_initializer='zeros' if zero_init else 'ones')(x)
def conv(x, filters, kernel, strides=1, name=None):
return Conv2D(filters, kernel, name=name,
strides=strides, use_bias=False, padding='same',
......@@ -36,6 +36,7 @@ def conv(x, filters, kernel, strides=1, name=None):
scale=2.0, mode='fan_out', distribution='normal'),
kernel_regularizer=tf.keras.regularizers.l2(5e-5))(x)
def identity_block(input_tensor, kernel_size, filters, stage, block):
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
......@@ -56,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = Activation('relu')(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
......@@ -72,7 +74,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = conv(x, filters3, (1, 1), name=conv_name_base + '2c')
x = bn(x, name=bn_name_base + '2c', zero_init=True)
shortcut = conv(input_tensor,
shortcut = conv(
input_tensor,
filters3, (1, 1), strides=strides,
name=conv_name_base + '1')
shortcut = bn(shortcut, name=bn_name_base + '1')
......@@ -81,6 +84,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = Activation('relu')(x)
return x
def resnet50(inputs):
input = tf.layers.Input(tensor=inputs[0])
......@@ -132,7 +136,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
args = parser.parse_args()
logger.set_logger_dir("train_log/dynamic-bn")
logger.set_logger_dir("train_log/imagenet-resnet-keras")
tf.keras.backend.set_image_data_format('channels_first')
......@@ -154,13 +158,11 @@ if __name__ == '__main__':
df_train = MapDataComponent(df_train, one_hot, 1)
df_val = MapDataComponent(df_val, one_hot, 1)
input = QueueInput(df_train)
input = StagingInput(input)
M = KerasModel(
resnet50,
inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')],
input=input,
input=df_train,
trainer=SyncMultiGPUTrainerReplicated(nr_gpu))
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
......@@ -185,7 +187,7 @@ if __name__ == '__main__':
if not args.fake:
callbacks.append(
DataParallelInferenceRunner(
df_val, ScalarStats(['categorical_accuracy']), list(range(nr_gpu))))
df_val, ScalarStats(['categorical_accuracy']), nr_gpu))
M.fit(
steps_per_epoch=100 if args.fake else 1281167 // TOTAL_BATCH_SIZE,
......
......@@ -10,6 +10,7 @@ from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase
from ..train.interface import apply_default_prefetch
from ..callbacks import (
Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook,
ScalarStats)
......@@ -177,7 +178,7 @@ class KerasModel(object):
get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource):
input (InputSource | DataFlow):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
......@@ -194,7 +195,7 @@ class KerasModel(object):
assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase)
self.input = input
self.input = apply_default_prefetch(input, trainer)
self.trainer = trainer
def compile(self, optimizer, loss, metrics=None):
......
......@@ -22,12 +22,16 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
Args:
input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer):
Returns:
InputSource
"""
if not isinstance(input_source_or_dataflow, InputSource):
# to mimic same behavior of the old trainer interface
if type(trainer) == SimpleTrainer:
input = FeedInput(input_source_or_dataflow)
else:
logger.info("Automatically applying QueueInput on the DataFlow.")
input = QueueInput(input_source_or_dataflow)
else:
input = input_source_or_dataflow
......@@ -39,6 +43,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)):
logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input)
return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment