Commit 712fd299 authored by Yuxin Wu's avatar Yuxin Wu

Not all Keras variables are marked trainable (#748)

parent ac72ab73
......@@ -136,7 +136,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
args = parser.parse_args()
logger.set_logger_dir("train_log/imagenet-resnet-keras")
logger.set_logger_dir(os.path.join("train_log", "imagenet-resnet-keras"))
tf.keras.backend.set_image_data_format('channels_first')
......
......@@ -78,7 +78,7 @@ class KerasModelCaller(object):
for v in M.weights:
# In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead.
if v.name not in old_trainable_names:
if v.name not in old_trainable_names and v.name in added_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment