Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7a0e8747
Commit
7a0e8747
authored
Aug 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
build_graph with ctx
parent
aabab2cc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
81 deletions
+76
-81
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-55
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-0
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+49
-4
tensorpack/predict/base.py
tensorpack/predict/base.py
+9
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+11
-10
No files found.
tensorpack/callbacks/group.py
View file @
7a0e8747
...
@@ -12,61 +12,6 @@ from ..utils import *
...
@@ -12,61 +12,6 @@ from ..utils import *
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
# --- Test-Callback related stuff seems not very useful.
@
contextmanager
def
create_test_graph
(
trainer
):
model
=
trainer
.
model
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
# create a global step var in test graph
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
input_vars
,
False
)
cost
=
model
.
get_cost
()
yield
Gtest
@
contextmanager
def
create_test_session
(
trainer
):
""" create a test-time session from trainer"""
with
create_test_graph
(
trainer
):
with
tf
.
Session
()
as
sess
:
yield
sess
class
TestCallbackContext
(
object
):
"""
A class holding the context needed for running TestCallback
"""
def
__init__
(
self
):
self
.
sess
=
None
@
contextmanager
def
create_context
(
self
,
trainer
):
if
self
.
sess
is
None
:
with
create_test_session
(
trainer
)
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
# no tower in test graph. just keep it as what it is
self
.
saver
=
tf
.
train
.
Saver
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
# TODO also do this for after_train?
def
restore_checkpoint
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?"
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
@
contextmanager
def
test_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
# ---
class
CallbackTimeLogger
(
object
):
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
times
=
[]
...
...
tensorpack/models/batch_norm.py
View file @
7a0e8747
...
@@ -7,6 +7,7 @@ import tensorflow as tf
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
copy
import
copy
from
copy
import
copy
import
re
import
re
from
.model_desc
import
get_current_tower_context
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
._common
import
layer_register
from
._common
import
layer_register
...
@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together....
# XXX a hack to handle training tower & prediction tower together....
emaname
=
'EMA'
emaname
=
'EMA'
#ctx = get_current_model_context()
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
# training tower
# training tower
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
...
...
tensorpack/models/model_desc.py
View file @
7a0e8747
...
@@ -6,14 +6,54 @@
...
@@ -6,14 +6,54 @@
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
import
inspect
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils
import
*
from
..tfutils
import
*
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
,
'get_current_tower_context'
,
'TowerContext'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
self
.
_is_training
=
is_training
@
property
def
is_main_tower
(
self
):
return
self
.
_name
==
''
or
self
.
_name
==
'tower0'
@
property
def
is_training
(
self
):
return
self
.
_is_training
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description """
""" Base class for a model description """
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
...
@@ -49,7 +89,7 @@ class ModelDesc(object):
...
@@ -49,7 +89,7 @@ class ModelDesc(object):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
""":returns: a list of InputVar """
""":returns: a list of InputVar """
def
build_graph
(
self
,
model_inputs
,
is_training
):
def
build_graph
(
self
,
model_inputs
):
"""
"""
Setup the whole graph.
Setup the whole graph.
...
@@ -57,10 +97,15 @@ class ModelDesc(object):
...
@@ -57,10 +97,15 @@ class ModelDesc(object):
:param is_training: a boolean
:param is_training: a boolean
:returns: the cost to minimize. a scalar variable
:returns: the cost to minimize. a scalar variable
"""
"""
self
.
_build_graph
(
model_inputs
,
is_training
)
if
len
(
inspect
.
getargspec
(
self
.
_build_graph
)
.
args
)
==
3
:
logger
.
warn
(
"_build_graph(self, input_vars, is_training) is deprecated!
\
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead."
)
self
.
_build_graph
(
model_inputs
,
get_current_tower_context
()
.
is_training
)
else
:
self
.
_build_graph
(
model_inputs
)
@
abstractmethod
@
abstractmethod
def
_build_graph
(
self
,
inputs
,
is_training
):
def
_build_graph
(
self
,
inputs
):
pass
pass
def
get_cost
(
self
):
def
get_cost
(
self
):
...
...
tensorpack/predict/base.py
View file @
7a0e8747
...
@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
...
@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
from
..models
import
TowerContext
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_vars_by_names
from
..tfutils
import
get_vars_by_names
...
@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input_vars
=
config
.
model
.
get_input_vars
()
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
with
TowerContext
(
''
,
False
):
config
.
model
.
build_graph
(
input_vars
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
...
@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor):
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
def
build_multi_tower_prediction_graph
(
model
,
towers
,
prefix
=
'towerp'
):
def
build_multi_tower_prediction_graph
(
model
,
towers
):
"""
"""
:param towers: a list of gpu relative id.
:param towers: a list of gpu relative id.
"""
"""
...
@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
...
@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
for
k
in
towers
:
for
k
in
towers
:
logger
.
info
(
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'{}{}'
.
format
(
prefix
,
k
)):
TowerContext
(
'towerp{}'
.
format
(
k
)):
model
.
_build_graph
(
input_vars
,
False
)
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
PREFIX
=
'towerp'
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
# TODO backup summary keys?
build_multi_tower_prediction_graph
(
config
.
model
,
towers
,
self
.
PREFIX
)
build_multi_tower_prediction_graph
(
config
.
model
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
self
.
sess
)
config
.
session_init
.
init
(
self
.
sess
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
# use the first tower for compatible PredictorBase interface
for
k
in
towers
:
for
k
in
towers
:
output_vars
=
get_vars_by_names
(
output_vars
=
get_vars_by_names
(
[
'{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
[
'{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
...
@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
def
_do_call
(
self
,
dp
):
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
def
get_predictors
(
self
,
n
):
...
...
tensorpack/train/multigpu.py
View file @
7a0e8747
...
@@ -7,6 +7,7 @@ import tensorflow as tf
...
@@ -7,6 +7,7 @@ import tensorflow as tf
import
itertools
,
re
import
itertools
,
re
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..models
import
TowerContext
from
..utils
import
*
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.summary
import
summary_moving_average
...
@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer):
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
ret
=
[]
ret
=
[]
with
tf
.
name_scope
(
'
average_g
rad'
):
with
tf
.
name_scope
(
'
AvgG
rad'
):
for
grad_and_vars
in
zip
(
*
tower_grads
):
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
v
=
grad_and_vars
[
0
][
1
]
try
:
try
:
...
@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer):
grad_list
=
[]
grad_list
=
[]
for
idx
,
t
in
enumerate
(
self
.
config
.
tower
):
for
idx
,
t
in
enumerate
(
self
.
config
.
tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
idx
))
as
scope
:
TowerContext
(
'tower{}'
.
format
(
idx
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
self
.
model
.
build_graph
(
model_inputs
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
cost_var
=
self
.
model
.
get_cost
()
# build tower
# TODO gate_gradienst=0 seems to be faster?
# TODO gate_gradienst=0 seems to be faster?
...
@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
...
@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# pretend to average the grads, in order to make async and
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
# sync have consistent effective learning rate
def
scale
(
grads
):
def
scale
(
grads
):
with
tf
.
name_scope
(
'
async_scale_g
rad'
):
with
tf
.
name_scope
(
'
AsyncScaleG
rad'
):
return
[(
grad
/
len
(
self
.
config
.
tower
)
if
grad
is
not
None
else
None
,
var
)
return
[(
grad
/
len
(
self
.
config
.
tower
)
if
grad
is
not
None
else
None
,
var
)
for
grad
,
var
in
grads
]
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
map
(
scale
,
grad_list
)
...
...
tensorpack/train/trainer.py
View file @
7a0e8747
...
@@ -10,19 +10,18 @@ from six.moves import zip
...
@@ -10,19 +10,18 @@ from six.moves import zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..models
import
TowerContext
from
..utils
import
*
from
..utils
import
*
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.modelutils
import
describe_model
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
class
PredictorFactory
(
object
):
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
""" Make predictors for a trainer"""
PREFIX
=
'towerp'
def
__init__
(
self
,
sess
,
model
,
towers
):
def
__init__
(
self
,
sess
,
model
,
towers
):
"""
"""
...
@@ -42,7 +41,7 @@ class PredictorFactory(object):
...
@@ -42,7 +41,7 @@ class PredictorFactory(object):
self
.
_build_predict_tower
()
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
raw_input_vars
=
get_vars_by_names
(
input_names
)
raw_input_vars
=
get_vars_by_names
(
input_names
)
output_names
=
[
'
{}{}/'
.
format
(
self
.
PREFIX
,
tower
)
+
n
for
n
in
output_names
]
output_names
=
[
'
towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
output_vars
=
get_vars_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
...
@@ -52,7 +51,7 @@ class PredictorFactory(object):
...
@@ -52,7 +51,7 @@ class PredictorFactory(object):
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
build_multi_tower_prediction_graph
(
build_multi_tower_prediction_graph
(
self
.
model
,
self
.
towers
,
prefix
=
self
.
PREFIX
)
self
.
model
,
self
.
towers
)
self
.
tower_built
=
True
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
...
@@ -64,8 +63,9 @@ class SimpleTrainer(Trainer):
...
@@ -64,8 +63,9 @@ class SimpleTrainer(Trainer):
def
train
(
self
):
def
train
(
self
):
model
=
self
.
model
model
=
self
.
model
self
.
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
self
.
input_vars
,
True
)
with
TowerContext
(
''
):
cost_var
=
model
.
get_cost
()
# TODO assert scalar
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
# TODO assert scalar
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
...
@@ -180,8 +180,9 @@ class QueueInputTrainer(Trainer):
...
@@ -180,8 +180,9 @@ class QueueInputTrainer(Trainer):
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
self
.
model
.
build_graph
(
self
.
dequed_inputs
,
True
)
with
TowerContext
(
''
):
cost_var
=
self
.
model
.
get_cost
()
self
.
model
.
build_graph
(
self
.
dequed_inputs
)
cost_var
=
self
.
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
)
# GATE_NONE
cost_var
,
gate_gradients
=
0
)
# GATE_NONE
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment