Commit 8250786f authored by yogurfrul's avatar yogurfrul Committed by Yuxin Wu

make the pix2pix support multi gpu trainning (#884)

* make the pix2pix support multi gpu trainning

* Update Image2Image.py

* Update Image2Image.py

* Update Image2Image.py
parent b8c7b6f4
......@@ -12,10 +12,11 @@ import argparse
from tensorpack import *
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, GANModelDesc
from GAN import GANTrainer, MultiGPUGANTrainer, GANModelDesc
"""
To train Image-to-Image translation model with image pairs:
......@@ -218,7 +219,13 @@ if __name__ == '__main__':
data = QueueInput(get_data())
GANTrainer(data, Model()).train_with_defaults(
nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(data, Model())
else:
trainer = MultiGPUGANTrainer(nr_tower, data, Model())
trainer.train_with_defaults(
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment