Commit 5cdf1d33 authored by Yuxin Wu's avatar Yuxin Wu

Better picklability for "spawn" mp context

parent fa0f57da
...@@ -295,16 +295,17 @@ class MapDataComponent(MapData): ...@@ -295,16 +295,17 @@ class MapDataComponent(MapData):
return None to discard this datapoint. return None to discard this datapoint.
index (int): index of the component. index (int): index of the component.
""" """
index = int(index) self._index = int(index)
self._func = func
def f(dp): super(MapDataComponent, self).__init__(ds, self._mapper)
r = func(dp[index])
if r is None: def _mapper(self, dp):
return None r = self._func(dp[self._index])
dp = list(dp) # shallow copy to avoid modifying the list if r is None:
dp[index] = r return None
return dp dp = list(dp) # shallow copy to avoid modifying the list
super(MapDataComponent, self).__init__(ds, f) dp[self._index] = r
return dp
class RepeatedData(ProxyDataFlow): class RepeatedData(ProxyDataFlow):
......
...@@ -103,23 +103,22 @@ class AugmentImageComponent(MapDataComponent): ...@@ -103,23 +103,22 @@ class AugmentImageComponent(MapDataComponent):
self.augs = augmentors self.augs = augmentors
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self._copy = copy
exception_handler = ExceptionHandler(catch_exceptions) self._exception_handler = ExceptionHandler(catch_exceptions)
super(AugmentImageComponent, self).__init__(ds, self._aug_mapper, index)
def func(x):
check_dtype(x)
with exception_handler.catch():
if copy:
x = copy_mod.deepcopy(x)
return self.augs.augment(x)
super(AugmentImageComponent, self).__init__(
ds, func, index)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
def _aug_mapper(self, x):
check_dtype(x)
with self._exception_handler.catch():
if self._copy:
x = copy_mod.deepcopy(x)
return self.augs.augment(x)
class AugmentImageCoordinates(MapData): class AugmentImageCoordinates(MapData):
""" """
...@@ -142,27 +141,30 @@ class AugmentImageCoordinates(MapData): ...@@ -142,27 +141,30 @@ class AugmentImageCoordinates(MapData):
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
exception_handler = ExceptionHandler(catch_exceptions) self._img_index = img_index
self._coords_index = coords_index
self._copy = copy
self._exception_handler = ExceptionHandler(catch_exceptions)
def func(dp): super(AugmentImageCoordinates, self).__init__(ds, self._aug_mapper)
with exception_handler.catch():
img, coords = dp[img_index], dp[coords_index]
check_dtype(img)
validate_coords(coords)
if copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
dp[img_index] = img
coords = self.augs._augment_coords(coords, prms)
dp[coords_index] = coords
return dp
super(AugmentImageCoordinates, self).__init__(ds, func)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
def _aug_mapper(self, dp):
with self._exception_handler.catch():
img, coords = dp[self._img_index], dp[self._coords_index]
check_dtype(img)
validate_coords(coords)
if self._copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
dp[self._img_index] = img
coords = self.augs._augment_coords(coords, prms)
dp[self._coords_index] = coords
return dp
class AugmentImageComponents(MapData): class AugmentImageComponents(MapData):
""" """
......
...@@ -170,7 +170,9 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -170,7 +170,9 @@ class MultiProcessPrefetchData(ProxyDataFlow):
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
""" """
if os.name == 'nt': if os.name == 'nt':
logger.warn("MultiProcessPrefetchData may not support windows!") logger.warn("MultiProcessPrefetchData does support windows. \
However, windows requires more strict picklability on processes, which may \
lead of failure on some of the code.")
super(MultiProcessPrefetchData, self).__init__(ds) super(MultiProcessPrefetchData, self).__init__(ds)
try: try:
self._size = ds.size() self._size = ds.size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment