Commit 712ea325 authored by Yuxin Wu's avatar Yuxin Wu

Use getter and setter for trainer.tower_func, instead of `set_tower_func`.

parent a988fc18
...@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'DumpTensor', 'DumpTensor',
'StagingInputWrapper', 'StagingInputWrapper',
'StepTensorPrinter', 'StepTensorPrinter',
'set_tower_func',
'guided_relu', 'saliency_map', 'get_scalar_var', 'guided_relu', 'saliency_map', 'get_scalar_var',
'prediction_incorrect', 'huber_loss', 'prediction_incorrect', 'huber_loss',
......
...@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer): ...@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer):
# we need to set towerfunc because it's a TowerTrainer, # we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training. # and only TowerTrainer supports automatic graph creation for inference during training.
tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
...@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer): ...@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer):
with tf.control_dependencies([g_min]): with tf.control_dependencies([g_min]):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op') d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min self.train_op = d_min
self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self.register_callback(cb) self.register_callback(cb)
...@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer):
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
...@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer):
self.g_min = opt.minimize( self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min') model.g_loss, var_list=model.g_vars, name='g_min')
self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self.register_callback(cb) self.register_callback(cb)
...@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer):
def get_cost(*inputs): def get_cost(*inputs):
model.build_graph(*inputs) model.build_graph(*inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers( cost_list = DataParallelBuilder.build_on_towers(
list(range(nr_gpu)), list(range(nr_gpu)),
lambda: tower_func(*input.get_input_tensors()), lambda: self.tower_func(*input.get_input_tensors()),
devices) devices)
# simply average the cost. It might get faster to average the gradients # simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
...@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer):
d_min = opt.minimize(d_loss, var_list=model.d_vars, d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op') colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min self.train_op = d_min
self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self.register_callback(cb) self.register_callback(cb)
......
...@@ -7,6 +7,7 @@ import six ...@@ -7,6 +7,7 @@ import six
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized from ..utils.argtools import call_only_once, memoized
from ..utils.develop import deprecated
from ..graph_builder.predict import SimplePredictBuilder from ..graph_builder.predict import SimplePredictBuilder
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
...@@ -25,22 +26,33 @@ class TowerTrainer(Trainer): ...@@ -25,22 +26,33 @@ class TowerTrainer(Trainer):
This is required by some features that replicates the model This is required by some features that replicates the model
automatically, e.g. creating a predictor. automatically, e.g. creating a predictor.
"""
tower_func = None To use features of :class:`TowerTrainer`, set `tower_func` and use it to build the graph.
""" Note that `tower_func` can only be set once per instance.
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
""" """
_tower_func = None
@call_only_once @call_only_once
def _set_tower_func(self, tower_func):
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self._tower_func = tower_func
@deprecated("Just use tower_func = xxx instead!")
def set_tower_func(self, tower_func): def set_tower_func(self, tower_func):
self._set_tower_func(tower_func)
@property
def tower_func(self):
""" """
Args: A :class:`TowerFuncWrapper` instance.
tower_func (TowerFuncWrapper) A callable which takes some input tensors and builds one replicate of the model.
""" """
assert isinstance(tower_func, TowerFuncWrapper), tower_func return self._tower_func
self.tower_func = tower_func
@tower_func.setter
def tower_func(self, val):
self._set_tower_func(val)
@property @property
def inputs_desc(self): def inputs_desc(self):
...@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer):
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc) get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn) self.tower_func = get_cost_fn
# TODO setup may want to register monitor as well?? # TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
......
...@@ -147,19 +147,25 @@ _FUNC_CALLED = set() ...@@ -147,19 +147,25 @@ _FUNC_CALLED = set()
def call_only_once(func): def call_only_once(func):
""" """
Decorate a method of a class, so that this method can only Decorate a method or property of a class, so that this method can only
be called once for every instance. be called once for every instance.
Calling it more than once will result in exception. Calling it more than once will result in exception.
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
self = args[0] self = args[0]
assert hasattr(self, func.__name__), "call_only_once can only be used on method!" # cannot use hasattr here, because hasattr tries to getattr, which
# fails if func is a property
assert func.__name__ in dir(self), "call_only_once can only be used on method or property!"
cls = type(self)
# cannot use ismethod(), because decorated method becomes a function
is_method = inspect.isfunction(getattr(cls, func.__name__))
key = (self, func) key = (self, func)
assert key not in _FUNC_CALLED, \ assert key not in _FUNC_CALLED, \
"Method {}.{} can only be called once per object!".format( "{} {}.{} can only be called once per object!".format(
type(self).__name__, func.__name__) 'Method' if is_method else 'Property',
cls.__name__, func.__name__)
_FUNC_CALLED.add(key) _FUNC_CALLED.add(key)
return func(*args, **kwargs) return func(*args, **kwargs)
...@@ -169,13 +175,32 @@ def call_only_once(func): ...@@ -169,13 +175,32 @@ def call_only_once(func):
if __name__ == '__main__': if __name__ == '__main__':
class A(): class A():
def __init__(self):
self._p = 0
@call_only_once @call_only_once
def f(self, x): def f(self, x):
print(x) print(x)
@property
def p(self):
return self._p
@p.setter
@call_only_once
def p(self, val):
self._p = val
a = A() a = A()
a.f(1) a.f(1)
b = A() b = A()
b.f(2) b.f(2)
b.f(1) b.f(1)
print(b.p)
print(b.p)
b.p = 2
print(b.p)
b.p = 3
print(b.p)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment