Commit 97f47539 authored by Yuxin Wu's avatar Yuxin Wu

speed up GAN

parent bcee048d
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time
from tensorpack import (FeedfreeTrainer, TowerContext, from tensorpack import (FeedfreeTrainer, TowerContext,
get_global_step_var, QueueInput) get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
...@@ -21,17 +22,19 @@ class GANTrainer(FeedfreeTrainer): ...@@ -21,17 +22,19 @@ class GANTrainer(FeedfreeTrainer):
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss, self.g_min = self.config.optimizer.minimize(self.model.g_loss,
var_list=self.model.g_vars, name='g_op') var_list=self.model.g_vars, name='g_op',
gate_gradients=tf.train.Optimizer.GATE_NONE)
self.d_min = self.config.optimizer.minimize(self.model.d_loss, self.d_min = self.config.optimizer.minimize(self.model.d_loss,
var_list=self.model.d_vars, name='d_op') var_list=self.model.d_vars, name='d_op',
gate_gradients=tf.train.Optimizer.GATE_NONE)
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr') self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average() self.summary_op = summary_moving_average()
self.d_min = tf.group(self.d_min, self.summary_op, self.gs_incr) with tf.control_dependencies([self.g_min]):
#self.train_op = tf.group(self.g_min, self.d_min) self.d_min = tf.group(self.d_min, self.summary_op, self.gs_incr)
self.train_op = self.d_min
def run_step(self): def run_step(self):
self.sess.run(self.g_min) self.sess.run(self.train_op)
self.sess.run(self.d_min)
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
......
...@@ -28,7 +28,7 @@ To visualize on test set: ...@@ -28,7 +28,7 @@ To visualize on test set:
""" """
SHAPE = 256 SHAPE = 256
BATCH = 4 BATCH = 1
IN_CH = 3 IN_CH = 3
OUT_CH = 3 OUT_CH = 3
LAMBDA = 100 LAMBDA = 100
...@@ -159,7 +159,7 @@ def get_config(): ...@@ -159,7 +159,7 @@ def get_config():
dataset=dataset, dataset=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), ModelSaver(), StatPrinter(), PeriodicCallback(ModelSaver(), 3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]), ]),
model=Model(), model=Model(),
......
...@@ -50,17 +50,22 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -50,17 +50,22 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def run_step(self): def run_step(self):
""" Simply run self.train_op""" """ Simply run self.train_op"""
self.sess.run(self.train_op) self.sess.run(self.train_op)
# debug-benchmark code: #if not hasattr(self, 'cnt'):
#run_metadata = tf.RunMetadata() #self.cnt = 0
#self.sess.run([self.train_op], #else:
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), #self.cnt += 1
#run_metadata=run_metadata #if self.cnt % 10 == 0:
#) ## debug-benchmark code:
#from tensorflow.python.client import timeline #run_metadata = tf.RunMetadata()
#trace = timeline.Timeline(step_stats=run_metadata.step_stats) #self.sess.run([self.train_op],
#trace_file = open('timeline.ctf.json', 'w') #options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#trace_file.write(trace.generate_chrome_trace_format()) #run_metadata=run_metadata
#import sys; sys.exit() #)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, MultiPredictorTowerTrainer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment