Commit 8156810d authored by Yuxin Wu's avatar Yuxin Wu

misc small fixes

parent 65449110
......@@ -17,7 +17,7 @@ $(function (){
if (fullname.startsWith('tensorpack.'))
fullname = fullname.substr(11);
if (fullname == "tensorpack.dataflow.MultiProcessMapData") {
if (fullname == "dataflow.MultiProcessMapData") {
groupName = "parallel_map";
}
......
......@@ -11,7 +11,7 @@ from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu
import DCGAN
from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer
from GAN import GANModelDesc, GANTrainer
"""
Boundary Equilibrium GAN.
......@@ -139,11 +139,7 @@ if __name__ == '__main__':
input = QueueInput(DCGAN.get_data())
model = Model()
nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(input, model)
else:
trainer = MultiGPUGANTrainer(nr_tower, input, model)
trainer = GANTrainer(input, model, num_gpu=nr_tower)
trainer.train_with_defaults(
callbacks=[
ModelSaver(),
......
......@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs())
with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
......@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer):
self.register_callback(cbs)
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs())
with TowerContext('', is_training=True), \
argscope(BatchNorm, ema_update='internal'):
# should not hook the EMA updates to both train_op, it will hurt training speed.
......
......@@ -93,7 +93,7 @@ if __name__ == '__main__':
logger.auto_set_dir()
SeparateGANTrainer(
QueueInput(DCGAN.get_data()),
M, g_period=6).train_with_defaults(
M, g_period=5).train_with_defaults(
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
......
......@@ -63,6 +63,7 @@ Train a simple GAN on mnist, conditioned on the class labels.
## [WGAN.py](WGAN.py), [Improved-WGAN.py](Improved-WGAN.py), [BEGAN.py](BEGAN.py)
These variants are implemented by some small modifications on top of DCGAN.py.
BEGAN has the best visual quality among them.
Some BEGAN samples:
![began-sample](demo/BEGAN-CelebA-samples.jpg)
......
......@@ -3,7 +3,6 @@
# File: dump-model-params.py
import argparse
import sys
import numpy as np
import os
import six
......@@ -34,9 +33,9 @@ def _import_external_ops(message):
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.nccl.ops import gen_nccl_ops # noqa
else:
from tensorflow.python.ops import gen_nccl_ops
from tensorflow.python.ops import gen_nccl_ops # noqa
return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment