Commit d2262d1d authored by ppwwyyxx's avatar ppwwyyxx

fix image.py

parent 977134e1
......@@ -11,11 +11,12 @@ __all__ = ['ImageFromFile']
class ImageFromFile(DataFlow):
""" generate rgb images from files """
def __init__(self, files, channel, resize=None):
def __init__(self, files, channel=3, resize=None):
""" files: list of file path
channel: 1 or 3 channel
resize: a (w, h) tuple. If given, will force a resize
resize: a (h, w) tuple. If given, will force a resize
"""
assert len(self.files)
self.files = files
self.channel = int(channel)
self.resize = resize
......@@ -30,6 +31,6 @@ class ImageFromFile(DataFlow):
if self.channel == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.resize is not None:
im = cv2.resize(im, self.resize)
yield (im,)
im = cv2.resize(im, self.resize[::-1])
yield [im]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment