Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
82bf74c9
Commit
82bf74c9
authored
Oct 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
graph builder for simple trainer
parent
94eace54
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
6 deletions
+53
-6
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+45
-0
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+1
-0
tensorpack/train/simple.py
tensorpack/train/simple.py
+7
-6
No files found.
tensorpack/graph_builder/training.py
0 → 100644
View file @
82bf74c9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: training.py
from
abc
import
ABCMeta
,
abstractmethod
import
tensorflow
as
tf
import
six
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
TowerContext
@
six
.
add_metaclass
(
ABCMeta
)
class
GraphBuilder
(
object
):
@
abstractmethod
def
build
(
*
args
,
**
kwargs
):
pass
class
SimpleGraphBuilder
(
GraphBuilder
):
"""
Build the graph for single-cost single-optimizer single-tower training.
"""
def
build
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Args:
input (InputSource): should have been setup already
get_cost_fn ([tf.Tensor] -> tf.Tensor): a callable,
taking several tensors as input and returns a cost tensor.
get_opt_fn (None -> tf.train.Optimizer): a callable that returns an optimizer
Returns:
tf.Operation: the training op
"""
with
TowerContext
(
''
,
is_training
=
True
)
as
ctx
:
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
return
train_op
tensorpack/tfutils/tower.py
View file @
82bf74c9
...
@@ -118,6 +118,7 @@ class TowerContext(object):
...
@@ -118,6 +118,7 @@ class TowerContext(object):
assert
ns
==
self
.
_name
,
\
assert
ns
==
self
.
_name
,
\
"Name conflict: name_scope inside tower '{}' becomes '{}'!"
.
format
(
self
.
_name
,
ns
)
\
"Name conflict: name_scope inside tower '{}' becomes '{}'!"
.
format
(
self
.
_name
,
ns
)
\
+
" You may need a different name for the tower!"
+
" You may need a different name for the tower!"
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
global
_CurrentTowerContext
...
...
tensorpack/train/simple.py
View file @
82bf74c9
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
TowerContext
from
..graph_builder.input_source
import
FeedInput
from
..graph_builder.input_source
import
FeedInput
from
..graph_builder.training
import
SimpleGraphBuilder
__all__
=
[
'SimpleTrainer'
]
__all__
=
[
'SimpleTrainer'
]
...
@@ -54,11 +54,12 @@ class SimpleTrainer(Trainer):
...
@@ -54,11 +54,12 @@ class SimpleTrainer(Trainer):
[Callback]: the callbacks to be added
[Callback]: the callbacks to be added
"""
"""
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
input
)
def
get_cost
(
*
inputs
):
_
,
grads
=
model
.
get_cost_and_grad
()
model
.
build_graph
(
inputs
)
opt
=
model
.
get_optimizer
()
return
model
.
get_cost
()
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
train_op
=
SimpleGraphBuilder
()
.
build
(
input
,
get_cost
,
model
.
get_optimizer
)
return
train_op
,
cbs
return
train_op
,
cbs
def
_setup
(
self
):
def
_setup
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment