Commit 9fac1a6c authored by Yuxin Wu's avatar Yuxin Wu

use yield from in dataflow; update logger name in dataflow.

parent 0641618d
......@@ -60,7 +60,7 @@ If this command failed, tell us your version of Python/TF/tensorpack.
Note that:
+ You can install Tensorpack master by `pip install -U git+https://github.com/ppwwyyxx/tensorpack.git`
+ You can install Tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git`
and see if your issue is already solved.
+ If you're not using tensorpack under a normal command line shell (e.g.,
using an IDE or jupyter notebook), please retry under a normal command line shell.
......
......@@ -8,6 +8,7 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+ 2019/11/10. Drop Python 2 support.
+ [2019/03/20](https://github.com/tensorpack/tensorpack/commit/b8a50d72a7c655b6dc6facb17efd74069ba7f86c).
The concept of `InputDesc` was replaced by its equivalent in TF:
`tf.TensorSpec`. This may be a breaking change if you have customized
......
......@@ -2,7 +2,6 @@
Tensorpack is a neural network training interface based on TensorFlow.
[![Build Status](https://travis-ci.org/tensorpack/tensorpack.svg?branch=master)](https://travis-ci.org/tensorpack/tensorpack)
[![ReadTheDoc](https://readthedocs.org/projects/tensorpack/badge/?version=latest)](http://tensorpack.readthedocs.io)
[![Gitter chat](https://img.shields.io/badge/chat-on%20gitter-46bc99.svg)](https://gitter.im/tensorpack/users)
[![model-zoo](https://img.shields.io/badge/model-zoo-brightgreen.svg)](http://models.tensorpack.com)
......
......@@ -48,8 +48,7 @@ class TestDataSpeed(ProxyDataFlow):
def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """
self.start()
for dp in self.ds:
yield dp
yield from self.ds
def start(self):
"""
......@@ -387,12 +386,10 @@ class RepeatedData(ProxyDataFlow):
def __iter__(self):
if self.num == -1:
while True:
for dp in self.ds:
yield dp
yield from self.ds
else:
for _ in range(self.num):
for dp in self.ds:
yield dp
yield from self.ds
class RepeatedDataPoint(ProxyDataFlow):
......@@ -519,8 +516,7 @@ class ConcatData(DataFlow):
def __iter__(self):
for d in self.df_lists:
for dp in d.__iter__():
yield dp
yield from d
class JoinData(DataFlow):
......@@ -702,8 +698,7 @@ class CacheData(ProxyDataFlow):
if len(self.buffer):
if self.shuffle:
self.rng.shuffle(self.buffer)
for dp in self.buffer:
yield dp
yield from self.buffer
else:
for dp in self.ds:
yield dp
......
......@@ -48,8 +48,7 @@ class _ExceptionWrapper:
def _repeat_iter(get_itr):
while True:
for x in get_itr():
yield x
yield from get_itr()
def _bind_guard(sock, name):
......
......@@ -85,11 +85,9 @@ class _ParallelMapData(ProxyDataFlow):
def __iter__(self):
if self._strict:
for dp in self.get_data_strict():
yield dp
yield from self.get_data_strict()
else:
for dp in self.get_data_non_strict():
yield dp
yield from self.get_data_non_strict()
class MultiThreadMapData(_ParallelMapData):
......@@ -205,8 +203,7 @@ class MultiThreadMapData(_ParallelMapData):
def __iter__(self):
with self._guard:
for dp in super(MultiThreadMapData, self).__iter__():
yield dp
yield from super(MultiThreadMapData, self).__iter__()
def __del__(self):
if self._evt is not None:
......@@ -320,8 +317,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def __iter__(self):
with self._guard, _zmq_catch_error(type(self).__name__):
for dp in super(MultiProcessMapDataZMQ, self).__iter__():
yield dp
yield from super(MultiProcessMapDataZMQ, self).__iter__()
class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
......
......@@ -85,8 +85,7 @@ class DataFromList(RNGDataFlow):
def __iter__(self):
if not self.shuffle:
for k in self.lst:
yield k
yield from self.lst
else:
idxs = np.arange(len(self.lst))
self.rng.shuffle(idxs)
......@@ -110,9 +109,7 @@ class DataFromGenerator(DataFlow):
self._gen = gen
def __iter__(self):
# yield from
for dp in self._gen():
yield dp
yield from self._gen()
class DataFromIterable(DataFlow):
......@@ -129,5 +126,4 @@ class DataFromIterable(DataFlow):
return self._len
def __iter__(self):
for dp in self._itr:
yield dp
yield from self._itr
......@@ -46,7 +46,9 @@ class _MyFormatter(logging.Formatter):
def _getlogger():
logger = logging.getLogger('tensorpack')
# this file is synced to "dataflow" package as well
package_name = "dataflow" if __name__.startswith("dataflow") else "tensorpack"
logger = logging.getLogger(package_name)
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment