Commit 9fac1a6c authored by Yuxin Wu's avatar Yuxin Wu

use yield from in dataflow; update logger name in dataflow.

parent 0641618d
...@@ -60,7 +60,7 @@ If this command failed, tell us your version of Python/TF/tensorpack. ...@@ -60,7 +60,7 @@ If this command failed, tell us your version of Python/TF/tensorpack.
Note that: Note that:
+ You can install Tensorpack master by `pip install -U git+https://github.com/ppwwyyxx/tensorpack.git` + You can install Tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git`
and see if your issue is already solved. and see if your issue is already solved.
+ If you're not using tensorpack under a normal command line shell (e.g., + If you're not using tensorpack under a normal command line shell (e.g.,
using an IDE or jupyter notebook), please retry under a normal command line shell. using an IDE or jupyter notebook), please retry under a normal command line shell.
......
...@@ -8,6 +8,7 @@ so you don't need to look at here very often. ...@@ -8,6 +8,7 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ 2019/11/10. Drop Python 2 support.
+ [2019/03/20](https://github.com/tensorpack/tensorpack/commit/b8a50d72a7c655b6dc6facb17efd74069ba7f86c). + [2019/03/20](https://github.com/tensorpack/tensorpack/commit/b8a50d72a7c655b6dc6facb17efd74069ba7f86c).
The concept of `InputDesc` was replaced by its equivalent in TF: The concept of `InputDesc` was replaced by its equivalent in TF:
`tf.TensorSpec`. This may be a breaking change if you have customized `tf.TensorSpec`. This may be a breaking change if you have customized
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Tensorpack is a neural network training interface based on TensorFlow. Tensorpack is a neural network training interface based on TensorFlow.
[![Build Status](https://travis-ci.org/tensorpack/tensorpack.svg?branch=master)](https://travis-ci.org/tensorpack/tensorpack)
[![ReadTheDoc](https://readthedocs.org/projects/tensorpack/badge/?version=latest)](http://tensorpack.readthedocs.io) [![ReadTheDoc](https://readthedocs.org/projects/tensorpack/badge/?version=latest)](http://tensorpack.readthedocs.io)
[![Gitter chat](https://img.shields.io/badge/chat-on%20gitter-46bc99.svg)](https://gitter.im/tensorpack/users) [![Gitter chat](https://img.shields.io/badge/chat-on%20gitter-46bc99.svg)](https://gitter.im/tensorpack/users)
[![model-zoo](https://img.shields.io/badge/model-zoo-brightgreen.svg)](http://models.tensorpack.com) [![model-zoo](https://img.shields.io/badge/model-zoo-brightgreen.svg)](http://models.tensorpack.com)
......
...@@ -48,8 +48,7 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -48,8 +48,7 @@ class TestDataSpeed(ProxyDataFlow):
def __iter__(self): def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """ """ Will run testing at the beginning, then produce data normally. """
self.start() self.start()
for dp in self.ds: yield from self.ds
yield dp
def start(self): def start(self):
""" """
...@@ -387,12 +386,10 @@ class RepeatedData(ProxyDataFlow): ...@@ -387,12 +386,10 @@ class RepeatedData(ProxyDataFlow):
def __iter__(self): def __iter__(self):
if self.num == -1: if self.num == -1:
while True: while True:
for dp in self.ds: yield from self.ds
yield dp
else: else:
for _ in range(self.num): for _ in range(self.num):
for dp in self.ds: yield from self.ds
yield dp
class RepeatedDataPoint(ProxyDataFlow): class RepeatedDataPoint(ProxyDataFlow):
...@@ -519,8 +516,7 @@ class ConcatData(DataFlow): ...@@ -519,8 +516,7 @@ class ConcatData(DataFlow):
def __iter__(self): def __iter__(self):
for d in self.df_lists: for d in self.df_lists:
for dp in d.__iter__(): yield from d
yield dp
class JoinData(DataFlow): class JoinData(DataFlow):
...@@ -702,8 +698,7 @@ class CacheData(ProxyDataFlow): ...@@ -702,8 +698,7 @@ class CacheData(ProxyDataFlow):
if len(self.buffer): if len(self.buffer):
if self.shuffle: if self.shuffle:
self.rng.shuffle(self.buffer) self.rng.shuffle(self.buffer)
for dp in self.buffer: yield from self.buffer
yield dp
else: else:
for dp in self.ds: for dp in self.ds:
yield dp yield dp
......
...@@ -48,8 +48,7 @@ class _ExceptionWrapper: ...@@ -48,8 +48,7 @@ class _ExceptionWrapper:
def _repeat_iter(get_itr): def _repeat_iter(get_itr):
while True: while True:
for x in get_itr(): yield from get_itr()
yield x
def _bind_guard(sock, name): def _bind_guard(sock, name):
......
...@@ -85,11 +85,9 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -85,11 +85,9 @@ class _ParallelMapData(ProxyDataFlow):
def __iter__(self): def __iter__(self):
if self._strict: if self._strict:
for dp in self.get_data_strict(): yield from self.get_data_strict()
yield dp
else: else:
for dp in self.get_data_non_strict(): yield from self.get_data_non_strict()
yield dp
class MultiThreadMapData(_ParallelMapData): class MultiThreadMapData(_ParallelMapData):
...@@ -205,8 +203,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -205,8 +203,7 @@ class MultiThreadMapData(_ParallelMapData):
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
for dp in super(MultiThreadMapData, self).__iter__(): yield from super(MultiThreadMapData, self).__iter__()
yield dp
def __del__(self): def __del__(self):
if self._evt is not None: if self._evt is not None:
...@@ -320,8 +317,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -320,8 +317,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def __iter__(self): def __iter__(self):
with self._guard, _zmq_catch_error(type(self).__name__): with self._guard, _zmq_catch_error(type(self).__name__):
for dp in super(MultiProcessMapDataZMQ, self).__iter__(): yield from super(MultiProcessMapDataZMQ, self).__iter__()
yield dp
class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
......
...@@ -85,8 +85,7 @@ class DataFromList(RNGDataFlow): ...@@ -85,8 +85,7 @@ class DataFromList(RNGDataFlow):
def __iter__(self): def __iter__(self):
if not self.shuffle: if not self.shuffle:
for k in self.lst: yield from self.lst
yield k
else: else:
idxs = np.arange(len(self.lst)) idxs = np.arange(len(self.lst))
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
...@@ -110,9 +109,7 @@ class DataFromGenerator(DataFlow): ...@@ -110,9 +109,7 @@ class DataFromGenerator(DataFlow):
self._gen = gen self._gen = gen
def __iter__(self): def __iter__(self):
# yield from yield from self._gen()
for dp in self._gen():
yield dp
class DataFromIterable(DataFlow): class DataFromIterable(DataFlow):
...@@ -129,5 +126,4 @@ class DataFromIterable(DataFlow): ...@@ -129,5 +126,4 @@ class DataFromIterable(DataFlow):
return self._len return self._len
def __iter__(self): def __iter__(self):
for dp in self._itr: yield from self._itr
yield dp
...@@ -46,7 +46,9 @@ class _MyFormatter(logging.Formatter): ...@@ -46,7 +46,9 @@ class _MyFormatter(logging.Formatter):
def _getlogger(): def _getlogger():
logger = logging.getLogger('tensorpack') # this file is synced to "dataflow" package as well
package_name = "dataflow" if __name__.startswith("dataflow") else "tensorpack"
logger = logging.getLogger(package_name)
logger.propagate = False logger.propagate = False
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment