Commit 822997c7 authored by Yuxin Wu's avatar Yuxin Wu

update docs and new model

parent 65bbd28a
...@@ -2,3 +2,6 @@ An issue has to be one of the following: ...@@ -2,3 +2,6 @@ An issue has to be one of the following:
- Unexpected Problems / Potential Bugs - Unexpected Problems / Potential Bugs
- Feature Requests - Feature Requests
- Questions on Using/Understanding Tensorpack - Questions on Using/Understanding Tensorpack
To post an issue, please click "New Issue", choose your category, and read
instructions there.
...@@ -13,6 +13,13 @@ A datapoint is a **list** of Python objects which are called the `components` of ...@@ -13,6 +13,13 @@ A datapoint is a **list** of Python objects which are called the `components` of
that yields datapoints (lists) of two components: that yields datapoints (lists) of two components:
a numpy array of shape (64, 28, 28), and an array of shape (64,). a numpy array of shape (64, 28, 28), and an array of shape (64,).
As you saw,
DataFlow is __independent__ of TensorFlow since it produces any python objects
(usually numpy arrays).
To `import tensorpack.dataflow`, you don't even have to install TensorFlow.
You can simply use DataFlow as a data processing pipeline and plug it into any other frameworks.
### Composition of DataFlow ### Composition of DataFlow
One good thing about having a standard interface is to be able to provide One good thing about having a standard interface is to be able to provide
the greatest code reusability. the greatest code reusability.
...@@ -65,8 +72,3 @@ generator = df.get_data() ...@@ -65,8 +72,3 @@ generator = df.get_data()
for dp in generator: for dp in generator:
# dp is now a list. do whatever # dp is now a list. do whatever
``` ```
DataFlow is __independent__ of both tensorpack and TensorFlow.
To `import tensorpack.dataflow`, you don't even have to install TensorFlow.
You can simply use it as a data processing pipeline and plug it into any other frameworks.
This diff is collapsed.
...@@ -441,7 +441,7 @@ class EvalCallback(Callback): ...@@ -441,7 +441,7 @@ class EvalCallback(Callback):
logger.info("[EvalCallback] Will evaluate every {} epochs".format(interval)) logger.info("[EvalCallback] Will evaluate every {} epochs".format(interval))
def _eval(self): def _eval(self):
if cfg.TRAINER == 'replicated' or cfg.TRAIN.NUM_GPUS == 1: if cfg.TRAINER == 'replicated':
with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \ with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar: tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar:
futures = [] futures = []
......
...@@ -86,11 +86,13 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -86,11 +86,13 @@ class PeriodicRunHooks(ProxyCallback):
class EnableCallbackIf(ProxyCallback): class EnableCallbackIf(ProxyCallback):
""" """
Enable the ``{before,after}_epoch``, ``{before,after}_run``, Disable the ``{before,after}_epoch``, ``{before,after}_run``,
``trigger_{epoch,step}`` ``trigger_{epoch,step}``
methods of a callback, only when some condition satisfies. methods of a callback, unless some condition satisfies.
The other methods are unaffected. The other methods are unaffected.
A more accurate name for this callback should be "DisableCallbackUnless", but that's too ugly.
Note: Note:
If you use ``{before,after}_run``, If you use ``{before,after}_run``,
``pred`` will be evaluated only in ``before_run``. ``pred`` will be evaluated only in ``before_run``.
...@@ -101,6 +103,7 @@ class EnableCallbackIf(ProxyCallback): ...@@ -101,6 +103,7 @@ class EnableCallbackIf(ProxyCallback):
Args: Args:
callback (Callback): callback (Callback):
pred (self -> bool): a callable predicate. Has to be a pure function. pred (self -> bool): a callable predicate. Has to be a pure function.
The callback is disabled unless this predicate returns True.
""" """
self._pred = pred self._pred = pred
super(EnableCallbackIf, self).__init__(callback) super(EnableCallbackIf, self).__init__(callback)
......
#!/usr/bin/env python
import os import os
from .serialize import loads_msgpack, loads_pyarrow, dumps_msgpack, dumps_pyarrow from .serialize import loads_msgpack, loads_pyarrow, dumps_msgpack, dumps_pyarrow
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# File: serialize.py # File: serialize.py
import os import os
import pyarrow as pa
from .develop import create_dummy_func from .develop import create_dummy_func
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
...@@ -46,6 +44,16 @@ def loads_pyarrow(buf): ...@@ -46,6 +44,16 @@ def loads_pyarrow(buf):
return pa.deserialize(buf) return pa.deserialize(buf)
try:
# import pyarrow has a lot of side effect: https://github.com/apache/arrow/pull/2329
# So we need an option to disable it.
if os.environ.get('TENSORPACK_SERIALIZE', 'pyarrow') == 'pyarrow':
import pyarrow as pa
except ImportError:
pa = None
dumps_pyarrow = create_dummy_func('dumps_pyarrow', ['pyarrow']) # noqa
loads_pyarrow = create_dummy_func('loads_pyarrow', ['pyarrow']) # noqa
try: try:
import msgpack import msgpack
import msgpack_numpy import msgpack_numpy
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment