Commit 8250786f authored by yogurfrul's avatar yogurfrul Committed by Yuxin Wu

make the pix2pix support multi gpu trainning (#884)

* make the pix2pix support multi gpu trainning

* Update Image2Image.py

* Update Image2Image.py

* Update Image2Image.py
parent b8c7b6f4
...@@ -12,10 +12,11 @@ import argparse ...@@ -12,10 +12,11 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, GANModelDesc from GAN import GANTrainer, MultiGPUGANTrainer, GANModelDesc
""" """
To train Image-to-Image translation model with image pairs: To train Image-to-Image translation model with image pairs:
...@@ -218,7 +219,13 @@ if __name__ == '__main__': ...@@ -218,7 +219,13 @@ if __name__ == '__main__':
data = QueueInput(get_data()) data = QueueInput(get_data())
GANTrainer(data, Model()).train_with_defaults( nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(data, Model())
else:
trainer = MultiGPUGANTrainer(nr_tower, data, Model())
trainer.train_with_defaults(
callbacks=[ callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=3), PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment