Commit 2d99afea authored by Yuxin Wu's avatar Yuxin Wu

[WIP] Switch GANs to use Trainerv2

parent 17a73a4c
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
......
...@@ -10,6 +10,7 @@ import sys ...@@ -10,6 +10,7 @@ import sys
import cv2 import cv2
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
......
...@@ -9,6 +9,7 @@ import glob ...@@ -9,6 +9,7 @@ import glob
from six.moves import map, zip, range from six.moves import map, zip, range
import numpy as np import numpy as np
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import os, sys import os, sys
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -156,11 +157,11 @@ if __name__ == '__main__': ...@@ -156,11 +157,11 @@ if __name__ == '__main__':
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
GANTrainer(config).train() GANTrainer(
input=QueueInput(get_data(args.data)),
model=Model()).train_with_config(config)
...@@ -8,6 +8,7 @@ import argparse ...@@ -8,6 +8,7 @@ import argparse
from six.moves import map, zip from six.moves import map, zip
import numpy as np import numpy as np
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (Trainer, QueueInput, from tensorpack import (TowerTrainer, QueueInput,
ModelDescBase, DataFlow, StagingInput, ModelDescBase, DataFlow, StagingInput,
TowerContext) TowerContext, TowerFuncWrapper)
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
...@@ -64,20 +64,15 @@ class GANModelDesc(ModelDescBase): ...@@ -64,20 +64,15 @@ class GANModelDesc(ModelDescBase):
return self._get_optimizer() return self._get_optimizer()
class GANTrainer(Trainer): class GANTrainer(TowerTrainer):
def __init__(self, config): def __init__(self, input, model):
""" super(GANTrainer, self).__init__()
GANTrainer expects a ModelDesc in config which sets the following attribute assert isinstance(model, GANModelDesc), model
after :meth:`_build_graph`: g_loss, d_loss, g_vars, d_vars.
"""
input = QueueInput(config.dataflow)
model = config.model
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(input) tower_func(input)
opt = model.get_optimizer() opt = model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
...@@ -86,29 +81,29 @@ class GANTrainer(Trainer): ...@@ -86,29 +81,29 @@ class GANTrainer(Trainer):
with tf.control_dependencies([g_min]): with tf.control_dependencies([g_min]):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op') d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min self.train_op = d_min
self.set_tower_func(tower_func)
super(GANTrainer, self).__init__(config) for cb in cbs:
self._register_callback(cb)
class SeparateGANTrainer(Trainer): class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """ """ A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_period=1, g_period=1): def __init__(self, input, model, d_period=1, g_period=1):
""" """
Args: Args:
d_period(int): period of each d_opt run d_period(int): period of each d_opt run
g_period(int): period of each g_opt run g_period(int): period of each g_opt run
""" """
super(SeparateGANTrainer, self).__init__()
self._d_period = int(d_period) self._d_period = int(d_period)
self._g_period = int(g_period) self._g_period = int(g_period)
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
input = QueueInput(config.dataflow)
model = config.model
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs) tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(input) tower_func(input)
opt = model.get_optimizer() opt = model.get_optimizer()
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
...@@ -117,7 +112,9 @@ class SeparateGANTrainer(Trainer): ...@@ -117,7 +112,9 @@ class SeparateGANTrainer(Trainer):
self.g_min = opt.minimize( self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min') model.g_loss, var_list=model.g_vars, name='g_min')
super(SeparateGANTrainer, self).__init__(config) self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
def run_step(self): def run_step(self):
if self.global_step % (self._d_period) == 0: if self.global_step % (self._d_period) == 0:
...@@ -126,26 +123,25 @@ class SeparateGANTrainer(Trainer): ...@@ -126,26 +123,25 @@ class SeparateGANTrainer(Trainer):
self.hooked_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
class MultiGPUGANTrainer(Trainer): class MultiGPUGANTrainer(TowerTrainer):
""" """
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
""" """
def __init__(self, config): def __init__(self, nr_gpu, input, model):
nr_gpu = config.nr_tower super(MultiGPUGANTrainer, self).__init__()
assert nr_gpu > 1 assert nr_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in config.tower] raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# setup input # setup input
input = StagingInput(QueueInput(config.dataflow), config.tower) input = StagingInput(input, list(range(nr_gpu)))
model = config.model
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def get_cost(): def get_cost():
model.build_graph(input) model.build_graph(input.get_input_tensors())
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(config.tower, get_cost, devices) cost_list = DataParallelBuilder.build_on_towers(list(range(nr_gpu)), tower_func, devices)
# simply average the cost. It might get faster to average the gradients # simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
...@@ -159,7 +155,9 @@ class MultiGPUGANTrainer(Trainer): ...@@ -159,7 +155,9 @@ class MultiGPUGANTrainer(Trainer):
d_min = opt.minimize(d_loss, var_list=model.d_vars, d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op') colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min self.train_op = d_min
super(MultiGPUGANTrainer, self).__init__(config) self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -12,6 +12,7 @@ import os ...@@ -12,6 +12,7 @@ import os
import sys import sys
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G from tensorpack.utils.globvars import globalns as G
......
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import sys import sys
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils import viz from tensorpack.utils import viz
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -76,8 +77,6 @@ if __name__ == '__main__': ...@@ -76,8 +77,6 @@ if __name__ == '__main__':
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver(), ClipCallback()], callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=200, max_epoch=200,
...@@ -85,4 +84,6 @@ if __name__ == '__main__': ...@@ -85,4 +84,6 @@ if __name__ == '__main__':
) )
# The original code uses a different schedule, but this seems to work well. # The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G # Train 1 D after 2 G
SeparateGANTrainer(config, d_period=3).train() SeparateGANTrainer(
input=QueueInput(DCGAN.get_data(args.data)),
model=Model(), d_period=3).train_with_config(config)
...@@ -218,6 +218,28 @@ class Trainer(object): ...@@ -218,6 +218,28 @@ class Trainer(object):
self.initialize(session_creator, session_init) self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch) self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_config(self, config):
"""
An alias to simplify the use of `TrainConfig`.
It is equivalent to the following:
.. code-block:: python
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
"""
if config.data or config.dataflow or config.model:
logger.warn(
"data/dataflow/model in TrainConfig will not be used "
"in `Trainer.train_with_config`")
logger.warn("To build the graph from config, use `launch_train_with_config`!")
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
# create the old trainer when called with TrainConfig # create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \ if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
...@@ -337,20 +359,6 @@ class SingleCostTrainer(TowerTrainer): ...@@ -337,20 +359,6 @@ class SingleCostTrainer(TowerTrainer):
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`. To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
""" """
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once @call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn): def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
""" """
...@@ -375,8 +383,10 @@ class SingleCostTrainer(TowerTrainer): ...@@ -375,8 +383,10 @@ class SingleCostTrainer(TowerTrainer):
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks for cb in internal_callbacks:
self._register_callback(cb)
return internal_callbacks
@abstractmethod @abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment