Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e839c50d
Commit
e839c50d
authored
Jul 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move PredictorFactory to graph_builder
parent
efe3dfb5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
21 deletions
+41
-21
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+22
-8
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+13
-9
tensorpack/train/base.py
tensorpack/train/base.py
+4
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
No files found.
tensorpack/
train/predict
.py
→
tensorpack/
graph_builder/predictor_factory
.py
View file @
e839c50d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# File: predictor_factory.py
import
tensorflow
as
tf
# from ..tfutils.tower import TowerContext
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
)
__all__
=
[
'PredictorFactory'
]
# class PredictorTowerBuilder(object):
# def __init__(self, model):
# self._model = model
# self._towers = []
#
# def build(self, tower_name, device, input=None):
# with tf.device(device), TowerContext(tower_name, is_training=False):
# if input is None:
# input = self._model.get_reused_placehdrs()
# self._model.build_graph(input)
#
#
# SMART
class
PredictorFactory
(
object
):
""" Make predictors from
a trainer
."""
""" Make predictors from
:class:`ModelDesc` and cache them
."""
def
__init__
(
self
,
trainer
):
def
__init__
(
self
,
model
,
towers
,
vs_name
):
"""
Args:
towers (list[int]): list of gpu id
towers (list[int]): list of
available
gpu id
"""
self
.
model
=
trainer
.
model
self
.
towers
=
t
rainer
.
config
.
predict_tower
self
.
vs_name
=
trainer
.
vs_name_for_predictor
self
.
model
=
model
self
.
towers
=
t
owers
self
.
vs_name
=
vs_name
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
...
...
tensorpack/tfutils/tower.py
View file @
e839c50d
...
...
@@ -14,22 +14,26 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
,
index
=
0
,
vs_name
=
''
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
,
index
=
0
,
vs_name
=
''
):
"""
Args:
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
index (int): index of this tower
.
vs_name (str): Open a variable scope with this name, if given.
"""
self
.
_name
=
tower_name
if
is_training
is
None
:
# TODO remove this
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
bool
(
is_training
)
if
not
self
.
_is_training
:
# TODO ugly
assert
index
==
0
and
vs_name
==
''
,
"vs_name and index are meaningless in prediction!"
self
.
_index
=
int
(
index
)
self
.
_vs_name
=
str
(
vs_name
)
...
...
@@ -40,10 +44,6 @@ class TowerContext(object):
def
is_main_training_tower
(
self
):
return
self
.
is_training
and
self
.
_index
==
0
@
property
def
is_main_tower
(
self
):
return
self
.
_index
==
0
@
property
def
is_training
(
self
):
return
self
.
_is_training
...
...
@@ -113,11 +113,15 @@ class TowerContext(object):
if
self
.
is_training
:
reuse
=
self
.
_index
>
0
if
reuse
is
True
:
# clear old name_scope and re-enter the current variable_scope
# clear old name_scope (due to the existing variable_scope)
# and re-enter the current variable_scope
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
))
# if not training, should handle vs outside (TODO not good)
else
:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
for
c
in
self
.
_ctxs
:
c
.
__enter__
()
...
...
tensorpack/train/base.py
View file @
e839c50d
...
...
@@ -10,7 +10,7 @@ from six.moves import range
import
tensorflow
as
tf
from
.
predict
import
PredictorFactory
from
.
.graph_builder.predictor_factory
import
PredictorFactory
from
.config
import
TrainConfig
from
..utils
import
logger
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
...
...
@@ -217,6 +217,7 @@ class Trainer(object):
"""
The variable scope name a predictor should be built in.
"""
# TODO graphbuilder knows it
return
""
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
...
...
@@ -229,7 +230,8 @@ class Trainer(object):
an :class:`OnlinePredictor`.
"""
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
config
.
predict_tower
,
self
.
vs_name_for_predictor
)
nr_tower
=
len
(
self
.
config
.
predict_tower
)
if
nr_tower
<
tower
:
logger
.
warn
(
...
...
tensorpack/train/feedfree.py
View file @
e839c50d
...
...
@@ -17,7 +17,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
class
FeedfreeTrainerBase
(
Trainer
):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``
self
.data`` to be a :class:`FeedfreeInput`.
Expect ``
config
.data`` to be a :class:`FeedfreeInput`.
"""
@
deprecated
(
"Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)"
)
...
...
tensorpack/train/multigpu.py
View file @
e839c50d
...
...
@@ -54,7 +54,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in
towers
.
devices: a list of devices to be used. By default will use GPUs in
``towers``
.
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment