Commit 0937a01f authored by Yuxin Wu's avatar Yuxin Wu

small fix on keras examples

parent 0430c07c
...@@ -83,8 +83,7 @@ class Model(ModelDesc): ...@@ -83,8 +83,7 @@ class Model(ModelDesc):
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False) lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3) opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors( return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()])
opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()])
@staticmethod @staticmethod
def update_target_param(): def update_target_param():
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import numpy as np import numpy as np
import os
import tensorflow as tf import tensorflow as tf
import argparse import argparse
......
...@@ -56,6 +56,7 @@ class KerasModelCaller(object): ...@@ -56,6 +56,7 @@ class KerasModelCaller(object):
old_trainable_names = set([x.name for x in tf.trainable_variables()]) old_trainable_names = set([x.name for x in tf.trainable_variables()])
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES]) trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])
def post_process_model(model): def post_process_model(model):
added_trainable_names = set([x.name for x in tf.trainable_variables()]) added_trainable_names = set([x.name for x in tf.trainable_variables()])
...@@ -73,6 +74,11 @@ class KerasModelCaller(object): ...@@ -73,6 +74,11 @@ class KerasModelCaller(object):
logger.warn("Keras created trainable variable '{}' which is actually not trainable. " logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n)) "This was automatically corrected by tensorpack.".format(n))
# Keras models might not use this collection at all (in some versions).
restore_collection(update_ops_backup)
for op in model.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, op)
if self.cached_model is None: if self.cached_model is None:
assert not reuse assert not reuse
model = self.cached_model = self.get_model(*input_tensors) model = self.cached_model = self.get_model(*input_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment