Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
95ab1563
Commit
95ab1563
authored
Jan 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support prefix for FeedfreeInferenceRunner
parent
16216c6a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
45 additions
and
16 deletions
+45
-16
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+11
-4
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+7
-3
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+2
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+6
-3
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+14
-0
tensorpack/train/base.py
tensorpack/train/base.py
+2
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-2
No files found.
examples/GAN/Image2Image.py
View file @
95ab1563
...
@@ -159,7 +159,7 @@ def get_data():
...
@@ -159,7 +159,7 @@ def get_data():
augs
=
[
imgaug
.
Resize
(
286
),
imgaug
.
RandomCrop
(
256
)]
augs
=
[
imgaug
.
Resize
(
286
),
imgaug
.
RandomCrop
(
256
)]
ds
=
AugmentImageComponents
(
ds
,
augs
,
(
0
,
1
))
ds
=
AugmentImageComponents
(
ds
,
augs
,
(
0
,
1
))
ds
=
BatchData
(
ds
,
BATCH
)
ds
=
BatchData
(
ds
,
BATCH
)
ds
=
PrefetchData
ZMQ
(
ds
,
1
)
ds
=
PrefetchData
(
ds
,
100
,
1
)
return
ds
return
ds
...
...
tensorpack/callbacks/base.py
View file @
95ab1563
...
@@ -11,7 +11,17 @@ __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
...
@@ -11,7 +11,17 @@ __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
""" Base class for all callbacks
Attributes:
epoch_num(int): the number of epochs that have completed the update
trainer(Trainer): the trainer
graph(tf.Graph): the graph
Note:
These attributes are available only after (and including)
:meth:`_setup_graph`.
"""
def
setup_graph
(
self
,
trainer
):
def
setup_graph
(
self
,
trainer
):
"""
"""
...
@@ -24,7 +34,6 @@ class Callback(object):
...
@@ -24,7 +34,6 @@ class Callback(object):
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
-
1
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
-
1
# self.epoch_num is always the number of epochs that finished updating parameters.
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
with
tf
.
name_scope
(
type
(
self
)
.
__name__
):
self
.
_setup_graph
()
self
.
_setup_graph
()
...
@@ -50,8 +59,6 @@ class Callback(object):
...
@@ -50,8 +59,6 @@ class Callback(object):
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
Triggered after every epoch.
Triggered after every epoch.
In this function, ``self.epoch_num`` would be the number of epoch finished.
"""
"""
self
.
epoch_num
+=
1
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
...
...
tensorpack/callbacks/inference_runner.py
View file @
95ab1563
...
@@ -11,6 +11,7 @@ from six.moves import zip, range
...
@@ -11,6 +11,7 @@ from six.moves import zip, range
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..utils
import
logger
,
get_tqdm
,
PREDICT_TOWER
,
SUMMARY_BACKUP_KEYS
from
..utils
import
logger
,
get_tqdm
,
PREDICT_TOWER
,
SUMMARY_BACKUP_KEYS
from
..tfutils.common
import
get_op_tensor_name
,
freeze_collection
from
..tfutils.common
import
get_op_tensor_name
,
freeze_collection
from
..tfutils
import
TowerContext
from
..train.input_data
import
FeedfreeInput
from
..train.input_data
import
FeedfreeInput
from
..predict
import
build_prediction_graph
from
..predict
import
build_prediction_graph
...
@@ -151,12 +152,14 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -151,12 +152,14 @@ class FeedfreeInferenceRunner(Callback):
pipeline.
pipeline.
"""
"""
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
):
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
):
"""
"""
Args:
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names of InputVar.
input_names (list): must be a subset of the names of InputVar.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input_data
=
input
self
.
_input_data
=
input
...
@@ -174,6 +177,7 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -174,6 +177,7 @@ class FeedfreeInferenceRunner(Callback):
self
.
_size
=
input
.
size
()
self
.
_size
=
input
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
raise
ValueError
(
"Input used in FeedfreeInferencecRunner must have a size!"
)
raise
ValueError
(
"Input used in FeedfreeInferencecRunner must have a size!"
)
self
.
_prefix
=
prefix
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
self
.
_find_input_tensors
()
# tensors
...
@@ -185,8 +189,8 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -185,8 +189,8 @@ class FeedfreeInferenceRunner(Callback):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
def
fn
(
_
):
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
build_prediction_graph
(
fn
,
[
0
]
)
build_prediction_graph
(
fn
,
[
0
]
,
prefix
=
self
.
_prefix
)
# TODO use towerp1 to support multiple FeedfreeInferenceRunner
self
.
_tower_prefix
=
PREDICT_TOWER
+
'0'
self
.
_tower_prefix
=
TowerContext
.
get_predict_tower_name
(
self
.
_prefix
,
0
)
self
.
_find_output_tensors
()
self
.
_find_output_tensors
()
...
...
tensorpack/models/model_desc.py
View file @
95ab1563
...
@@ -138,9 +138,9 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
...
@@ -138,9 +138,9 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
if
ctx
is
not
None
and
ctx
.
is_main_training_tower
:
if
ctx
is
not
None
and
ctx
.
is_main_training_tower
:
non_grad_updates
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
))
non_grad_updates
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
))
if
non_grad_updates
:
if
non_grad_updates
:
logger
.
info
(
"Apply UPDATE_OPS collection on cost."
)
with
tf
.
control_dependencies
(
non_grad_updates
):
with
tf
.
control_dependencies
(
non_grad_updates
):
barrier
=
tf
.
control_flow_ops
.
no_op
(
name
=
'update_ops_barrier'
)
cost
=
tf
.
identity
(
cost
)
cost
=
tf
.
control_flow_ops
.
with_dependencies
([
barrier
],
cost
)
return
cost
return
cost
def
_get_cost
(
self
,
*
args
):
def
_get_cost
(
self
,
*
args
):
...
...
tensorpack/predict/base.py
View file @
95ab1563
...
@@ -144,17 +144,20 @@ def get_predict_func(config):
...
@@ -144,17 +144,20 @@ def get_predict_func(config):
return
OfflinePredictor
(
config
)
return
OfflinePredictor
(
config
)
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
]):
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
]
,
prefix
=
''
):
"""
"""
Args:
Args:
build_tower_fn: a function that will be called inside each tower,
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
taking tower id as the argument.
towers: a list of relative GPU id.
towers: a list of relative GPU id.
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
"""
for
k
in
towers
:
for
k
in
towers
:
logger
.
info
(
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building prediction graph for towerid={} with prefix='{}' ..."
.
format
(
k
,
prefix
))
towername
=
TowerContext
.
get_predict_tower_name
(
prefix
,
k
)
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)
):
TowerContext
(
towername
,
is_training
=
False
):
build_tower_fn
(
k
)
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
tensorpack/tfutils/tower.py
View file @
95ab1563
...
@@ -72,6 +72,20 @@ class TowerContext(object):
...
@@ -72,6 +72,20 @@ class TowerContext(object):
newname
=
re
.
sub
(
predict_tower_prefix
,
'tower0/'
,
name
)
newname
=
re
.
sub
(
predict_tower_prefix
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
return
graph
.
get_tensor_by_name
(
newname
)
@
staticmethod
def
get_predict_tower_name
(
prefix
,
towerid
=
0
):
"""
Args:
prefix(str): an alphanumeric prefix.
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
Returns:
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
"""
assert
prefix
==
''
or
prefix
.
isalnum
()
return
PREDICT_TOWER
+
prefix
+
str
(
towerid
)
def
__enter__
(
self
):
def
__enter__
(
self
):
global
_CurrentTowerContext
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
assert
_CurrentTowerContext
is
None
,
\
...
...
tensorpack/train/base.py
View file @
95ab1563
...
@@ -105,8 +105,9 @@ class Trainer(object):
...
@@ -105,8 +105,9 @@ class Trainer(object):
summary (tf.Summary or str): a summary object, or a str which will
summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
be interpreted as a serialized tf.Summary protobuf.
"""
"""
if
isinstance
(
summary
,
six
.
string_types
):
if
isinstance
(
summary
,
six
.
binary_type
):
summary
=
tf
.
Summary
.
FromString
(
summary
)
summary
=
tf
.
Summary
.
FromString
(
summary
)
assert
isinstance
(
summary
,
tf
.
Summary
),
type
(
summary
)
for
val
in
summary
.
value
:
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
...
...
tensorpack/train/feedfree.py
View file @
95ab1563
...
@@ -75,8 +75,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -75,8 +75,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
MultiPredictorTower
Trainer
,
SingleCostFeedfree
Trainer
,
SingleCostFeedfree
Trainer
):
MultiPredictorTower
Trainer
):
"""
"""
A trainer with single cost, single training tower, any number of
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
prediction tower, and feed-free input.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment