Commit 2c5b1bec authored by Yuxin Wu's avatar Yuxin Wu

Allow double-wrapping TowerFuncWrapper

parent d9817a56
......@@ -161,10 +161,19 @@ class TowerFuncWrapper(object):
It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors.
"""
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
assert callable(tower_fn), tower_fn
if not isinstance(tower_fn, TowerFuncWrapper):
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._towers = []
self._towers = []
def __new__(cls, tower_fn, inputs_desc=None):
# to avoid double-wrapping a function
if isinstance(tower_fn, TowerFuncWrapper):
return tower_fn
else:
return super(TowerFuncWrapper, cls).__new__(cls)
def __call__(self, *args):
ctx = get_current_tower_context()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment