Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3398df09
Commit
3398df09
authored
Feb 13, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs. check bounds (fix #652)
parent
9227aa8e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
5 deletions
+14
-5
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+2
-0
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+6
-2
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+2
-2
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-0
No files found.
tensorpack/callbacks/inference_runner.py
View file @
3398df09
...
@@ -18,7 +18,7 @@ from ..utils.utils import get_tqdm_kwargs
...
@@ -18,7 +18,7 @@ from ..utils.utils import get_tqdm_kwargs
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
)
from
..graph_builder.predict
import
SimplePredictBuilder
from
..graph_builder.predict
import
SimplePredictBuilder
from
.base
import
Callback
from
.base
import
Callback
...
@@ -118,6 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -118,6 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
if
isinstance
(
input
,
DataFlow
):
if
isinstance
(
input
,
DataFlow
):
input
=
FeedInput
(
input
,
infinite
=
True
)
# TODO a better way to handle inference size
input
=
FeedInput
(
input
,
infinite
=
True
)
# TODO a better way to handle inference size
assert
isinstance
(
input
,
InputSource
),
input
assert
isinstance
(
input
,
InputSource
),
input
assert
not
isinstance
(
input
,
StagingInput
),
input
self
.
_tower_name
=
tower_name
self
.
_tower_name
=
tower_name
self
.
_device
=
device
self
.
_device
=
device
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
3398df09
...
@@ -45,6 +45,8 @@ class CenterCrop(TransformAugmentorBase):
...
@@ -45,6 +45,8 @@ class CenterCrop(TransformAugmentorBase):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
assert
orig_shape
[
0
]
>=
self
.
crop_shape
[
0
]
\
and
orig_shape
[
1
]
>=
self
.
crop_shape
[
1
],
orig_shape
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
return
CropTransform
(
h0
,
w0
,
self
.
crop_shape
[
0
],
self
.
crop_shape
[
1
])
return
CropTransform
(
h0
,
w0
,
self
.
crop_shape
[
0
],
self
.
crop_shape
[
1
])
...
...
tensorpack/dataflow/parallel_map.py
View file @
3398df09
...
@@ -123,7 +123,8 @@ class MultiThreadMapData(_ParallelMapData):
...
@@ -123,7 +123,8 @@ class MultiThreadMapData(_ParallelMapData):
if
self
.
stopped
():
if
self
.
stopped
():
return
return
# cannot ignore None here. will lead to unsynced send/recv
# cannot ignore None here. will lead to unsynced send/recv
self
.
outq
.
put
(
self
.
func
(
dp
))
obj
=
self
.
func
(
dp
)
self
.
queue_put_stoppable
(
self
.
outq
,
obj
)
except
Exception
:
except
Exception
:
if
self
.
stopped
():
if
self
.
stopped
():
pass
# skip duplicated error messages
pass
# skip duplicated error messages
...
@@ -190,7 +191,10 @@ class MultiThreadMapData(_ParallelMapData):
...
@@ -190,7 +191,10 @@ class MultiThreadMapData(_ParallelMapData):
if
self
.
_evt
is
not
None
:
if
self
.
_evt
is
not
None
:
self
.
_evt
.
set
()
self
.
_evt
.
set
()
for
p
in
self
.
_threads
:
for
p
in
self
.
_threads
:
p
.
join
()
p
.
stop
()
p
.
join
(
timeout
=
5.0
)
# if p.is_alive():
# logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
# TODO deprecated
...
...
tensorpack/dataflow/raw.py
View file @
3398df09
...
@@ -70,12 +70,12 @@ class DataFromQueue(DataFlow):
...
@@ -70,12 +70,12 @@ class DataFromQueue(DataFlow):
class
DataFromList
(
RNGDataFlow
):
class
DataFromList
(
RNGDataFlow
):
""" Wrap a list of datapoi
tn
s to a DataFlow"""
""" Wrap a list of datapoi
nt
s to a DataFlow"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
def
__init__
(
self
,
lst
,
shuffle
=
True
):
"""
"""
Args:
Args:
lst (list): input list.
lst (list): input list.
Each element is a datapoint.
shuffle (bool): shuffle data.
shuffle (bool): shuffle data.
"""
"""
super
(
DataFromList
,
self
)
.
__init__
()
super
(
DataFromList
,
self
)
.
__init__
()
...
...
tensorpack/train/trainers.py
View file @
3398df09
...
@@ -339,6 +339,8 @@ class HorovodTrainer(SingleCostTrainer):
...
@@ -339,6 +339,8 @@ class HorovodTrainer(SingleCostTrainer):
# NOTE It will fail if GPU was already detected before initializing the session
# NOTE It will fail if GPU was already detected before initializing the session
# https://github.com/tensorflow/tensorflow/issues/8136
# https://github.com/tensorflow/tensorflow/issues/8136
session_creator
.
config
.
gpu_options
.
visible_device_list
=
str
(
self
.
_local_rank
)
session_creator
.
config
.
gpu_options
.
visible_device_list
=
str
(
self
.
_local_rank
)
# TODO split #CPUs
# session_creator.config.inter_op_parallelism_threads =
super
(
HorovodTrainer
,
self
)
.
initialize
(
super
(
HorovodTrainer
,
self
)
.
initialize
(
session_creator
,
session_init
)
session_creator
,
session_init
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment