Commit be39dbdf authored by Yuxin Wu's avatar Yuxin Wu

fix deprecation about dropout; fix Keras compatibility in tf1.13

parent 79148350
...@@ -100,7 +100,7 @@ class Model(ModelDesc): ...@@ -100,7 +100,7 @@ class Model(ModelDesc):
.apply(fg) .apply(fg)
.BatchNorm('bn5').apply(activate) .BatchNorm('bn5').apply(activate)
# 5 # 5
.tf.nn.dropout(0.5 if is_training else 1.0) .Dropout(rate=0.5 if is_training else 0.0)
.Conv2D('conv6', 512, 5, padding='VALID') .Conv2D('conv6', 512, 5, padding='VALID')
.apply(fg).BatchNorm('bn6') .apply(fg).BatchNorm('bn6')
.apply(nonlin) .apply(nonlin)
......
This diff is collapsed.
...@@ -33,7 +33,7 @@ class Model(ModelDesc): ...@@ -33,7 +33,7 @@ class Model(ModelDesc):
def build_graph(self, image, label): def build_graph(self, image, label):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0) drop_rate = tf.constant(0.5 if is_training else 0.0)
if is_training: if is_training:
tf.summary.image("train_image", image, 10) tf.summary.image("train_image", image, 10)
...@@ -56,7 +56,7 @@ class Model(ModelDesc): ...@@ -56,7 +56,7 @@ class Model(ModelDesc):
.Conv2D('conv3.1', filters=128, padding='VALID') \ .Conv2D('conv3.1', filters=128, padding='VALID') \
.Conv2D('conv3.2', filters=128, padding='VALID') \ .Conv2D('conv3.2', filters=128, padding='VALID') \
.FullyConnected('fc0', 1024 + 512, activation=tf.nn.relu) \ .FullyConnected('fc0', 1024 + 512, activation=tf.nn.relu) \
.tf.nn.dropout(keep_prob) \ .Dropout(rate=drop_rate) \
.FullyConnected('fc1', 512, activation=tf.nn.relu) \ .FullyConnected('fc1', 512, activation=tf.nn.relu) \
.FullyConnected('linear', out_dim=self.cifar_classnum)() .FullyConnected('linear', out_dim=self.cifar_classnum)()
......
...@@ -147,8 +147,8 @@ if __name__ == '__main__': ...@@ -147,8 +147,8 @@ if __name__ == '__main__':
num_gpu = get_num_gpu() num_gpu = get_num_gpu()
if args.fake: if args.fake:
df_train = FakeData([[64, 224, 224, 3], [64, 1000]], 5000, random=False, dtype='uint8') df_train = FakeData([[32, 224, 224, 3], [32, 1000]], 5000, random=False, dtype='uint8')
df_val = FakeData([[64, 224, 224, 3], [64, 1000]], 5000, random=False) df_val = FakeData([[32, 224, 224, 3], [32, 1000]], 5000, random=False)
else: else:
batch_size = TOTAL_BATCH_SIZE // num_gpu batch_size = TOTAL_BATCH_SIZE // num_gpu
assert args.data is not None assert args.data is not None
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import tensorflow as tf import tensorflow as tf
import six import six
from tensorflow import keras from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import metrics as metrics_module
from contextlib import contextmanager
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
...@@ -82,7 +84,19 @@ class KerasModelCaller(object): ...@@ -82,7 +84,19 @@ class KerasModelCaller(object):
if self.cached_model is None: if self.cached_model is None:
assert not reuse assert not reuse
model = self.cached_model = self.get_model(*input_tensors)
# starting from some versions, tf.keras starts to prepend name scope to variable names ..
@contextmanager
def clear_tower0_name_scope():
ns = tf.get_default_graph().get_name_scope()
if ns == 'tower0':
with tf.name_scope('/'):
yield
else:
yield
with clear_tower0_name_scope():
model = self.cached_model = self.get_model(*input_tensors)
outputs = model.outputs outputs = model.outputs
elif reuse: elif reuse:
# use the cached Keras model to mimic reuse # use the cached Keras model to mimic reuse
...@@ -108,7 +122,7 @@ class KerasPhaseCallback(Callback): ...@@ -108,7 +122,7 @@ class KerasPhaseCallback(Callback):
def __init__(self, isTrain): def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain self._isTrain = isTrain
self._learning_phase = keras.backend.learning_phase() self._learning_phase = K.learning_phase()
def _setup_graph(self): def _setup_graph(self):
logger.info("Using Keras learning phase {} in the graph!".format( logger.info("Using Keras learning phase {} in the graph!".format(
...@@ -200,7 +214,8 @@ def setup_keras_trainer( ...@@ -200,7 +214,8 @@ def setup_keras_trainer(
input, input,
get_cost, get_cost,
lambda: optimizer) lambda: optimizer)
if model_caller.cached_model.uses_learning_phase: if len(K.learning_phase().consumers()) > 0:
# check if learning_phase is used in this model
trainer.register_callback(KerasPhaseCallback(True)) trainer.register_callback(KerasPhaseCallback(True))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment