Commit 01f54c26 authored by Yuxin Wu's avatar Yuxin Wu

docs update

parent 000b2ea3
......@@ -6,21 +6,21 @@ Subpackages
.. toctree::
tensorpack.dataflow.dataset
tensorpack.dataflow.imgaug
dataflow.dataset
dataflow.imgaug
tensorpack.dataflow.dftools module
----------------------------------
Module contents
---------------
.. automodule:: tensorpack.dataflow.dftools
.. automodule:: tensorpack.dataflow
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
tensorpack.dataflow.dftools module
----------------------------------
.. automodule:: tensorpack.dataflow
.. automodule:: tensorpack.dataflow.dftools
:members:
:undoc-members:
:show-inheritance:
......@@ -5,12 +5,12 @@ API Documentation
:maxdepth: 1
tensorpack.models
tensorpack.dataflow
tensorpack.callbacks
tensorpack.train
tensorpack.utils
tensorpack.tfutils
tensorpack.predict
tensorpack.RL
models
dataflow
callbacks
train
predict
tfutils
utils
RL
......@@ -5,3 +5,11 @@ tensorpack.train package
:members:
:undoc-members:
:show-inheritance:
tensorpack.train.monitor module
------------------------------------
.. automodule:: tensorpack.train.monitor
:members:
:undoc-members:
:show-inheritance:
......@@ -46,8 +46,6 @@ TrainConfig(
ProgressBar(),
# run `tf.summary.merge_all` and save results every epoch
MergeAllSummaries(),
# print all the statistics I've created, and scalar tensors I've summarized
StatPrinter(),
]
)
```
......
......@@ -6,7 +6,7 @@ from pkgutil import iter_modules
import os
import os.path
__all__ = []
__all__ = ['monitor']
def global_import(name):
......@@ -19,6 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['monitor']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......@@ -26,4 +27,5 @@ for _, module_name, _ in iter_modules(
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
......@@ -21,20 +21,35 @@ class TrainingMonitor(object):
"""
Monitor a training progress, by processing different types of
summary/statistics from trainer.
.. document private functions
.. automethod:: _setup
"""
def setup(self, trainer):
self._trainer = trainer
self._setup()
def _setup(self):
""" Override this method to setup the monitor."""
pass
def put_summary(self, summary):
"""
Process a tf.Summary.
"""
pass
def put(self, name, val): # TODO split by types?
def put(self, name, val):
"""
Process a key-value pair.
"""
pass
def put_scalar(self, name, val):
self.put(name, val)
# TODO put other types
def flush(self):
pass
......@@ -47,6 +62,9 @@ class NoOpMonitor(TrainingMonitor):
class Monitors(TrainingMonitor):
"""
Merge monitors together for trainer to use.
"""
def __init__(self, monitors):
# TODO filter by names
self._scalar_history = ScalarHistory()
......@@ -68,9 +86,9 @@ class Monitors(TrainingMonitor):
for m in self._monitors:
m.put_summary(summary)
def _dispatch_put(self, name, val):
def _dispatch_put_scalar(self, name, val):
for m in self._monitors:
m.put(name, val)
m.put_scalar(name, val)
def put_summary(self, summary):
if isinstance(summary, six.binary_type):
......@@ -86,24 +104,35 @@ class Monitors(TrainingMonitor):
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self._dispatch_put(val.tag, val.simple_value)
self._dispatch_put_scalar(val.tag, val.simple_value)
def put(self, name, val):
val = float(val) # TODO only support numeric for now
val = float(val) # TODO only support scalar for now
self.put_scalar(name, val)
self._dispatch_put(name, val)
def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val)
s = tf.Summary()
s.value.add(tag=name, simple_value=val)
self._dispatch_put_summary(s)
def get_latest(self, name):
"""
Get latest scalar value of some data.
"""
return self._scalar_history.get_latest(name)
def get_history(self, name):
"""
Get a history of the scalar value of some data.
"""
return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor):
"""
Write summaries to TensorFlow event file.
"""
def __new__(cls):
if logger.LOG_DIR:
return super(TFSummaryWriter, cls).__new__(cls)
......@@ -126,6 +155,9 @@ class TFSummaryWriter(TrainingMonitor):
class JSONWriter(TrainingMonitor):
"""
Write all scalar data to a json, grouped by their global step.
"""
def __new__(cls):
if logger.LOG_DIR:
return super(JSONWriter, cls).__new__(cls)
......@@ -150,7 +182,7 @@ class JSONWriter(TrainingMonitor):
self._last_gs = -1
def put(self, name, val):
def put_scalar(self, name, val):
gs = self._trainer.global_step
if gs != self._last_gs:
self._push()
......@@ -181,6 +213,9 @@ class JSONWriter(TrainingMonitor):
# TODO print interval
class ScalarPrinter(TrainingMonitor):
"""
Print all scalar data in terminal.
"""
def __init__(self):
self._whitelist = None
self._blacklist = set([])
......@@ -188,7 +223,7 @@ class ScalarPrinter(TrainingMonitor):
def setup(self, _):
self._dic = {}
def put(self, name, val):
def put_scalar(self, name, val):
self._dic[name] = float(val)
def _print_stat(self):
......@@ -203,10 +238,13 @@ class ScalarPrinter(TrainingMonitor):
class ScalarHistory(TrainingMonitor):
"""
Only used by monitors internally.
"""
def setup(self, _):
self._dic = defaultdict(list)
def put(self, name, val):
def put_scalar(self, name, val):
self._dic[name].append(float(val))
def get_latest(self, name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment