Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3bc0bed2
Commit
3bc0bed2
authored
Feb 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some small change in input_data
parent
c5de2ef9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
67 additions
and
51 deletions
+67
-51
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+15
-13
tensorpack/predict/base.py
tensorpack/predict/base.py
+14
-1
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+3
-16
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+1
-1
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+30
-17
tensorpack/train/predict.py
tensorpack/train/predict.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
3bc0bed2
...
@@ -5,14 +5,16 @@
...
@@ -5,14 +5,16 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
import
tqdm
import
six
import
six
import
copy
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
_kwargs
,
get_tqdm
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils
import
TowerContext
from
..tfutils
import
TowerContext
from
..train.input_data
import
TensorInput
from
..train.input_data
import
TensorInput
,
FeedInput
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Triggerable
from
.base
import
Triggerable
...
@@ -78,8 +80,9 @@ class InferenceRunner(Triggerable):
...
@@ -78,8 +80,9 @@ class InferenceRunner(Triggerable):
input_tensor_names(list): list of tensors to feed the dataflow to.
input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
Defaults to all the input placeholders.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
if
isinstance
(
ds
,
DataFlow
):
self
.
ds
=
ds
self
.
ds
=
FeedInput
(
ds
)
assert
isinstance
(
self
.
ds
,
FeedInput
),
self
.
ds
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
self
.
infs
=
[
infs
]
else
:
else
:
...
@@ -132,14 +135,13 @@ class InferenceRunner(Triggerable):
...
@@ -132,14 +135,13 @@ class InferenceRunner(Triggerable):
inf
.
before_inference
()
inf
.
before_inference
()
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
_
in
tqdm
.
trange
(
self
.
ds
.
size
(),
**
get_tqdm_kwargs
()):
for
dp
in
self
.
ds
.
get_data
():
dp
=
self
.
ds
.
next_feed
()
outputs
=
self
.
predictor
(
dp
)
outputs
=
self
.
predictor
(
dp
)
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
for
k
in
tensormap
]
for
k
in
tensormap
]
inf
.
datapoint
(
inf_output
)
inf
.
datapoint
(
inf_output
)
pbar
.
update
()
self
.
_write_summary_after_inference
()
self
.
_write_summary_after_inference
()
def
_write_summary_after_inference
(
self
):
def
_write_summary_after_inference
(
self
):
...
@@ -195,7 +197,7 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -195,7 +197,7 @@ class FeedfreeInferenceRunner(Triggerable):
self
.
_input_data
.
setup
(
self
.
trainer
.
model
)
self
.
_input_data
.
setup
(
self
.
trainer
.
model
)
# only 1 prediction tower will be used for inference
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reused_placehdrs
(
)
model_placehdrs
=
copy
.
copy
(
self
.
trainer
.
model
.
get_reused_placehdrs
()
)
if
self
.
_input_names
is
not
None
:
if
self
.
_input_names
is
not
None
:
raise
NotImplementedError
(
"Random code. Not tested."
)
raise
NotImplementedError
(
"Random code. Not tested."
)
assert
len
(
self
.
_input_names
)
==
len
(
self
.
_input_tensors
),
\
assert
len
(
self
.
_input_names
)
==
len
(
self
.
_input_tensors
),
\
...
...
tensorpack/predict/base.py
View file @
3bc0bed2
...
@@ -11,7 +11,7 @@ from ..utils import logger
...
@@ -11,7 +11,7 @@ from ..utils import logger
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils.collection
import
freeze_collection
from
..tfutils.collection
import
freeze_collection
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
...
@@ -192,6 +192,19 @@ class PredictorTowerBuilder(object):
...
@@ -192,6 +192,19 @@ class PredictorTowerBuilder(object):
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
is_training
=
False
):
self
.
_fn
(
tower
)
self
.
_fn
(
tower
)
@
staticmethod
def
get_tensors_maybe_in_tower
(
placeholder_names
,
names
,
k
,
prefix
=
''
):
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
# if the name is not a placeholder, use it's name in each tower
return
TowerContext
.
get_predict_tower_name
(
k
,
prefix
)
+
'/'
+
name
names
=
list
(
map
(
maybe_inside_tower
,
names
))
tensors
=
get_tensors_by_names
(
names
)
return
tensors
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
"""
"""
...
...
tensorpack/predict/multigpu.py
View file @
3bc0bed2
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
.base
import
OnlinePredictor
,
build_prediction_graph
from
.base
import
OnlinePredictor
,
build_prediction_graph
,
PredictorTowerBuilder
__all__
=
[
'MultiTowerOfflinePredictor'
,
__all__
=
[
'MultiTowerOfflinePredictor'
,
'DataParallelOfflinePredictor'
]
'DataParallelOfflinePredictor'
]
...
@@ -33,26 +33,13 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -33,26 +33,13 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
sess
=
config
.
session_creator
.
create_session
()
self
.
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
self
.
sess
)
config
.
session_init
.
init
(
self
.
sess
)
get_tensor_fn
=
MultiTowerOfflinePredicto
r
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilde
r
.
get_tensors_maybe_in_tower
for
k
in
towers
:
for
k
in
towers
:
input_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
input_names
,
k
)
input_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
input_names
,
k
)
output_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
output_names
,
k
)
output_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
output_names
,
k
)
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
input_tensors
,
output_tensors
,
config
.
return_input
,
self
.
sess
))
input_tensors
,
output_tensors
,
config
.
return_input
,
self
.
sess
))
@
staticmethod
def
get_tensors_maybe_in_tower
(
placeholder_names
,
names
,
k
):
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
# if the name is not a placeholder, use it's name in each tower
return
TowerContext
.
get_predict_tower_name
(
k
)
+
'/'
+
name
names
=
map
(
maybe_inside_tower
,
names
)
tensors
=
get_tensors_by_names
(
names
)
return
tensors
def
_do_call
(
self
,
dp
):
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
# use the first tower for compatible PredictorBase interface
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
...
...
tensorpack/tfutils/tower.py
View file @
3bc0bed2
...
@@ -76,9 +76,9 @@ class TowerContext(object):
...
@@ -76,9 +76,9 @@ class TowerContext(object):
def
get_predict_tower_name
(
towerid
=
0
,
prefix
=
''
):
def
get_predict_tower_name
(
towerid
=
0
,
prefix
=
''
):
"""
"""
Args:
Args:
prefix(str): an alphanumeric prefix.
towerid(int): an integer, the id of this predict tower, usually
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
used to choose the GPU id.
prefix(str): an alphanumeric prefix.
Returns:
Returns:
str: the final tower name used to create a predict tower.
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
...
...
tensorpack/train/input_data.py
View file @
3bc0bed2
...
@@ -23,12 +23,27 @@ __all__ = ['InputData', 'FeedfreeInput',
...
@@ -23,12 +23,27 @@ __all__ = ['InputData', 'FeedfreeInput',
class
InputData
(
object
):
class
InputData
(
object
):
""" Base class for the abstract InputData. """
""" Base class for the abstract InputData. """
@
abstractmethod
def
get_input_tensors
(
self
):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
pass
pass
def
setup_training
(
self
,
trainer
):
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
self
.
setup
(
trainer
.
model
)
@
abstractmethod
def
reset_state
(
self
):
pass
def
next_feed
(
self
):
return
[]
class
FeedInput
(
InputData
):
class
FeedInput
(
InputData
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
...
@@ -49,30 +64,25 @@ class FeedInput(InputData):
...
@@ -49,30 +64,25 @@ class FeedInput(InputData):
rds
.
reset_state
()
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
self
.
data_producer
=
rds
.
get_data
()
def
next_feed
(
self
):
def
reset_state
(
self
):
data
=
next
(
self
.
data_producer
)
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
feed
=
dict
(
zip
(
self
.
input_placehdrs
,
data
))
rds
.
reset_state
()
self
.
_last_feed
=
feed
self
.
data_producer
=
rds
.
get_data
()
return
feed
def
last_feed
(
self
):
def
get_input_tensors
(
self
):
return
self
.
_last_feed
return
self
.
input_placehdrs
def
next_feed
(
self
):
return
next
(
self
.
data_producer
)
class
FeedfreeInput
(
InputData
):
class
FeedfreeInput
(
InputData
):
""" Abstract base for input without feed,
""" Abstract base for input without feed,
e.g. by queue or other operations. """
e.g. by queue or other operations. """
@
abstractmethod
def
reset_state
(
self
):
def
get_input_tensors
(
self
):
# TODO cannot reset
"""
pass
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def
get_client_threads
(
self
):
return
[]
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
...
@@ -234,6 +244,9 @@ class DummyConstantInput(FeedfreeInput):
...
@@ -234,6 +244,9 @@ class DummyConstantInput(FeedfreeInput):
self
.
shapes
=
shapes
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
logger
.
warn
(
"Using dummy input for debug!"
)
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
placehdrs
=
self
.
input_placehdrs
placehdrs
=
self
.
input_placehdrs
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
...
...
tensorpack/train/predict.py
View file @
3bc0bed2
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..predict
import
(
OnlinePredictor
,
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
,
MultiTowerOfflinePredictor
)
PredictorTowerBuilder
)
__all__
=
[
'PredictorFactory'
]
__all__
=
[
'PredictorFactory'
]
...
@@ -39,7 +39,7 @@ class PredictorFactory(object):
...
@@ -39,7 +39,7 @@ class PredictorFactory(object):
self
.
_tower_builder
.
build
(
tower
)
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
MultiTowerOfflinePredicto
r
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilde
r
.
get_tensors_maybe_in_tower
in_tensors
=
get_tensor_fn
(
placeholder_names
,
input_names
,
tower
)
in_tensors
=
get_tensor_fn
(
placeholder_names
,
input_names
,
tower
)
out_tensors
=
get_tensor_fn
(
placeholder_names
,
output_names
,
tower
)
out_tensors
=
get_tensor_fn
(
placeholder_names
,
output_names
,
tower
)
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
tensorpack/train/trainer.py
View file @
3bc0bed2
...
@@ -28,7 +28,8 @@ class SimpleTrainer(Trainer):
...
@@ -28,7 +28,8 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
dp
=
self
.
_input_method
.
next_feed
()
feed
=
dict
(
zip
(
self
.
inputs
,
dp
))
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment