Commit 63004976 authored by Yuxin Wu's avatar Yuxin Wu

simplify WGAN. fix bug in viz.

parent 13c96b94
...@@ -39,7 +39,16 @@ class Model(DCGAN.Model): ...@@ -39,7 +39,16 @@ class Model(DCGAN.Model):
def _get_optimizer(self): def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
return tf.train.RMSPropOptimizer(lr) opt = tf.train.RMSPropOptimizer(lr)
# add clipping to D optimizer
def clip(p):
n = p.op.name
if not n.startswith('discrim/'):
return None
logger.info("Clip {}".format(n))
return tf.clip_by_value(p, -0.01, 0.01)
return optimizer.VariableAssignmentOptimizer(opt, clip)
DCGAN.BATCH = 64 DCGAN.BATCH = 64
...@@ -67,17 +76,10 @@ class WGANTrainer(FeedfreeTrainerBase): ...@@ -67,17 +76,10 @@ class WGANTrainer(FeedfreeTrainerBase):
super(WGANTrainer, self)._setup() super(WGANTrainer, self)._setup()
self.build_train_tower() self.build_train_tower()
# add clipping to D optimizer opt = self.model.get_optimizer()
def clip(p): self.d_min = opt.minimize(
n = p.op.name
logger.info("Clip {}".format(n))
return tf.clip_by_value(p, -0.01, 0.01)
opt_G = self.model.get_optimizer()
opt_D = optimizer.VariableAssignmentOptimizer(opt_G, clip)
self.d_min = opt_D.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min') self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt_G.minimize( self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_op') self.model.g_loss, var_list=self.model.g_vars, name='g_op')
def run_step(self): def run_step(self):
......
...@@ -83,8 +83,8 @@ class PostProcessOptimizer(ProxyOptimizer): ...@@ -83,8 +83,8 @@ class PostProcessOptimizer(ProxyOptimizer):
for _, var in grads_and_vars: for _, var in grads_and_vars:
with self._maybe_colocate(var): with self._maybe_colocate(var):
op = self._func(var) op = self._func(var)
assert isinstance(op, tf.Operation), op
if op is not None: if op is not None:
assert isinstance(op, tf.Operation), op
ops.append(op) ops.append(op)
update_op = tf.group(update_op, *ops, name=name) update_op = tf.group(update_op, *ops, name=name)
return update_op return update_op
......
...@@ -202,7 +202,7 @@ def stack_patches( ...@@ -202,7 +202,7 @@ def stack_patches(
canvas.draw_patches(patch_list) canvas.draw_patches(patch_list)
if viz: if viz:
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback) interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
return canvas return canvas.canvas
def gen_stack_patches(patch_list, def gen_stack_patches(patch_list,
...@@ -260,7 +260,7 @@ def gen_stack_patches(patch_list, ...@@ -260,7 +260,7 @@ def gen_stack_patches(patch_list,
canvas.draw_patches(cur_list) canvas.draw_patches(cur_list)
if viz: if viz:
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback) interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
yield canvas yield canvas.canvas
start = end start = end
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment