Commit f16aef9d authored by Yuxin Wu's avatar Yuxin Wu

fix dependencies

parent 97f47539
...@@ -22,16 +22,13 @@ class GANTrainer(FeedfreeTrainer): ...@@ -22,16 +22,13 @@ class GANTrainer(FeedfreeTrainer):
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss, self.g_min = self.config.optimizer.minimize(self.model.g_loss,
var_list=self.model.g_vars, name='g_op', var_list=self.model.g_vars, name='g_op')
gate_gradients=tf.train.Optimizer.GATE_NONE) with tf.control_dependencies([self.g_min]):
self.d_min = self.config.optimizer.minimize(self.model.d_loss, self.d_min = self.config.optimizer.minimize(self.model.d_loss,
var_list=self.model.d_vars, name='d_op', var_list=self.model.d_vars, name='d_op')
gate_gradients=tf.train.Optimizer.GATE_NONE)
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr') self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average() self.summary_op = summary_moving_average()
with tf.control_dependencies([self.g_min]): self.train_op = tf.group(self.d_min, self.summary_op, self.gs_incr)
self.d_min = tf.group(self.d_min, self.summary_op, self.gs_incr)
self.train_op = self.d_min
def run_step(self): def run_step(self):
self.sess.run(self.train_op) self.sess.run(self.train_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment