Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4d2a7b4c
Commit
4d2a7b4c
authored
Jul 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Simplify the connection between ModelDesc and InputSource
parent
e46e6bca
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
43 deletions
+43
-43
examples/GAN/GAN.py
examples/GAN/GAN.py
+8
-4
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-4
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+13
-5
tensorpack/train/base.py
tensorpack/train/base.py
+9
-0
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+7
-22
tensorpack/train/simple.py
tensorpack/train/simple.py
+4
-8
No files found.
examples/GAN/GAN.py
View file @
4d2a7b4c
...
@@ -8,7 +8,8 @@ import numpy as np
...
@@ -8,7 +8,8 @@ import numpy as np
import
time
import
time
from
tensorpack
import
(
FeedfreeTrainerBase
,
QueueInput
,
from
tensorpack
import
(
FeedfreeTrainerBase
,
QueueInput
,
ModelDesc
,
DataFlow
,
StagingInputWrapper
,
ModelDesc
,
DataFlow
,
StagingInputWrapper
,
MultiGPUTrainerBase
,
LeastLoadedDeviceSetter
)
MultiGPUTrainerBase
,
LeastLoadedDeviceSetter
,
TowerContext
)
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
...
@@ -65,7 +66,8 @@ class GANTrainer(FeedfreeTrainerBase):
...
@@ -65,7 +66,8 @@ class GANTrainer(FeedfreeTrainerBase):
def
_setup
(
self
):
def
_setup
(
self
):
super
(
GANTrainer
,
self
)
.
_setup
()
super
(
GANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
with
TowerContext
(
''
,
is_training
=
True
):
self
.
model
.
build_graph
(
self
.
_input_source
)
opt
=
self
.
model
.
get_optimizer
()
opt
=
self
.
model
.
get_optimizer
()
# by default, run one d_min after one g_min
# by default, run one d_min after one g_min
...
@@ -91,7 +93,8 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
...
@@ -91,7 +93,8 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SeparateGANTrainer
,
self
)
.
_setup
()
super
(
SeparateGANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
with
TowerContext
(
''
,
is_training
=
True
):
self
.
model
.
build_graph
(
self
.
_input_source
)
opt
=
self
.
model
.
get_optimizer
()
opt
=
self
.
model
.
get_optimizer
()
self
.
d_min
=
opt
.
minimize
(
self
.
d_min
=
opt
.
minimize
(
...
@@ -123,8 +126,9 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
...
@@ -123,8 +126,9 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
super
(
MultiGPUGANTrainer
,
self
)
.
_setup
()
super
(
MultiGPUGANTrainer
,
self
)
.
_setup
()
devices
=
[
LeastLoadedDeviceSetter
(
d
,
self
.
_raw_devices
)
for
d
in
self
.
_raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
self
.
_raw_devices
)
for
d
in
self
.
_raw_devices
]
# NOTE trainer internal APIs subject to change in the future
def
get_cost
():
def
get_cost
():
self
.
build_train_tower
(
)
self
.
model
.
build_graph
(
self
.
_input_source
)
return
[
self
.
model
.
d_loss
,
self
.
model
.
g_loss
]
return
[
self
.
model
.
d_loss
,
self
.
model
.
g_loss
]
cost_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
cost_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
get_cost
,
devices
)
self
.
config
.
tower
,
get_cost
,
devices
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
4d2a7b4c
...
@@ -90,8 +90,7 @@ class InferenceRunnerBase(Callback):
...
@@ -90,8 +90,7 @@ class InferenceRunnerBase(Callback):
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
def
fn
(
_
):
def
fn
(
_
):
in_tensors
=
self
.
_input_source
.
get_input_tensors
()
self
.
trainer
.
model
.
build_graph
(
self
.
_input_source
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
with
tf
.
variable_scope
(
self
.
trainer
.
vs_name_for_predictor
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
trainer
.
vs_name_for_predictor
,
reuse
=
True
):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
...
@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# build graph
# build graph
def
build_tower
(
k
):
def
build_tower
(
k
):
# inputs (placeholders) for this tower only
# inputs (placeholders) for this tower only
input_tensors
=
self
.
_input_source
.
get_input_tensors
()
model
.
build_graph
(
self
.
_input_source
)
model
.
build_graph
(
input_tensors
)
builder
=
PredictorTowerBuilder
(
build_tower
,
prefix
=
self
.
_prefix
)
builder
=
PredictorTowerBuilder
(
build_tower
,
prefix
=
self
.
_prefix
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
...
...
tensorpack/models/model_desc.py
View file @
4d2a7b4c
...
@@ -9,6 +9,8 @@ import tensorflow as tf
...
@@ -9,6 +9,8 @@ import tensorflow as tf
import
six
import
six
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
# TODO sort out import issues
# from ..train.input_source import InputSource
from
.regularize
import
regularize_cost_from_collection
from
.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'ModelDesc'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
]
...
@@ -130,15 +132,21 @@ class ModelDesc(object):
...
@@ -130,15 +132,21 @@ class ModelDesc(object):
:returns: a list of InputDesc
:returns: a list of InputDesc
"""
"""
def
build_graph
(
self
,
model_inputs
):
# TODO only use InputSource in the future? Now mainly used in predict/
def
build_graph
(
self
,
inputs
):
"""
"""
Build the whole symbolic graph.
Build the whole symbolic graph.
Args:
Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
InputDesc of this model.
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
"""
self
.
_build_graph
(
model_inputs
)
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
=
inputs
.
get_input_tensors
()
assert
len
(
inputs
)
==
len
(
self
.
get_inputs_desc
()),
\
"Number of inputs passed to the graph != number of inputs defined "
\
"in ModelDesc! ({} != {})"
.
format
(
len
(
inputs
),
len
(
self
.
get_inputs_desc
()))
self
.
_build_graph
(
inputs
)
@
abstractmethod
@
abstractmethod
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
...
...
tensorpack/train/base.py
View file @
4d2a7b4c
...
@@ -108,6 +108,15 @@ class Trainer(object):
...
@@ -108,6 +108,15 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
"""
def
_setup_input_source
(
self
,
input_source
):
"""
Setup InputSource on this trainer.
"""
input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
cbs
=
input_source
.
get_callbacks
()
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
def
setup
(
self
):
def
setup
(
self
):
"""
"""
Setup the trainer and be ready for the main loop.
Setup the trainer and be ready for the main loop.
...
...
tensorpack/train/feedfree.py
View file @
4d2a7b4c
...
@@ -20,27 +20,10 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -20,27 +20,10 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
"""
def
build_train_tower
(
self
):
"""
Get input tensors from `self.input_source` and build the forward graph.
"""
def
f
():
self
.
_input_tensors
=
self
.
_input_source
.
get_input_tensors
()
self
.
model
.
build_graph
(
self
.
_input_tensors
)
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
# call without a context, use a default one
with
TowerContext
(
''
,
is_training
=
True
):
f
()
else
:
assert
ctx
.
is_training
,
ctx
f
()
def
_setup
(
self
):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
_setup_input_source
(
self
.
_input_source
)
input_callbacks
=
self
.
_input_source
.
get_callbacks
()
for
cb
in
input_callbacks
:
self
.
register_callback
(
cb
)
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``."""
""" Simply run ``self.train_op``."""
...
@@ -51,10 +34,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -51,10 +34,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
build_train_tower
()
ctx
=
get_current_tower_context
()
assert
ctx
.
is_training
,
ctx
self
.
model
.
build_graph
(
self
.
_input_source
)
cost
=
self
.
model
.
get_cost
()
# assume single cost
cost
=
self
.
model
.
get_cost
()
# assume single cost
# produce gradients
varlist
=
tf
.
trainable_variables
()
varlist
=
tf
.
trainable_variables
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
# only optimize w.r.t vars in this tower
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
# TODO use ctx.vars?
...
@@ -93,8 +80,6 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
...
@@ -93,8 +80,6 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
cost
,
grads
=
self
.
_get_cost_and_grad
()
cost
,
grads
=
self
.
_get_cost_and_grad
()
opt
=
self
.
model
.
get_optimizer
()
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
# skip training
# self.train_op = tf.group(*self._input_tensors)
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
...
...
tensorpack/train/simple.py
View file @
4d2a7b4c
...
@@ -34,15 +34,11 @@ class SimpleTrainer(Trainer):
...
@@ -34,15 +34,11 @@ class SimpleTrainer(Trainer):
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
_setup
(
self
):
def
_setup
(
self
):
model
=
self
.
model
self
.
_setup_input_source
(
self
.
_input_source
)
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
get_callbacks
()
for
cb
in
cbs
:
self
.
register_callback
(
cb
)
self
.
inputs
=
self
.
_input_source
.
get_input_tensors
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
self
.
model
.
build_graph
(
self
.
_input_source
)
cost_var
=
model
.
get_cost
()
cost_var
=
self
.
model
.
get_cost
()
opt
=
self
.
model
.
get_optimizer
()
opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment