Commit 712ea325 authored by Yuxin Wu's avatar Yuxin Wu

Use getter and setter for trainer.tower_func, instead of `set_tower_func`.

parent a988fc18
......@@ -372,6 +372,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'DumpTensor',
'StagingInputWrapper',
'StepTensorPrinter',
'set_tower_func',
'guided_relu', 'saliency_map', 'get_scalar_var',
'prediction_incorrect', 'huber_loss',
......
......@@ -72,9 +72,9 @@ class GANTrainer(TowerTrainer):
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors())
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
# by default, run one d_min after one g_min
......@@ -83,7 +83,6 @@ class GANTrainer(TowerTrainer):
with tf.control_dependencies([g_min]):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min
self.set_tower_func(tower_func)
for cb in cbs:
self.register_callback(cb)
......@@ -103,9 +102,9 @@ class SeparateGANTrainer(TowerTrainer):
assert min(d_period, g_period) == 1
cbs = input.setup(model.get_inputs_desc())
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors())
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
with tf.name_scope('optimize'):
......@@ -114,7 +113,6 @@ class SeparateGANTrainer(TowerTrainer):
self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min')
self.set_tower_func(tower_func)
for cb in cbs:
self.register_callback(cb)
......@@ -142,11 +140,11 @@ class MultiGPUGANTrainer(TowerTrainer):
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(nr_gpu)),
lambda: tower_func(*input.get_input_tensors()),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'):
......@@ -161,7 +159,6 @@ class MultiGPUGANTrainer(TowerTrainer):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min
self.set_tower_func(tower_func)
for cb in cbs:
self.register_callback(cb)
......
......@@ -7,6 +7,7 @@ import six
from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import deprecated
from ..graph_builder.predict import SimplePredictBuilder
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
......@@ -25,22 +26,33 @@ class TowerTrainer(Trainer):
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
To use features of :class:`TowerTrainer`, set `tower_func` and use it to build the graph.
Note that `tower_func` can only be set once per instance.
"""
_tower_func = None
@call_only_once
def _set_tower_func(self, tower_func):
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self._tower_func = tower_func
@deprecated("Just use tower_func = xxx instead!")
def set_tower_func(self, tower_func):
self._set_tower_func(tower_func)
@property
def tower_func(self):
"""
Args:
tower_func (TowerFuncWrapper)
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
return self._tower_func
@tower_func.setter
def tower_func(self, val):
self._set_tower_func(val)
@property
def inputs_desc(self):
......@@ -128,7 +140,7 @@ class SingleCostTrainer(TowerTrainer):
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
self.tower_func = get_cost_fn
# TODO setup may want to register monitor as well??
input_callbacks = self._setup_input(inputs_desc, input)
......
......@@ -147,19 +147,25 @@ _FUNC_CALLED = set()
def call_only_once(func):
"""
Decorate a method of a class, so that this method can only
Decorate a method or property of a class, so that this method can only
be called once for every instance.
Calling it more than once will result in exception.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
assert hasattr(self, func.__name__), "call_only_once can only be used on method!"
# cannot use hasattr here, because hasattr tries to getattr, which
# fails if func is a property
assert func.__name__ in dir(self), "call_only_once can only be used on method or property!"
cls = type(self)
# cannot use ismethod(), because decorated method becomes a function
is_method = inspect.isfunction(getattr(cls, func.__name__))
key = (self, func)
assert key not in _FUNC_CALLED, \
"Method {}.{} can only be called once per object!".format(
type(self).__name__, func.__name__)
"{} {}.{} can only be called once per object!".format(
'Method' if is_method else 'Property',
cls.__name__, func.__name__)
_FUNC_CALLED.add(key)
return func(*args, **kwargs)
......@@ -169,13 +175,32 @@ def call_only_once(func):
if __name__ == '__main__':
class A():
def __init__(self):
self._p = 0
@call_only_once
def f(self, x):
print(x)
@property
def p(self):
return self._p
@p.setter
@call_only_once
def p(self, val):
self._p = val
a = A()
a.f(1)
b = A()
b.f(2)
b.f(1)
print(b.p)
print(b.p)
b.p = 2
print(b.p)
b.p = 3
print(b.p)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment