Commit b1c82a9c authored by Yuxin Wu's avatar Yuxin Wu

fix #1202

parent 2981c5d4
......@@ -41,7 +41,7 @@ It has:
Keras does not respect variable scopes or variable
collections, which contradicts with tensorpack trainers.
Therefore Keras support is __experimental__.
Therefore Keras support is __experimental__ and __unofficial__.
These simple examples can run within tensorpack smoothly, but note that a future
version of Keras or a complicated model may break them (unlikely, though).
These simple examples can run within tensorpack smoothly, but note that a
complicated model or a future version of Keras may break them.
......@@ -12,6 +12,7 @@ from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils.tower import get_current_tower_context
KL = keras.layers
......@@ -61,9 +62,14 @@ class Model(ModelDesc):
def build_graph(self, image, label):
image = tf.expand_dims(image, 3) * 2 - 1
ctx = get_current_tower_context()
M = get_keras_model()
logits = M(image)
if ctx.is_main_training_tower:
for op in M.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS)
# build cost function by tensorflow
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment