Commit 8156810d authored by Yuxin Wu's avatar Yuxin Wu

misc small fixes

parent 65449110
...@@ -17,7 +17,7 @@ $(function (){ ...@@ -17,7 +17,7 @@ $(function (){
if (fullname.startsWith('tensorpack.')) if (fullname.startsWith('tensorpack.'))
fullname = fullname.substr(11); fullname = fullname.substr(11);
if (fullname == "tensorpack.dataflow.MultiProcessMapData") { if (fullname == "dataflow.MultiProcessMapData") {
groupName = "parallel_map"; groupName = "parallel_map";
} }
......
...@@ -11,7 +11,7 @@ from tensorpack.tfutils.summary import add_moving_summary ...@@ -11,7 +11,7 @@ from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
import DCGAN import DCGAN
from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer from GAN import GANModelDesc, GANTrainer
""" """
Boundary Equilibrium GAN. Boundary Equilibrium GAN.
...@@ -139,11 +139,7 @@ if __name__ == '__main__': ...@@ -139,11 +139,7 @@ if __name__ == '__main__':
input = QueueInput(DCGAN.get_data()) input = QueueInput(DCGAN.get_data())
model = Model() model = Model()
nr_tower = max(get_num_gpu(), 1) nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1: trainer = GANTrainer(input, model, num_gpu=nr_tower)
trainer = GANTrainer(input, model)
else:
trainer = MultiGPUGANTrainer(nr_tower, input, model)
trainer.train_with_defaults( trainer.train_with_defaults(
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
......
...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer): ...@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK. not needed. Just calling model.build_graph directly is OK.
""" """
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature()) self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
...@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer):
self.register_callback(cbs) self.register_callback(cbs)
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature()) self.tower_func = TowerFuncWrapper(model.build_graph, model.inputs())
with TowerContext('', is_training=True), \ with TowerContext('', is_training=True), \
argscope(BatchNorm, ema_update='internal'): argscope(BatchNorm, ema_update='internal'):
# should not hook the EMA updates to both train_op, it will hurt training speed. # should not hook the EMA updates to both train_op, it will hurt training speed.
......
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
SeparateGANTrainer( SeparateGANTrainer(
QueueInput(DCGAN.get_data()), QueueInput(DCGAN.get_data()),
M, g_period=6).train_with_defaults( M, g_period=5).train_with_defaults(
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
......
...@@ -63,6 +63,7 @@ Train a simple GAN on mnist, conditioned on the class labels. ...@@ -63,6 +63,7 @@ Train a simple GAN on mnist, conditioned on the class labels.
## [WGAN.py](WGAN.py), [Improved-WGAN.py](Improved-WGAN.py), [BEGAN.py](BEGAN.py) ## [WGAN.py](WGAN.py), [Improved-WGAN.py](Improved-WGAN.py), [BEGAN.py](BEGAN.py)
These variants are implemented by some small modifications on top of DCGAN.py. These variants are implemented by some small modifications on top of DCGAN.py.
BEGAN has the best visual quality among them.
Some BEGAN samples: Some BEGAN samples:
![began-sample](demo/BEGAN-CelebA-samples.jpg) ![began-sample](demo/BEGAN-CelebA-samples.jpg)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# File: dump-model-params.py # File: dump-model-params.py
import argparse import argparse
import sys
import numpy as np import numpy as np
import os import os
import six import six
...@@ -21,7 +20,7 @@ def _import_external_ops(message): ...@@ -21,7 +20,7 @@ def _import_external_ops(message):
logger.info("Importing horovod ...") logger.info("Importing horovod ...")
import horovod.tensorflow # noqa import horovod.tensorflow # noqa
return return
if "MaxBytesInUse" in message: if "MaxBytesInUse" in message:
logger.info("Importing memory_stats ...") logger.info("Importing memory_stats ...")
from tensorflow.contrib.memory_stats import MaxBytesInUse # noqa from tensorflow.contrib.memory_stats import MaxBytesInUse # noqa
return return
...@@ -34,9 +33,9 @@ def _import_external_ops(message): ...@@ -34,9 +33,9 @@ def _import_external_ops(message):
pass pass
else: else:
_validate_and_load_nccl_so() _validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.nccl.ops import gen_nccl_ops # noqa
else: else:
from tensorflow.python.ops import gen_nccl_ops from tensorflow.python.ops import gen_nccl_ops # noqa
return return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment