Commit 0937a01f authored by Yuxin Wu's avatar Yuxin Wu

small fix on keras examples

parent 0430c07c
......@@ -83,8 +83,7 @@ class Model(ModelDesc):
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors(
opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()])
return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()])
@staticmethod
def update_target_param():
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu
import numpy as np
import os
import tensorflow as tf
import argparse
......
......@@ -56,6 +56,7 @@ class KerasModelCaller(object):
old_trainable_names = set([x.name for x in tf.trainable_variables()])
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])
def post_process_model(model):
added_trainable_names = set([x.name for x in tf.trainable_variables()])
......@@ -73,6 +74,11 @@ class KerasModelCaller(object):
logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n))
# Keras models might not use this collection at all (in some versions).
restore_collection(update_ops_backup)
for op in model.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, op)
if self.cached_model is None:
assert not reuse
model = self.cached_model = self.get_model(*input_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment