Commit b1c82a9c authored by Yuxin Wu's avatar Yuxin Wu

fix #1202

parent 2981c5d4
...@@ -41,7 +41,7 @@ It has: ...@@ -41,7 +41,7 @@ It has:
Keras does not respect variable scopes or variable Keras does not respect variable scopes or variable
collections, which contradicts with tensorpack trainers. collections, which contradicts with tensorpack trainers.
Therefore Keras support is __experimental__. Therefore Keras support is __experimental__ and __unofficial__.
These simple examples can run within tensorpack smoothly, but note that a future These simple examples can run within tensorpack smoothly, but note that a
version of Keras or a complicated model may break them (unlikely, though). complicated model or a future version of Keras may break them.
...@@ -12,6 +12,7 @@ from tensorpack.contrib.keras import KerasPhaseCallback ...@@ -12,6 +12,7 @@ from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils.tower import get_current_tower_context
KL = keras.layers KL = keras.layers
...@@ -61,9 +62,14 @@ class Model(ModelDesc): ...@@ -61,9 +62,14 @@ class Model(ModelDesc):
def build_graph(self, image, label): def build_graph(self, image, label):
image = tf.expand_dims(image, 3) * 2 - 1 image = tf.expand_dims(image, 3) * 2 - 1
ctx = get_current_tower_context()
M = get_keras_model() M = get_keras_model()
logits = M(image) logits = M(image)
if ctx.is_main_training_tower:
for op in M.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS)
# build cost function by tensorflow # build cost function by tensorflow
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment