Commit 712fd299 authored by Yuxin Wu's avatar Yuxin Wu

Not all Keras variables are marked trainable (#748)

parent ac72ab73
...@@ -136,7 +136,7 @@ if __name__ == '__main__': ...@@ -136,7 +136,7 @@ if __name__ == '__main__':
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true') parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
args = parser.parse_args() args = parser.parse_args()
logger.set_logger_dir("train_log/imagenet-resnet-keras") logger.set_logger_dir(os.path.join("train_log", "imagenet-resnet-keras"))
tf.keras.backend.set_image_data_format('channels_first') tf.keras.backend.set_image_data_format('channels_first')
......
...@@ -78,7 +78,7 @@ class KerasModelCaller(object): ...@@ -78,7 +78,7 @@ class KerasModelCaller(object):
for v in M.weights: for v in M.weights:
# In Keras, the collection is not respected and could contain non-trainable vars. # In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead. # We put M.weights into the collection instead.
if v.name not in old_trainable_names: if v.name not in old_trainable_names and v.name in added_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v) tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()]) new_trainable_names = set([x.name for x in tf.trainable_variables()])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment