Commit 97f47539 authored by Yuxin Wu's avatar Yuxin Wu

speed up GAN

parent bcee048d
......@@ -5,6 +5,7 @@
import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainer, TowerContext,
get_global_step_var, QueueInput)
from tensorpack.tfutils.summary import summary_moving_average, add_moving_summary
......@@ -21,17 +22,19 @@ class GANTrainer(FeedfreeTrainer):
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
self.g_min = self.config.optimizer.minimize(self.model.g_loss,
var_list=self.model.g_vars, name='g_op')
var_list=self.model.g_vars, name='g_op',
gate_gradients=tf.train.Optimizer.GATE_NONE)
self.d_min = self.config.optimizer.minimize(self.model.d_loss,
var_list=self.model.d_vars, name='d_op')
var_list=self.model.d_vars, name='d_op',
gate_gradients=tf.train.Optimizer.GATE_NONE)
self.gs_incr = tf.assign_add(get_global_step_var(), 1, name='global_step_incr')
self.summary_op = summary_moving_average()
self.d_min = tf.group(self.d_min, self.summary_op, self.gs_incr)
#self.train_op = tf.group(self.g_min, self.d_min)
with tf.control_dependencies([self.g_min]):
self.d_min = tf.group(self.d_min, self.summary_op, self.gs_incr)
self.train_op = self.d_min
def run_step(self):
self.sess.run(self.g_min)
self.sess.run(self.d_min)
self.sess.run(self.train_op)
class RandomZData(DataFlow):
def __init__(self, shape):
......
......@@ -28,7 +28,7 @@ To visualize on test set:
"""
SHAPE = 256
BATCH = 4
BATCH = 1
IN_CH = 3
OUT_CH = 3
LAMBDA = 100
......@@ -159,7 +159,7 @@ def get_config():
dataset=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
StatPrinter(), PeriodicCallback(ModelSaver(), 3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]),
model=Model(),
......
......@@ -50,17 +50,22 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#if not hasattr(self, 'cnt'):
#self.cnt = 0
#else:
#self.cnt += 1
#if self.cnt % 10 == 0:
## debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment