Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
53549d52
Commit
53549d52
authored
Feb 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] start moving optimizer to model
parent
5b29bda9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
10 deletions
+31
-10
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+13
-1
tensorpack/train/config.py
tensorpack/train/config.py
+8
-5
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+5
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+5
-2
No files found.
tensorpack/models/model_desc.py
View file @
53549d52
...
...
@@ -49,6 +49,7 @@ InputVar = InputDesc
class
ModelDesc
(
object
):
""" Base class for a model description """
# inputs:
def
get_reused_placehdrs
(
self
):
"""
Create or return (if already created) raw input TF placeholders in the graph.
...
...
@@ -97,13 +98,14 @@ class ModelDesc(object):
"""
:returns: a list of InputDesc
"""
# TODO deprecate @
Ma
r 11
# TODO deprecate @
Ap
r 11
logger
.
warn
(
"[Deprecated] _get_input_vars() is renamed to _get_inputs()"
)
return
self
.
_get_input_vars
()
def
_get_input_vars
(
self
):
# keep backward compatibility
raise
NotImplementedError
()
# graph, cost, optimizer:
def
build_graph
(
self
,
model_inputs
):
"""
Build the whole symbolic graph.
...
...
@@ -153,6 +155,16 @@ class ModelDesc(object):
def
_get_cost
(
self
,
*
args
):
return
self
.
cost
def
get_optimizer
(
self
):
"""
Returns:
a :class:`tf.train.Optimizer` instance.
"""
return
self
.
_get_optimizer
()
def
_get_optimizer
(
self
):
raise
NotImplementedError
()
def
get_gradient_processor
(
self
):
""" Return a list of :class:`tensorpack.tfutils.GradientProcessor`.
They will be executed by the trainer in the given order.
...
...
tensorpack/train/config.py
View file @
53549d52
...
...
@@ -24,7 +24,7 @@ class TrainConfig(object):
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
optimizer
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
session_config
=
get_default_sess_config
(),
session_init
=
None
,
...
...
@@ -37,7 +37,6 @@ class TrainConfig(object):
data (InputData): an `InputData` instance. Only one of ``dataflow``
or ``data`` has to be present.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
...
...
@@ -74,9 +73,6 @@ class TrainConfig(object):
assert_type
(
self
.
data
,
InputData
)
self
.
dataflow
=
None
self
.
optimizer
=
optimizer
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
if
isinstance
(
callbacks
,
Callbacks
):
# keep quiet now because I haven't determined the final API yet.
logger
.
warn
(
"[Deprecated] API of TrainConfig(callbacks=) has changed!"
)
...
...
@@ -133,10 +129,17 @@ class TrainConfig(object):
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
if
'optimizer'
in
kwargs
:
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
else
:
self
.
optimizer
=
None
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
# this is a deprecated function
# TODO Deprecate @ Mar 15
logger
.
warn
(
"config.set_tower is deprecated. set config.tower or config.nr_tower directly"
)
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
if
nr_tower
:
...
...
tensorpack/train/feedfree.py
View file @
53549d52
...
...
@@ -21,7 +21,6 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def
_trigger_epoch
(
self
):
# run summary_op every epoch
# TODO FIXME summary_op will take a data! This is not good for TensorInput.
...
...
@@ -45,7 +44,11 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
self
.
model
.
build_graph
(
actual_inputs
)
cost_var
=
self
.
model
.
get_cost
()
# GATE_NONE faster?
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
opt
=
self
.
config
.
optimizer
if
opt
is
None
:
opt
=
self
.
model
.
get_optimizer
()
# XXX TODO not gonna work if optimizer modifies grad
self
.
config
.
optimizer
=
opt
grads
=
opt
.
compute_gradients
(
cost_var
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
False
)
...
...
tensorpack/train/trainer.py
View file @
53549d52
...
...
@@ -84,11 +84,14 @@ class SimpleTrainer(Trainer):
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
opt
=
self
.
config
.
optimizer
if
not
opt
:
opt
=
model
.
get_optimizer
()
grads
=
opt
.
compute_gradients
(
cost_var
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment