Commit c088e2a6 authored by Yuxin Wu's avatar Yuxin Wu

fix doc build. make MergeAllSummaries a Triggerable

parent c15c8639
...@@ -22,12 +22,7 @@ sys.path.insert(0, os.path.abspath('../')) ...@@ -22,12 +22,7 @@ sys.path.insert(0, os.path.abspath('../'))
os.environ['TENSORPACK_DOC_BUILDING'] = '1' os.environ['TENSORPACK_DOC_BUILDING'] = '1'
MOCK_MODULES = ['scipy', MOCK_MODULES = ['scipy', 'tabulate',
#'tensorflow', 'tensorflow.contrib',
#'tensorflow.python.ops',
#'tensorflow.contrib.framework',
#'tensorflow.python',
#'tensorflow.python.training',
'sklearn.datasets', 'sklearn', 'sklearn.datasets', 'sklearn',
'scipy.misc', 'h5py', 'nltk', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb', 'cv2', 'scipy.io', 'dill', 'zmq', 'subprocess32', 'lmdb',
......
...@@ -20,7 +20,6 @@ except ImportError: ...@@ -20,7 +20,6 @@ except ImportError:
# configure requirements # configure requirements
reqfile = os.path.join(CURRENT_DIR, 'requirements.txt') reqfile = os.path.join(CURRENT_DIR, 'requirements.txt')
req = [x.strip() for x in open(reqfile).readlines()] req = [x.strip() for x in open(reqfile).readlines()]
reqfile = os.path.join(CURRENT_DIR, 'opt-requirements.txt') reqfile = os.path.join(CURRENT_DIR, 'opt-requirements.txt')
extra_req = [x.strip() for x in open(reqfile).readlines()] extra_req = [x.strip() for x in open(reqfile).readlines()]
...@@ -43,7 +42,6 @@ setup( ...@@ -43,7 +42,6 @@ setup(
version=__version__, version=__version__,
description='Neural Network Toolbox on TensorFlow', description='Neural Network Toolbox on TensorFlow',
long_description=long_description, long_description=long_description,
install_requires=req, install_requires=req,
tests_require=['flake8'], tests_require=['flake8'],
extras_require={ extras_require={
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback from .base import Callback, Triggerable
__all__ = ['MovingAverageSummary', 'MergeAllSummaries'] __all__ = ['MovingAverageSummary', 'MergeAllSummaries']
...@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback): ...@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
return [self.ema_op] return [self.ema_op]
class MergeAllSummaries(Callback): class MergeAllSummaries(Triggerable):
""" """
Evaluate all summaries by `tf.summary.merge_all`, and write to logs. Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
""" """
...@@ -70,7 +70,15 @@ class MergeAllSummaries(Callback): ...@@ -70,7 +70,15 @@ class MergeAllSummaries(Callback):
return return
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def _summary_run_alone(self):
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self): def _trigger_epoch(self):
if self._run_alone: if self._run_alone:
summary = self.summary_op.eval() self._summary_run_alone()
self.trainer.monitors.put_summary(summary)
def _trigger(self):
assert self._run_alone, \
"MergeAllSummaries can be used as a trigger only if run_alone=True."
self._summary_run_alone()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment