Commit 3398df09 authored by Yuxin Wu's avatar Yuxin Wu

update docs. check bounds (fix #652)

parent 9227aa8e
......@@ -18,7 +18,7 @@ from ..utils.utils import get_tqdm_kwargs
from ..dataflow.base import DataFlow
from ..input_source import (
InputSource, FeedInput, QueueInput)
InputSource, FeedInput, QueueInput, StagingInput)
from ..graph_builder.predict import SimplePredictBuilder
from .base import Callback
......@@ -118,6 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=True) # TODO a better way to handle inference size
assert isinstance(input, InputSource), input
assert not isinstance(input, StagingInput), input
self._tower_name = tower_name
self._device = device
super(InferenceRunner, self).__init__(input, infs)
......
......@@ -45,6 +45,8 @@ class CenterCrop(TransformAugmentorBase):
def _get_augment_params(self, img):
orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
......
......@@ -123,7 +123,8 @@ class MultiThreadMapData(_ParallelMapData):
if self.stopped():
return
# cannot ignore None here. will lead to unsynced send/recv
self.outq.put(self.func(dp))
obj = self.func(dp)
self.queue_put_stoppable(self.outq, obj)
except Exception:
if self.stopped():
pass # skip duplicated error messages
......@@ -190,7 +191,10 @@ class MultiThreadMapData(_ParallelMapData):
if self._evt is not None:
self._evt.set()
for p in self._threads:
p.join()
p.stop()
p.join(timeout=5.0)
# if p.is_alive():
# logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
......
......@@ -70,12 +70,12 @@ class DataFromQueue(DataFlow):
class DataFromList(RNGDataFlow):
""" Wrap a list of datapoitns to a DataFlow"""
""" Wrap a list of datapoints to a DataFlow"""
def __init__(self, lst, shuffle=True):
"""
Args:
lst (list): input list.
lst (list): input list. Each element is a datapoint.
shuffle (bool): shuffle data.
"""
super(DataFromList, self).__init__()
......
......@@ -339,6 +339,8 @@ class HorovodTrainer(SingleCostTrainer):
# NOTE It will fail if GPU was already detected before initializing the session
# https://github.com/tensorflow/tensorflow/issues/8136
session_creator.config.gpu_options.visible_device_list = str(self._local_rank)
# TODO split #CPUs
# session_creator.config.inter_op_parallelism_threads =
super(HorovodTrainer, self).initialize(
session_creator, session_init)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment