Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ad5cb725
Commit
ad5cb725
authored
Oct 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove feedfree
parent
82bf74c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
89 deletions
+28
-89
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+0
-67
tensorpack/train/simple.py
tensorpack/train/simple.py
+28
-22
No files found.
tensorpack/train/feedfree.py
deleted
100644 → 0
View file @
82bf74c9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
..tfutils.tower
import
TowerContext
from
..graph_builder.input_source
import
QueueInput
,
FeedfreeInput
from
.simple
import
SimpleTrainer
from
.base
import
Trainer
__all__
=
[
'FeedfreeTrainerBase'
,
'SingleCostFeedfreeTrainer'
,
'QueueInputTrainer'
]
# TODO deprecate it some time
class
FeedfreeTrainerBase
(
Trainer
):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``config.data`` to be a :class:`FeedfreeInput`.
"""
@
deprecated
(
"Please build the graph yourself, e.g. by self.model.build_graph(self._input_source)"
)
def
build_train_tower
(
self
):
with
TowerContext
(
''
,
is_training
=
True
):
self
.
model
.
build_graph
(
self
.
_input_source
)
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
@
deprecated
(
""
,
"2017-11-21"
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
SingleCostFeedfreeTrainer
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
logger
.
warn
(
"SingleCostFeedfreeTrainer was deprecated!"
)
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
self
.
model
.
build_graph
(
self
.
_input_source
)
return
self
.
model
.
get_cost_and_grad
()
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
dataflow
=
None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return
SimpleTrainer
(
config
)
tensorpack/train/simple.py
View file @
ad5cb725
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
from
..utils
import
logger
from
..graph_builder.input_source
import
FeedInput
from
..graph_builder.input_source
import
FeedInput
,
QueueInput
from
..graph_builder.training
import
SimpleGraphBuilder
from
..graph_builder.training
import
SimpleGraphBuilder
__all__
=
[
'SimpleTrainer'
]
__all__
=
[
'SimpleTrainer'
]
...
@@ -39,29 +39,35 @@ class SimpleTrainer(Trainer):
...
@@ -39,29 +39,35 @@ class SimpleTrainer(Trainer):
"Consider QueueInput or other InputSource instead."
)
"Consider QueueInput or other InputSource instead."
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
@
staticmethod
def
_setup
(
self
):
def
setup_graph
(
model
,
input
):
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
"""
Setup graph for SimpleTrainer. It simply build one tower and optimize `model.cost`.
Args:
model (ModelDesc):
input (InputSource):
Returns:
def
get_cost
(
*
inputs
):
tf.Operation: the training op
self
.
model
.
build_graph
(
inputs
)
return
self
.
model
.
get_cost
()
[Callback]: the callbacks to be added
self
.
train_op
=
SimpleGraphBuilder
()
.
build
(
self
.
_input_source
,
get_cost
,
self
.
model
.
get_optimizer
)
"""
self
.
config
.
callbacks
.
extend
(
cbs
)
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
def
get_cost
(
*
inputs
):
model
.
build_graph
(
inputs
)
return
model
.
get_cost
()
train_op
=
SimpleGraphBuilder
()
.
build
(
input
,
get_cost
,
model
.
get_optimizer
)
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
return
train_op
,
cbs
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
def
_setup
(
self
):
Args:
self
.
train_op
,
callbacks
=
SimpleTrainer
.
setup_graph
(
self
.
model
,
self
.
_input_source
)
config (TrainConfig): Must contain 'model' and 'dataflow'.
self
.
config
.
callbacks
.
extend
(
callbacks
)
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
dataflow
=
None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return
SimpleTrainer
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment