Commit 3db6ccac authored by Yuxin Wu's avatar Yuxin Wu

RunOp accepts either Op or lambda

parent 8e2428d9
...@@ -20,12 +20,13 @@ __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor ...@@ -20,12 +20,13 @@ __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor
class RunOp(Callback): class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, def __init__(self, op,
run_before=True, run_as_trigger=True, run_before=True, run_as_trigger=True,
run_step=False, verbose=False): run_step=False, verbose=False):
""" """
Args: Args:
setup_func: a function that returns the Op in the graph op (tf.Operation or function): an Op, or a function that returns the Op in the graph.
The function will be called later (in the `setup_graph` callback).
run_before (bool): run the Op before training run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training) run_step (bool): run the Op every step (along with training)
...@@ -36,7 +37,9 @@ class RunOp(Callback): ...@@ -36,7 +37,9 @@ class RunOp(Callback):
<https://github.com/ppwwyyxx/tensorpack/blob/master/examples/DeepQNetwork/>`_ <https://github.com/ppwwyyxx/tensorpack/blob/master/examples/DeepQNetwork/>`_
uses this callback to update target network. uses this callback to update target network.
""" """
self.setup_func = setup_func if not callable(op):
op = lambda: op # noqa
self.setup_func = op
self.run_before = run_before self.run_before = run_before
self.run_as_trigger = run_as_trigger self.run_as_trigger = run_as_trigger
self.run_step = run_step self.run_step = run_step
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment