Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7a0e8747
Commit
7a0e8747
authored
Aug 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
build_graph with ctx
parent
aabab2cc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
81 deletions
+76
-81
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-55
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-0
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+49
-4
tensorpack/predict/base.py
tensorpack/predict/base.py
+9
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+11
-10
No files found.
tensorpack/callbacks/group.py
View file @
7a0e8747
...
...
@@ -12,61 +12,6 @@ from ..utils import *
__all__
=
[
'Callbacks'
]
# --- Test-Callback related stuff seems not very useful.
@
contextmanager
def
create_test_graph
(
trainer
):
model
=
trainer
.
model
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
# create a global step var in test graph
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
input_vars
,
False
)
cost
=
model
.
get_cost
()
yield
Gtest
@
contextmanager
def
create_test_session
(
trainer
):
""" create a test-time session from trainer"""
with
create_test_graph
(
trainer
):
with
tf
.
Session
()
as
sess
:
yield
sess
class
TestCallbackContext
(
object
):
"""
A class holding the context needed for running TestCallback
"""
def
__init__
(
self
):
self
.
sess
=
None
@
contextmanager
def
create_context
(
self
,
trainer
):
if
self
.
sess
is
None
:
with
create_test_session
(
trainer
)
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
# no tower in test graph. just keep it as what it is
self
.
saver
=
tf
.
train
.
Saver
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
# TODO also do this for after_train?
def
restore_checkpoint
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?"
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
@
contextmanager
def
test_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
# ---
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
...
...
tensorpack/models/batch_norm.py
View file @
7a0e8747
...
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
copy
import
copy
import
re
from
.model_desc
import
get_current_tower_context
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
._common
import
layer_register
...
...
@@ -54,6 +55,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together....
emaname
=
'EMA'
#ctx = get_current_model_context()
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
# training tower
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
...
...
tensorpack/models/model_desc.py
View file @
7a0e8747
...
...
@@ -6,14 +6,54 @@
from
abc
import
ABCMeta
,
abstractmethod
import
tensorflow
as
tf
from
collections
import
namedtuple
import
inspect
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils
import
*
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
,
'get_current_tower_context'
,
'TowerContext'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
self
.
_is_training
=
is_training
@
property
def
is_main_tower
(
self
):
return
self
.
_name
==
''
or
self
.
_name
==
'tower0'
@
property
def
is_training
(
self
):
return
self
.
_is_training
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
class
ModelDesc
(
object
):
""" Base class for a model description """
__metaclass__
=
ABCMeta
...
...
@@ -49,7 +89,7 @@ class ModelDesc(object):
def
_get_input_vars
(
self
):
""":returns: a list of InputVar """
def
build_graph
(
self
,
model_inputs
,
is_training
):
def
build_graph
(
self
,
model_inputs
):
"""
Setup the whole graph.
...
...
@@ -57,10 +97,15 @@ class ModelDesc(object):
:param is_training: a boolean
:returns: the cost to minimize. a scalar variable
"""
self
.
_build_graph
(
model_inputs
,
is_training
)
if
len
(
inspect
.
getargspec
(
self
.
_build_graph
)
.
args
)
==
3
:
logger
.
warn
(
"_build_graph(self, input_vars, is_training) is deprecated!
\
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead."
)
self
.
_build_graph
(
model_inputs
,
get_current_tower_context
()
.
is_training
)
else
:
self
.
_build_graph
(
model_inputs
)
@
abstractmethod
def
_build_graph
(
self
,
inputs
,
is_training
):
def
_build_graph
(
self
,
inputs
):
pass
def
get_cost
(
self
):
...
...
tensorpack/predict/base.py
View file @
7a0e8747
...
...
@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import
tensorflow
as
tf
import
six
from
..models
import
TowerContext
from
..utils
import
logger
from
..tfutils
import
get_vars_by_names
...
...
@@ -88,7 +89,8 @@ class OfflinePredictor(OnlinePredictor):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
with
TowerContext
(
''
,
False
):
config
.
model
.
build_graph
(
input_vars
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
...
...
@@ -99,7 +101,7 @@ class OfflinePredictor(OnlinePredictor):
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
def
build_multi_tower_prediction_graph
(
model
,
towers
,
prefix
=
'towerp'
):
def
build_multi_tower_prediction_graph
(
model
,
towers
):
"""
:param towers: a list of gpu relative id.
"""
...
...
@@ -107,26 +109,24 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
for
k
in
towers
:
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'{}{}'
.
format
(
prefix
,
k
)):
model
.
_build_graph
(
input_vars
,
False
)
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'towerp{}'
.
format
(
k
)):
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
PREFIX
=
'towerp'
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
build_multi_tower_prediction_graph
(
config
.
model
,
towers
,
self
.
PREFIX
)
build_multi_tower_prediction_graph
(
config
.
model
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
self
.
sess
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
# use the first tower for compatible PredictorBase interface
for
k
in
towers
:
output_vars
=
get_vars_by_names
(
[
'{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
...
...
@@ -135,6 +135,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
...
...
tensorpack/train/multigpu.py
View file @
7a0e8747
...
...
@@ -7,6 +7,7 @@ import tensorflow as tf
import
itertools
,
re
from
six.moves
import
zip
,
range
from
..models
import
TowerContext
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
...
...
@@ -26,7 +27,7 @@ class MultiGPUTrainer(QueueInputTrainer):
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
with
tf
.
name_scope
(
'
average_g
rad'
):
with
tf
.
name_scope
(
'
AvgG
rad'
):
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
...
...
@@ -44,12 +45,12 @@ class MultiGPUTrainer(QueueInputTrainer):
grad_list
=
[]
for
idx
,
t
in
enumerate
(
self
.
config
.
tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
idx
))
as
scope
:
TowerContext
(
'tower{}'
.
format
(
idx
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
self
.
model
.
build_graph
(
model_inputs
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
# TODO gate_gradienst=0 seems to be faster?
...
...
@@ -92,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
with
tf
.
name_scope
(
'
async_scale_g
rad'
):
with
tf
.
name_scope
(
'
AsyncScaleG
rad'
):
return
[(
grad
/
len
(
self
.
config
.
tower
)
if
grad
is
not
None
else
None
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
...
...
tensorpack/train/trainer.py
View file @
7a0e8747
...
...
@@ -10,19 +10,18 @@ from six.moves import zip
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..models
import
TowerContext
from
..utils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.modelutils
import
describe_model
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
PREFIX
=
'towerp'
def
__init__
(
self
,
sess
,
model
,
towers
):
"""
...
...
@@ -42,7 +41,7 @@ class PredictorFactory(object):
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
raw_input_vars
=
get_vars_by_names
(
input_names
)
output_names
=
[
'
{}{}/'
.
format
(
self
.
PREFIX
,
tower
)
+
n
for
n
in
output_names
]
output_names
=
[
'
towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
...
...
@@ -52,7 +51,7 @@ class PredictorFactory(object):
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
build_multi_tower_prediction_graph
(
self
.
model
,
self
.
towers
,
prefix
=
self
.
PREFIX
)
self
.
model
,
self
.
towers
)
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
...
...
@@ -64,8 +63,9 @@ class SimpleTrainer(Trainer):
def
train
(
self
):
model
=
self
.
model
self
.
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
self
.
input_vars
,
True
)
cost_var
=
model
.
get_cost
()
# TODO assert scalar
with
TowerContext
(
''
):
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
# TODO assert scalar
add_moving_summary
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
...
...
@@ -180,8 +180,9 @@ class QueueInputTrainer(Trainer):
#self.dequed_inputs = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
self
.
model
.
build_graph
(
self
.
dequed_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
with
TowerContext
(
''
):
self
.
model
.
build_graph
(
self
.
dequed_inputs
)
cost_var
=
self
.
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
)
# GATE_NONE
add_moving_summary
(
cost_var
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment