Commit 260936ff authored by Yuxin Wu's avatar Yuxin Wu

Fix imports; Add call_only_once decorator;

parent 9fff46d5
......@@ -9,7 +9,7 @@ from six.moves import zip
from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import memoized
from ..utils.argtools import memoized, call_only_once
from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
......@@ -85,6 +85,7 @@ class InputSource(object):
def _get_input_tensors(self):
pass
@call_only_once
def setup(self, inputs_desc):
"""
Args:
......
......@@ -10,6 +10,7 @@ import six
from abc import abstractmethod, ABCMeta
from ..utils import logger
from ..utils.argtools import call_only_once
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
......@@ -73,6 +74,7 @@ class Trainer(object):
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
@call_only_once
def setup_callbacks(self, callbacks, monitors):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
......@@ -92,6 +94,7 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
@call_only_once
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
......@@ -120,6 +123,7 @@ class Trainer(object):
and self.hooked_sess (the session with hooks and coordinator)
"""
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
"""
Run the main training loop.
......@@ -213,6 +217,10 @@ class SingleCostTrainer(Trainer):
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
Trainer.train(
self,
......@@ -220,6 +228,7 @@ class SingleCostTrainer(Trainer):
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Build the main training graph. Defaults to do nothing.
......
......@@ -6,17 +6,19 @@ import os
from ..callbacks.graph import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..input_source import QueueInput
from ..graph_builder.training import (
SimpleBuilder,
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder,
DistributedReplicatedBuilder)
AsyncMultiGPUBuilder)
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..input_source import QueueInput
from .base import SingleCostTrainer
......
......@@ -12,7 +12,7 @@ else:
import functools
__all__ = ['map_arg', 'memoized', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once']
'memoized_ignoreargs', 'log_once', 'call_only_once']
def map_arg(**maps):
......@@ -140,3 +140,42 @@ def log_once(message, func):
func(str): the name of the logger method. e.g. "info", "warn", "error".
"""
getattr(logger, func)(message)
_FUNC_CALLED = set()
def call_only_once(func):
"""
Decorate a method of a class, so that this method can only
be called once for every instance.
Calling it more than once will result in exception.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
assert hasattr(self, func.__name__), "call_only_once can only be used on method!"
key = (self, func)
assert key not in _FUNC_CALLED, \
"Method {}.{} can only be called once per object!".format(
type(self).__name__, func.__name__)
_FUNC_CALLED.add(key)
func(*args, **kwargs)
return wrapper
if __name__ == '__main__':
class A():
@call_only_once
def f(self, x):
print(x)
a = A()
a.f(1)
b = A()
b.f(2)
b.f(1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment