Commit a979a3f3 authored by Yuxin Wu's avatar Yuxin Wu

fix image summary

parent 8abdaf77
......@@ -75,7 +75,7 @@ class Model(ModelDesc):
W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'):
image_gen = self.generator(z)
tf.summary.image('gen', image_gen, max_images=30)
tf.summary.image('gen', image_gen, max_outputs=30)
with tf.variable_scope('discrim'):
vecpos = self.discriminator(image_pos)
with tf.variable_scope('discrim', reuse=True):
......@@ -105,12 +105,11 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]),
session_config=get_default_sess_config(0.5),
model=Model(),
step_per_epoch=300,
max_epoch=300,
max_epoch=200,
)
def sample(model_path):
......
......@@ -117,7 +117,7 @@ class Model(ModelDesc):
fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat(2, [input, output, fake_output]) + 1.0) * 128.0
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.image_summary('gen', viz, max_images=max(30, BATCH))
tf.image_summary('gen', viz, max_outputs=max(30, BATCH))
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
......
......@@ -66,7 +66,7 @@ class Model(ModelDesc):
W_init=tf.truncated_normal_initializer(stddev=0.02)):
with tf.variable_scope('gen'):
image_gen = self.generator(z)
tf.summary.image('gen', image_gen, max_images=30)
tf.summary.image('gen', image_gen, max_outputs=30)
with tf.variable_scope('discrim'):
vecpos, _ = self.discriminator(image_pos)
with tf.variable_scope('discrim', reuse=True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment