Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b9a15df7
Commit
b9a15df7
authored
Aug 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use QueueInput in DataParallelInferenceRunner, correctness verified.
parent
5c241e09
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
38 additions
and
29 deletions
+38
-29
docs/conf.py
docs/conf.py
+1
-0
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+0
-1
examples/ResNet/imagenet_resnet_utils.py
examples/ResNet/imagenet_resnet_utils.py
+1
-0
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+13
-21
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+2
-0
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+6
-1
tensorpack/graph_builder/input_source.py
tensorpack/graph_builder/input_source.py
+6
-2
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+2
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+5
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
No files found.
docs/conf.py
View file @
b9a15df7
...
@@ -364,6 +364,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -364,6 +364,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'get_predictors'
,
'get_predictors'
,
'vs_name_for_predictor'
,
'vs_name_for_predictor'
,
'dump_chkpt_vars'
,
'dump_chkpt_vars'
,
'VisualQA'
,
'ParamRestore'
]:
'ParamRestore'
]:
return
True
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
...
...
examples/ResNet/imagenet-resnet.py
View file @
b9a15df7
...
@@ -74,7 +74,6 @@ class Model(ModelDesc):
...
@@ -74,7 +74,6 @@ class Model(ModelDesc):
def
get_data
(
name
):
def
get_data
(
name
):
isTrain
=
name
==
'train'
isTrain
=
name
==
'train'
augmentors
=
fbresnet_augmentor
(
isTrain
)
augmentors
=
fbresnet_augmentor
(
isTrain
)
augmentors
.
append
(
imgaug
.
ToUint8
())
datadir
=
args
.
data
datadir
=
args
.
data
return
get_imagenet_dataflow
(
return
get_imagenet_dataflow
(
datadir
,
name
,
BATCH_SIZE
,
augmentors
,
dir_structure
=
'original'
)
datadir
,
name
,
BATCH_SIZE
,
augmentors
,
dir_structure
=
'original'
)
...
...
examples/ResNet/imagenet_resnet_utils.py
View file @
b9a15df7
...
@@ -87,6 +87,7 @@ def get_imagenet_dataflow(
...
@@ -87,6 +87,7 @@ def get_imagenet_dataflow(
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
"""
assert
name
in
[
'train'
,
'val'
,
'test'
]
assert
name
in
[
'train'
,
'val'
,
'test'
]
assert
datadir
is
not
None
isTrain
=
name
==
'train'
isTrain
=
name
==
'train'
cpu
=
min
(
30
,
multiprocessing
.
cpu_count
())
cpu
=
min
(
30
,
multiprocessing
.
cpu_count
())
if
isTrain
:
if
isTrain
:
...
...
tensorpack/callbacks/inference_runner.py
View file @
b9a15df7
...
@@ -20,12 +20,11 @@ from ..dataflow.base import DataFlow
...
@@ -20,12 +20,11 @@ from ..dataflow.base import DataFlow
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
from
..graph_builder.input_source
import
(
FeedInput
,
DataParallelFeed
Input
)
FeedInput
,
Queue
Input
)
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.hooks
import
CallbackToHook
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
,
'DataParallelInferenceRunner'
]
'DataParallelInferenceRunner'
]
...
@@ -151,7 +150,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -151,7 +150,7 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
@
deprecated
(
"Just use InferenceRunner since it now accepts TensorInput!"
)
@
deprecated
(
"Just use InferenceRunner since it now accepts TensorInput!"
,
"2017-11-11"
)
def
FeedfreeInferenceRunner
(
*
args
,
**
kwargs
):
def
FeedfreeInferenceRunner
(
*
args
,
**
kwargs
):
return
InferenceRunner
(
*
args
,
**
kwargs
)
return
InferenceRunner
(
*
args
,
**
kwargs
)
...
@@ -170,9 +169,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -170,9 +169,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
"""
"""
self
.
_tower_names
=
[
'InferenceTower{}'
.
format
(
k
)
for
k
in
range
(
len
(
gpus
))]
self
.
_tower_names
=
[
'InferenceTower{}'
.
format
(
k
)
for
k
in
range
(
len
(
gpus
))]
if
isinstance
(
input
,
DataFlow
):
if
isinstance
(
input
,
DataFlow
):
input
=
DataParallelFeedInput
(
input
,
self
.
_tower_names
)
input
=
QueueInput
(
input
)
assert
isinstance
(
input
,
DataParallelFeedInput
),
input
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
...
@@ -187,13 +184,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -187,13 +184,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self
.
trainer
.
predictor_factory
.
build
(
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
))
tower_name
,
device
,
self
.
_input_source
))
# setup
feeds
and hooks
# setup
callbacks
and hooks
self
.
_
hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_
input_callbacks
=
Callbacks
(
cbs
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks_parallel
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks_parallel
.
extend
(
self
.
_input_callbacks
.
get_hooks
())
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
setup_graph
(
self
.
trainer
)
inf
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
class
InferencerToHookDataParallel
(
InferencerToHook
):
class
InferencerToHookDataParallel
(
InferencerToHook
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
...
@@ -223,7 +222,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -223,7 +222,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_before_train
(
self
):
def
_before_train
(
self
):
s
elf
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
s
uper
(
DataParallelInferenceRunner
,
self
)
.
_before_train
(
)
self
.
_parallel_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks_parallel
)
self
.
_parallel_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks_parallel
)
def
_trigger
(
self
):
def
_trigger
(
self
):
...
@@ -239,16 +238,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -239,16 +238,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
pbar
.
update
(
nr_tower
)
pbar
.
update
(
nr_tower
)
total
-=
nr_tower
total
-=
nr_tower
# take care of the rest
# take care of the rest
try
:
while
total
>
0
:
while
total
>
0
:
self
.
_hooked_sess
.
run
(
fetches
=
[])
# TODO XXX doesn't support remap
pbar
.
update
(
1
)
feed
=
self
.
_input_source
.
next_feed
(
cnt
=
1
)
total
-=
1
self
.
_hooked_sess
.
run
(
fetches
=
[],
feed_dict
=
feed
)
pbar
.
update
(
1
)
total
-=
1
except
AttributeError
:
logger
.
error
(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!"
)
logger
.
error
(
"[DataParallelInferenceRunner] Skipping the rest of the datapoints ..."
)
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
trigger_epoch
()
inf
.
trigger_epoch
()
tensorpack/callbacks/param.py
View file @
b9a15df7
...
@@ -185,7 +185,7 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -185,7 +185,7 @@ class HumanHyperParamSetter(HyperParamSetter):
"""
"""
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
logger
.
info
(
"Use {} to
control hyperparam {}
."
.
format
(
logger
.
info
(
"Use {} to
set hyperparam: '{}'
."
.
format
(
self
.
file_name
,
self
.
param
.
readable_name
))
self
.
file_name
,
self
.
param
.
readable_name
))
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
b9a15df7
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
..base
import
DataFlow
from
..base
import
DataFlow
from
...utils.timer
import
timed_operation
from
...utils.timer
import
timed_operation
from
...utils
import
logger
from
six.moves
import
zip
,
map
from
six.moves
import
zip
,
map
from
collections
import
Counter
from
collections
import
Counter
import
json
import
json
...
@@ -26,6 +27,7 @@ class VisualQA(DataFlow):
...
@@ -26,6 +27,7 @@ class VisualQA(DataFlow):
"""
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
def
__init__
(
self
,
question_file
,
annotation_file
):
logger
.
warn
(
"dataset.VisualQA is deprecated!"
)
with
timed_operation
(
'Reading VQA JSON file'
):
with
timed_operation
(
'Reading VQA JSON file'
):
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
self
.
task_type
=
qobj
[
'task_type'
]
self
.
task_type
=
qobj
[
'task_type'
]
...
...
tensorpack/dataflow/prefetch.py
View file @
b9a15df7
...
@@ -275,7 +275,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -275,7 +275,7 @@ class ThreadedMapData(ProxyDataFlow):
dp
=
self
.
queue_get_stoppable
(
self
.
inq
)
dp
=
self
.
queue_get_stoppable
(
self
.
inq
)
dp
=
self
.
func
(
dp
)
dp
=
self
.
func
(
dp
)
if
dp
is
not
None
:
if
dp
is
not
None
:
self
.
queue_put_stoppable
(
self
.
outq
,
dp
)
self
.
outq
.
put
(
dp
)
else
:
else
:
assert
not
self
.
_strict
,
\
assert
not
self
.
_strict
,
\
"[ThreadedMapData] Map function cannot return None when strict mode is used."
"[ThreadedMapData] Map function cannot return None when strict mode is used."
...
@@ -345,3 +345,8 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -345,3 +345,8 @@ class ThreadedMapData(ProxyDataFlow):
for
_
in
range
(
self
.
buffer_size
):
for
_
in
range
(
self
.
buffer_size
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
yield
self
.
_out_queue
.
get
()
yield
self
.
_out_queue
.
get
()
def
__del__
(
self
):
for
p
in
self
.
_threads
:
p
.
stop
()
p
.
join
()
tensorpack/graph_builder/input_source.py
View file @
b9a15df7
...
@@ -192,13 +192,13 @@ class EnqueueThread(ShareSessionThread):
...
@@ -192,13 +192,13 @@ class EnqueueThread(ShareSessionThread):
except
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
,
DataFlowTerminated
):
except
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
,
DataFlowTerminated
):
pass
pass
except
Exception
:
except
Exception
:
logger
.
exception
(
"Exception in
EnqueueThread:"
)
logger
.
exception
(
"Exception in
{}:"
.
format
(
self
.
name
)
)
finally
:
finally
:
try
:
try
:
self
.
close_op
.
run
()
self
.
close_op
.
run
()
except
Exception
:
except
Exception
:
pass
pass
logger
.
info
(
"
EnqueueThread Exited."
)
logger
.
info
(
"
{} Exited."
.
format
(
self
.
name
)
)
class
QueueInput
(
FeedfreeInput
):
class
QueueInput
(
FeedfreeInput
):
...
@@ -234,6 +234,10 @@ class QueueInput(FeedfreeInput):
...
@@ -234,6 +234,10 @@ class QueueInput(FeedfreeInput):
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
def
_create_ema_callback
(
self
):
def
_create_ema_callback
(
self
):
"""
Create a hook-only callback which maintain EMA of the queue size.
Also tf.summary.scalar the EMA.
"""
with
self
.
cached_name_scope
():
with
self
.
cached_name_scope
():
# in TF there is no API to get queue capacity, so we can only summary the size
# in TF there is no API to get queue capacity, so we can only summary the size
size
=
tf
.
cast
(
self
.
queue
.
size
(),
tf
.
float32
,
name
=
'queue_size'
)
size
=
tf
.
cast
(
self
.
queue
.
size
(),
tf
.
float32
,
name
=
'queue_size'
)
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
b9a15df7
...
@@ -60,7 +60,8 @@ class PredictorFactory(object):
...
@@ -60,7 +60,8 @@ class PredictorFactory(object):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
"""
logger
.
info
(
"Building predictor tower '{}' on device {} ..."
.
format
(
tower_name
,
device
))
logger
.
info
(
"Building predictor tower '{}' on device {} ..."
.
format
(
tower_name
,
device
))
assert
tower_name
not
in
self
.
_names_built
assert
tower_name
not
in
self
.
_names_built
,
\
"Prediction tower with name '{}' already exists!"
.
format
(
tower_name
)
with
tf
.
device
(
device
),
\
with
tf
.
device
(
device
),
\
TowerContext
(
tower_name
,
is_training
=
False
),
\
TowerContext
(
tower_name
,
is_training
=
False
),
\
...
...
tensorpack/tfutils/sessinit.py
View file @
b9a15df7
...
@@ -104,6 +104,9 @@ class SaverRestore(SessionInit):
...
@@ -104,6 +104,9 @@ class SaverRestore(SessionInit):
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
"""
"""
if
model_path
.
endswith
(
'.npy'
)
or
model_path
.
endswith
(
'.npz'
):
logger
.
warn
(
"SaverRestore expect a TF checkpoint, but got a model path '{}'."
.
format
(
model_path
)
+
" To load from a dict, use 'DictRestore'."
)
model_path
=
get_checkpoint_path
(
model_path
)
model_path
=
get_checkpoint_path
(
model_path
)
self
.
path
=
model_path
self
.
path
=
model_path
self
.
prefix
=
prefix
self
.
prefix
=
prefix
...
@@ -192,6 +195,7 @@ class DictRestore(SessionInit):
...
@@ -192,6 +195,7 @@ class DictRestore(SessionInit):
Args:
Args:
param_dict (dict): a dict of {name: value}
param_dict (dict): a dict of {name: value}
"""
"""
assert
isinstance
(
param_dict
,
dict
),
type
(
param_dict
)
# use varname (with :0) for consistency
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
...
@@ -220,7 +224,7 @@ class DictRestore(SessionInit):
...
@@ -220,7 +224,7 @@ class DictRestore(SessionInit):
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
@
deprecated
(
"Use `DictRestore` instead!"
,
"2017-0
6
-01"
)
@
deprecated
(
"Use `DictRestore` instead!"
,
"2017-0
9
-01"
)
def
ParamRestore
(
d
):
def
ParamRestore
(
d
):
return
DictRestore
(
d
)
return
DictRestore
(
d
)
...
...
tensorpack/train/multigpu.py
View file @
b9a15df7
...
@@ -62,8 +62,8 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
...
@@ -62,8 +62,8 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Returns:
Returns:
List of outputs of ``func``, evaluated on each tower.
List of outputs of ``func``, evaluated on each tower.
"""
"""
logger
.
info
(
"Training a model of {} tower"
.
format
(
len
(
towers
)))
if
len
(
towers
)
>
1
:
if
len
(
towers
)
>
1
:
logger
.
info
(
"Training a model of {} towers"
.
format
(
len
(
towers
)))
_check_tf_version
()
_check_tf_version
()
ret
=
[]
ret
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment