Commit e741d7b4 authored by Yuxin Wu's avatar Yuxin Wu

update WGAN to include two possible ways of clipping

parent 321440af
......@@ -36,15 +36,34 @@ class Model(DCGAN.Model):
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 1e-4, summary=True)
opt = tf.train.RMSPropOptimizer(lr)
return opt
# add clipping to D optimizer
def clip(p):
n = p.op.name
# An alternative way to implement the clipping:
"""
def clip(v):
n = v.op.name
if not n.startswith('discrim/'):
return None
logger.info("Clip {}".format(n))
return tf.clip_by_value(p, -0.01, 0.01)
return tf.clip_by_value(v, -0.01, 0.01)
return optimizer.VariableAssignmentOptimizer(opt, clip)
"""
class ClipCallback(Callback):
def _setup_graph(self):
vars = tf.trainable_variables()
ops = []
for v in vars:
n = v.op.name
if not n.startswith('discrim/'):
continue
logger.info("Clip {}".format(n))
ops.append(tf.assign(v, tf.clip_by_value(v, -0.01, 0.01)))
self._op = tf.group(*ops, name='clip')
def _trigger_step(self):
self._op.run()
if __name__ == '__main__':
......@@ -58,7 +77,7 @@ if __name__ == '__main__':
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment