Commit 01f54c26 authored by Yuxin Wu's avatar Yuxin Wu

docs update

parent 000b2ea3
...@@ -6,21 +6,21 @@ Subpackages ...@@ -6,21 +6,21 @@ Subpackages
.. toctree:: .. toctree::
tensorpack.dataflow.dataset dataflow.dataset
tensorpack.dataflow.imgaug dataflow.imgaug
tensorpack.dataflow.dftools module Module contents
---------------------------------- ---------------
.. automodule:: tensorpack.dataflow.dftools .. automodule:: tensorpack.dataflow
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
Module contents tensorpack.dataflow.dftools module
--------------- ----------------------------------
.. automodule:: tensorpack.dataflow .. automodule:: tensorpack.dataflow.dftools
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
...@@ -5,12 +5,12 @@ API Documentation ...@@ -5,12 +5,12 @@ API Documentation
:maxdepth: 1 :maxdepth: 1
tensorpack.models models
tensorpack.dataflow dataflow
tensorpack.callbacks callbacks
tensorpack.train train
tensorpack.utils predict
tensorpack.tfutils tfutils
tensorpack.predict utils
tensorpack.RL RL
...@@ -5,3 +5,11 @@ tensorpack.train package ...@@ -5,3 +5,11 @@ tensorpack.train package
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.train.monitor module
------------------------------------
.. automodule:: tensorpack.train.monitor
:members:
:undoc-members:
:show-inheritance:
...@@ -46,8 +46,6 @@ TrainConfig( ...@@ -46,8 +46,6 @@ TrainConfig(
ProgressBar(), ProgressBar(),
# run `tf.summary.merge_all` and save results every epoch # run `tf.summary.merge_all` and save results every epoch
MergeAllSummaries(), MergeAllSummaries(),
# print all the statistics I've created, and scalar tensors I've summarized
StatPrinter(),
] ]
) )
``` ```
......
...@@ -6,7 +6,7 @@ from pkgutil import iter_modules ...@@ -6,7 +6,7 @@ from pkgutil import iter_modules
import os import os
import os.path import os.path
__all__ = [] __all__ = ['monitor']
def global_import(name): def global_import(name):
...@@ -19,6 +19,7 @@ def global_import(name): ...@@ -19,6 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = ['monitor']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
...@@ -26,4 +27,5 @@ for _, module_name, _ in iter_modules( ...@@ -26,4 +27,5 @@ for _, module_name, _ in iter_modules(
continue continue
if module_name.startswith('_'): if module_name.startswith('_'):
continue continue
global_import(module_name) if module_name not in _SKIP:
global_import(module_name)
...@@ -21,20 +21,35 @@ class TrainingMonitor(object): ...@@ -21,20 +21,35 @@ class TrainingMonitor(object):
""" """
Monitor a training progress, by processing different types of Monitor a training progress, by processing different types of
summary/statistics from trainer. summary/statistics from trainer.
.. document private functions
.. automethod:: _setup
""" """
def setup(self, trainer): def setup(self, trainer):
self._trainer = trainer self._trainer = trainer
self._setup() self._setup()
def _setup(self): def _setup(self):
""" Override this method to setup the monitor."""
pass pass
def put_summary(self, summary): def put_summary(self, summary):
"""
Process a tf.Summary.
"""
pass pass
def put(self, name, val): # TODO split by types? def put(self, name, val):
"""
Process a key-value pair.
"""
pass pass
def put_scalar(self, name, val):
self.put(name, val)
# TODO put other types
def flush(self): def flush(self):
pass pass
...@@ -47,6 +62,9 @@ class NoOpMonitor(TrainingMonitor): ...@@ -47,6 +62,9 @@ class NoOpMonitor(TrainingMonitor):
class Monitors(TrainingMonitor): class Monitors(TrainingMonitor):
"""
Merge monitors together for trainer to use.
"""
def __init__(self, monitors): def __init__(self, monitors):
# TODO filter by names # TODO filter by names
self._scalar_history = ScalarHistory() self._scalar_history = ScalarHistory()
...@@ -68,9 +86,9 @@ class Monitors(TrainingMonitor): ...@@ -68,9 +86,9 @@ class Monitors(TrainingMonitor):
for m in self._monitors: for m in self._monitors:
m.put_summary(summary) m.put_summary(summary)
def _dispatch_put(self, name, val): def _dispatch_put_scalar(self, name, val):
for m in self._monitors: for m in self._monitors:
m.put(name, val) m.put_scalar(name, val)
def put_summary(self, summary): def put_summary(self, summary):
if isinstance(summary, six.binary_type): if isinstance(summary, six.binary_type):
...@@ -86,24 +104,35 @@ class Monitors(TrainingMonitor): ...@@ -86,24 +104,35 @@ class Monitors(TrainingMonitor):
suffix = '-summary' # issue#6150 suffix = '-summary' # issue#6150
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self._dispatch_put(val.tag, val.simple_value) self._dispatch_put_scalar(val.tag, val.simple_value)
def put(self, name, val): def put(self, name, val):
val = float(val) # TODO only support numeric for now val = float(val) # TODO only support scalar for now
self.put_scalar(name, val)
self._dispatch_put(name, val) def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val)
s = tf.Summary() s = tf.Summary()
s.value.add(tag=name, simple_value=val) s.value.add(tag=name, simple_value=val)
self._dispatch_put_summary(s) self._dispatch_put_summary(s)
def get_latest(self, name): def get_latest(self, name):
"""
Get latest scalar value of some data.
"""
return self._scalar_history.get_latest(name) return self._scalar_history.get_latest(name)
def get_history(self, name): def get_history(self, name):
"""
Get a history of the scalar value of some data.
"""
return self._scalar_history.get_history(name) return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor): class TFSummaryWriter(TrainingMonitor):
"""
Write summaries to TensorFlow event file.
"""
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.LOG_DIR:
return super(TFSummaryWriter, cls).__new__(cls) return super(TFSummaryWriter, cls).__new__(cls)
...@@ -126,6 +155,9 @@ class TFSummaryWriter(TrainingMonitor): ...@@ -126,6 +155,9 @@ class TFSummaryWriter(TrainingMonitor):
class JSONWriter(TrainingMonitor): class JSONWriter(TrainingMonitor):
"""
Write all scalar data to a json, grouped by their global step.
"""
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.LOG_DIR:
return super(JSONWriter, cls).__new__(cls) return super(JSONWriter, cls).__new__(cls)
...@@ -150,7 +182,7 @@ class JSONWriter(TrainingMonitor): ...@@ -150,7 +182,7 @@ class JSONWriter(TrainingMonitor):
self._last_gs = -1 self._last_gs = -1
def put(self, name, val): def put_scalar(self, name, val):
gs = self._trainer.global_step gs = self._trainer.global_step
if gs != self._last_gs: if gs != self._last_gs:
self._push() self._push()
...@@ -181,6 +213,9 @@ class JSONWriter(TrainingMonitor): ...@@ -181,6 +213,9 @@ class JSONWriter(TrainingMonitor):
# TODO print interval # TODO print interval
class ScalarPrinter(TrainingMonitor): class ScalarPrinter(TrainingMonitor):
"""
Print all scalar data in terminal.
"""
def __init__(self): def __init__(self):
self._whitelist = None self._whitelist = None
self._blacklist = set([]) self._blacklist = set([])
...@@ -188,7 +223,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -188,7 +223,7 @@ class ScalarPrinter(TrainingMonitor):
def setup(self, _): def setup(self, _):
self._dic = {} self._dic = {}
def put(self, name, val): def put_scalar(self, name, val):
self._dic[name] = float(val) self._dic[name] = float(val)
def _print_stat(self): def _print_stat(self):
...@@ -203,10 +238,13 @@ class ScalarPrinter(TrainingMonitor): ...@@ -203,10 +238,13 @@ class ScalarPrinter(TrainingMonitor):
class ScalarHistory(TrainingMonitor): class ScalarHistory(TrainingMonitor):
"""
Only used by monitors internally.
"""
def setup(self, _): def setup(self, _):
self._dic = defaultdict(list) self._dic = defaultdict(list)
def put(self, name, val): def put_scalar(self, name, val):
self._dic[name].append(float(val)) self._dic[name].append(float(val))
def get_latest(self, name): def get_latest(self, name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment