Commit c346e924 authored by Yuxin Wu's avatar Yuxin Wu

fix #786

parent 56a77747
...@@ -80,8 +80,8 @@ You can overwrite any of the following methods to define a new callback: ...@@ -80,8 +80,8 @@ You can overwrite any of the following methods to define a new callback:
The training loops would become `sess.run([training_op, my_op])`. The training loops would become `sess.run([training_op, my_op])`.
This is different from `sess.run(training_op); sess.run(my_op);`, This is different from `sess.run(training_op); sess.run(my_op);`,
which is what you would get if you run `my_op` in `_trigger_step`. which is what you would get if you write `self.trainer.sess.run(my_op)` in `_trigger_step`.
Sometimes the difference matters, please choose carefully. Usually the difference matters, please choose carefully.
* `_trigger_step(self)` * `_trigger_step(self)`
...@@ -105,6 +105,7 @@ You can overwrite any of the following methods to define a new callback: ...@@ -105,6 +105,7 @@ You can overwrite any of the following methods to define a new callback:
* Access tensors / ops (details mentioned above): * Access tensors / ops (details mentioned above):
* For existing tensors/ops created in the tower, access them through [self.trainer.towers](../../modules/train.html#tensorpack.train.TowerTrainer.towers). * For existing tensors/ops created in the tower, access them through [self.trainer.towers](../../modules/train.html#tensorpack.train.TowerTrainer.towers).
* Extra tensors/ops have to be created in `_setup_graph` callback method. * Extra tensors/ops have to be created in `_setup_graph` callback method.
* Access the current graph and session by `self.trainer.graph` and `self.trainer.sess`.
* Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`. * Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`.
The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc. The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc.
You can access history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors) You can access history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors)
......
...@@ -57,7 +57,7 @@ class Model(ModelDesc): ...@@ -57,7 +57,7 @@ class Model(ModelDesc):
def get_basic_cell(): def get_basic_cell():
cell = rnn.BasicLSTMCell(num_units=HIDDEN_SIZE, forget_bias=0.0, reuse=tf.get_variable_scope().reuse) cell = rnn.BasicLSTMCell(num_units=HIDDEN_SIZE, forget_bias=0.0, reuse=tf.get_variable_scope().reuse)
if is_training: if is_training:
cell = rnn.DropoutWrapper(cell, output_keep_prob=DROPOUT) cell = rnn.DropoutWrapper(cell, output_keep_prob=1 - DROPOUT)
return cell return cell
cell = rnn.MultiRNNCell([get_basic_cell() for _ in range(NUM_LAYER)]) cell = rnn.MultiRNNCell([get_basic_cell() for _ in range(NUM_LAYER)])
...@@ -73,7 +73,7 @@ class Model(ModelDesc): ...@@ -73,7 +73,7 @@ class Model(ModelDesc):
embeddingW = tf.get_variable('embedding', [VOCAB_SIZE, HIDDEN_SIZE], initializer=initializer) embeddingW = tf.get_variable('embedding', [VOCAB_SIZE, HIDDEN_SIZE], initializer=initializer)
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x hiddensize input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x hiddensize
input_feature = Dropout(input_feature, rate=DROPOUT) input_feature = Dropout(input_feature, keep_prob=1 - DROPOUT)
with tf.variable_scope('LSTM', initializer=initializer): with tf.variable_scope('LSTM', initializer=initializer):
input_list = tf.unstack(input_feature, num=SEQ_LEN, axis=1) # seqlen x (Bxhidden) input_list = tf.unstack(input_feature, num=SEQ_LEN, axis=1) # seqlen x (Bxhidden)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment