Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
079eb3a9
Commit
079eb3a9
authored
Jul 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Make SimpleTrainer support InputSource
parent
20ee19bc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
37 deletions
+24
-37
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+1
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+10
-27
tensorpack/train/simple.py
tensorpack/train/simple.py
+13
-9
No files found.
examples/PennTreebank/PTB-LSTM.py
View file @
079eb3a9
...
...
@@ -173,4 +173,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
Simple
Feedfree
Trainer
(
config
)
.
train
()
SimpleTrainer
(
config
)
.
train
()
tensorpack/train/feedfree.py
View file @
079eb3a9
...
...
@@ -4,9 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
..tfutils.tower
import
TowerContext
from
..graph_builder.input_source
import
QueueInput
,
FeedfreeInput
from
.simple
import
SimpleTrainer
from
.base
import
Trainer
__all__
=
[
'FeedfreeTrainerBase'
,
'SingleCostFeedfreeTrainer'
,
...
...
@@ -34,6 +36,7 @@ class FeedfreeTrainerBase(Trainer):
self
.
hooked_sess
.
run
(
self
.
train_op
)
# TODO Kept for now for back-compat
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
...
...
@@ -42,30 +45,10 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
return
self
.
model
.
get_cost_and_grad
()
class
SimpleFeedfreeTrainer
(
SingleCostFeedfreeTrainer
):
"""
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
"""
def
__init__
(
self
,
config
):
"""
Args:
config (TrainConfig): ``config.data`` must exist and is a :class:`FeedfreeInput`.
"""
self
.
_input_source
=
config
.
data
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
self
.
_input_source
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
self
.
config
.
tower
))
def
_setup
(
self
):
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
with
TowerContext
(
''
,
is_training
=
True
):
cost
,
grads
=
self
.
_get_cost_and_grad
()
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
@
deprecated
(
"Use SimpleTrainer with config.data instead!"
)
def
SimpleFeedfreeTrainer
(
config
):
assert
isinstance
(
config
.
data
,
FeedfreeInput
),
config
.
data
return
SimpleTrainer
(
config
)
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
...
...
@@ -76,16 +59,16 @@ def QueueInputTrainer(config, input_queue=None):
Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
dataflow
=
None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return
Simple
Feedfree
Trainer
(
config
)
return
SimpleTrainer
(
config
)
tensorpack/train/simple.py
View file @
079eb3a9
...
...
@@ -13,8 +13,10 @@ __all__ = ['SimpleTrainer']
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
""" A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow.
When DataFlow is given, the InputSource to be used will be ``FeedInput(df)``.
"""
def
__init__
(
self
,
config
):
"""
...
...
@@ -22,23 +24,25 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config.
"""
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
self
.
config
.
tower
))
if
config
.
dataflow
is
None
:
self
.
_input_source
=
config
.
data
assert
isinstance
(
self
.
_input_source
,
FeedInput
),
type
(
self
.
_input_source
)
else
:
self
.
_input_source
=
FeedInput
(
config
.
dataflow
)
logger
.
warn
(
"SimpleTrainer is slow! Do you really want to use it?"
)
logger
.
warn
(
"FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead."
)
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
_setup
(
self
):
self
.
_setup_input_source
(
self
.
_input_source
)
with
TowerContext
(
''
,
is_training
=
True
):
self
.
model
.
build_graph
(
self
.
_input_source
)
cost_var
=
self
.
model
.
get_cost
()
cost
,
grads
=
self
.
model
.
get_cost_and_grad
()
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment