Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
047579df
Commit
047579df
authored
Aug 22, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
TowerFuncWrapper -> TowerFunc
parent
50ff9036
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
41 additions
and
37 deletions
+41
-37
docs/conf.py
docs/conf.py
+2
-1
examples/GAN/GAN.py
examples/GAN/GAN.py
+4
-4
examples/keras/mnist-keras.py
examples/keras/mnist-keras.py
+1
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+0
-5
tensorpack/predict/config.py
tensorpack/predict/config.py
+7
-7
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+0
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+9
-5
tensorpack/train/config.py
tensorpack/train/config.py
+9
-4
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+4
-4
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-1
No files found.
docs/conf.py
View file @
047579df
...
...
@@ -384,6 +384,7 @@ _DEPRECATED_NAMES = set([
'get_nr_gpu'
,
'TrainingMonitor'
,
'PeakMemoryTracker'
,
'TowerFuncWrapper'
,
'PrefetchData'
,
'MultiProcessPrefetchData'
,
...
...
@@ -391,7 +392,7 @@ _DEPRECATED_NAMES = set([
'MultiThreadPrefetchData'
,
# deprecated or renamed symbolic code
'Deconv2D'
,
'psnr'
,
'Deconv2D'
,
# shouldn't appear in doc:
'l2_regularizer'
,
'l1_regularizer'
,
...
...
examples/GAN/GAN.py
View file @
047579df
...
...
@@ -8,7 +8,7 @@ import tensorflow as tf
from
tensorpack
import
BatchNorm
,
DataFlow
,
ModelDescBase
,
StagingInput
,
TowerTrainer
,
argscope
from
tensorpack.graph_builder
import
DataParallelBuilder
,
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.tower
import
TowerContext
,
TowerFunc
Wrapper
from
tensorpack.tfutils.tower
import
TowerContext
,
TowerFunc
from
tensorpack.utils
import
logger
from
tensorpack.utils.argtools
import
memoized_method
...
...
@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self
.
tower_func
=
TowerFunc
Wrapper
(
model
.
build_graph
,
model
.
inputs
())
self
.
tower_func
=
TowerFunc
(
model
.
build_graph
,
model
.
inputs
())
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
...
...
@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
self
.
tower_func
=
TowerFunc
Wrapper
(
get_cost
,
model
.
get_input_signature
())
self
.
tower_func
=
TowerFunc
(
get_cost
,
model
.
get_input_signature
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
num_gpu
)),
...
...
@@ -167,7 +167,7 @@ class SeparateGANTrainer(TowerTrainer):
self
.
register_callback
(
cbs
)
# Build the graph
self
.
tower_func
=
TowerFunc
Wrapper
(
model
.
build_graph
,
model
.
inputs
())
self
.
tower_func
=
TowerFunc
(
model
.
build_graph
,
model
.
inputs
())
with
TowerContext
(
''
,
is_training
=
True
),
\
argscope
(
BatchNorm
,
ema_update
=
'internal'
):
# should not hook the EMA updates to both train_op, it will hurt training speed.
...
...
examples/keras/mnist-keras.py
View file @
047579df
...
...
@@ -120,7 +120,7 @@ if __name__ == '__main__':
if
get_num_gpu
()
<=
1
:
# single GPU:
launch_train_with_config
(
cfg
,
QueueInput
Trainer
())
launch_train_with_config
(
cfg
,
Simple
Trainer
())
else
:
# multi GPU:
launch_train_with_config
(
cfg
,
SyncMultiGPUTrainerParameterServer
(
2
))
...
...
tensorpack/callbacks/inference_runner.py
View file @
047579df
...
...
@@ -114,7 +114,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
tower_func (tfutils.TowerFunc
Wrapper
or None): the tower function to be used to build the graph.
tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph.
By defaults to call `trainer.tower_func` under a `training=False` TowerContext,
but you can change it to a different tower function
if you need to inference with several different graphs.
...
...
@@ -196,7 +196,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
gpus (int or list[int]): #gpus, or list of GPU id
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
tower_func (tfutils.TowerFunc
Wrapper
or None): the tower function to be used to build the graph.
tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph.
The tower function will be called under a `training=False` TowerContext.
The default is `trainer.tower_func`,
but you can change it to a different tower function
...
...
tensorpack/graph_builder/model_desc.py
View file @
047579df
...
...
@@ -6,7 +6,6 @@ from collections import namedtuple
import
tensorflow
as
tf
from
..utils.argtools
import
memoized_method
from
..utils.develop
import
deprecated
from
..tfutils.common
import
get_op_tensor_name
from
..compat
import
backport_tensor_spec
,
tfv1
...
...
@@ -174,7 +173,3 @@ class ModelDesc(ModelDescBase):
A subclass is expected to implement this method.
"""
raise
NotImplementedError
()
@
deprecated
(
"Just use `build_graph` instead!"
)
def
_build_graph_get_cost
(
self
,
*
inputs
):
return
self
.
build_graph
(
*
inputs
)
tensorpack/predict/config.py
View file @
047579df
...
...
@@ -9,7 +9,7 @@ from ..graph_builder import ModelDescBase
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
JustCurrentSession
,
SessionInit
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.tower
import
TowerFunc
Wrapper
from
..tfutils.tower
import
TowerFunc
from
..utils
import
logger
__all__
=
[
'PredictConfig'
]
...
...
@@ -36,7 +36,7 @@ class PredictConfig(object):
This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFunc
Wrapper
` instance.
2. `tower_func`: a :class:`tfutils.TowerFunc` instance.
Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature.
...
...
@@ -52,8 +52,8 @@ class PredictConfig(object):
Args:
model (ModelDescBase): to be used to construct a tower function.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFunc
Wrapper
` instance.
input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFunc
Wrapper
),
or a :class:`tfutils.TowerFunc` instance.
input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFunc),
this describes the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match input_signature.
...
...
@@ -85,13 +85,13 @@ class PredictConfig(object):
assert_type
(
model
,
ModelDescBase
,
'model'
)
assert
input_signature
is
None
and
tower_func
is
None
self
.
input_signature
=
model
.
get_input_signature
()
self
.
tower_func
=
TowerFunc
Wrapper
(
model
.
build_graph
,
self
.
input_signature
)
self
.
tower_func
=
TowerFunc
(
model
.
build_graph
,
self
.
input_signature
)
else
:
if
isinstance
(
tower_func
,
TowerFunc
Wrapper
):
if
isinstance
(
tower_func
,
TowerFunc
):
input_signature
=
tower_func
.
input_signature
assert
input_signature
is
not
None
and
tower_func
is
not
None
self
.
input_signature
=
input_signature
self
.
tower_func
=
TowerFunc
Wrapper
(
tower_func
,
input_signature
)
self
.
tower_func
=
TowerFunc
(
tower_func
,
input_signature
)
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
047579df
...
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..utils.develop
import
deprecated
__all__
=
[
'print_stat'
,
'rms'
]
...
...
@@ -37,7 +36,6 @@ def rms(x, name=None):
# don't hurt to leave it here
@
deprecated
(
"Please implement it by yourself."
,
"2018-04-28"
)
def
psnr
(
prediction
,
ground_truth
,
maxp
=
None
,
name
=
'psnr'
):
"""`Peak Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
...
...
tensorpack/tfutils/tower.py
View file @
047579df
...
...
@@ -15,7 +15,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from
.collection
import
CollectionGuard
from
.common
import
get_op_or_tensor_by_name
,
get_op_tensor_name
__all__
=
[
'get_current_tower_context'
,
'BaseTowerContext'
,
'TowerContext'
,
'TowerFuncWrapper'
,
__all__
=
[
'get_current_tower_context'
,
'BaseTowerContext'
,
'TowerContext'
,
'TowerFuncWrapper'
,
'TowerFunc'
,
'TowerTensorHandle'
,
'TowerTensorHandles'
]
_CurrentTowerContext
=
None
...
...
@@ -245,9 +246,9 @@ def TowerContext(tower_name, is_training, vs_name=''):
return
PredictTowerContext
(
tower_name
,
vs_name
=
vs_name
)
class
TowerFunc
Wrapper
(
object
):
class
TowerFunc
(
object
):
"""
A
wrapper around a
tower function (see
A tower function (see
[tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)).
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
...
...
@@ -279,10 +280,10 @@ class TowerFuncWrapper(object):
def
__new__
(
cls
,
tower_fn
,
_
):
# to avoid double-wrapping a function
if
isinstance
(
tower_fn
,
TowerFunc
Wrapper
):
if
isinstance
(
tower_fn
,
TowerFunc
):
return
tower_fn
else
:
return
super
(
TowerFunc
Wrapper
,
cls
)
.
__new__
(
cls
)
return
super
(
TowerFunc
,
cls
)
.
__new__
(
cls
)
def
__call__
(
self
,
*
args
):
ctx
=
get_current_tower_context
()
...
...
@@ -311,6 +312,9 @@ class TowerFuncWrapper(object):
return
self
.
_input_signature
TowerFuncWrapper
=
TowerFunc
class
TowerTensorHandles
(
object
):
"""
Wrap a list of :class:`TowerTensorHandle`,
...
...
tensorpack/train/config.py
View file @
047579df
...
...
@@ -176,9 +176,14 @@ class AutoResumeTrainConfig(TrainConfig):
Note that the functionality requires the logging directory to obtain
necessary information from a previous run.
In some cases (e.g. when using Horovod), the directory is not
available, or the directories are different for different workers,
then this class may not function properly.
If you have unconventional setup of logging directory, this class will not
work for you, for example:
1. If you save the checkpoint to a different directory rather than the
logging directory.
2. If in distributed training the directory is not
available to every worker, or the directories are different for different workers.
"""
def
__init__
(
self
,
always_resume
=
True
,
**
kwargs
):
"""
...
...
@@ -189,7 +194,7 @@ class AutoResumeTrainConfig(TrainConfig):
kwargs: same as in :class:`TrainConfig`.
Note:
The main goal of this class is to let a training job
to
resume
The main goal of this class is to let a training job resume
without changing any line of code or command line arguments.
So it's useful to let resume take priority over user-provided arguments sometimes.
...
...
tensorpack/train/interface.py
View file @
047579df
...
...
@@ -85,7 +85,7 @@ def launch_train_with_config(config, trainer):
# This is the only place where the `ModelDesc` abstraction is useful.
# We should gradually stay away from this unuseful abstraction.
# TowerFunc
Wrapper is a better abstraction (similar to tf.defu
n in the future)
# TowerFunc
is a better abstraction (similar to tf.functio
n in the future)
trainer
.
setup_graph
(
model
.
get_input_signature
(),
input
,
model
.
build_graph
,
model
.
get_optimizer
)
...
...
tensorpack/train/tower.py
View file @
047579df
...
...
@@ -9,7 +9,7 @@ from ..compat import tfv1, is_tfv2
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
PredictTowerContext
,
TowerFunc
Wrapper
,
get_current_tower_context
from
..tfutils.tower
import
PredictTowerContext
,
TowerFunc
,
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.develop
import
HIDE_DOC
...
...
@@ -38,13 +38,13 @@ class TowerTrainer(Trainer):
@
call_only_once
def
_set_tower_func
(
self
,
tower_func
):
assert
isinstance
(
tower_func
,
TowerFunc
Wrapper
),
tower_func
assert
isinstance
(
tower_func
,
TowerFunc
),
tower_func
self
.
_tower_func
=
tower_func
@
property
def
tower_func
(
self
):
"""
A :class:`TowerFunc
Wrapper
` instance.
A :class:`TowerFunc` instance.
See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)
for more information.
"""
...
...
@@ -215,7 +215,7 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
"""
get_cost_fn
=
TowerFunc
Wrapper
(
get_cost_fn
,
input_signature
)
get_cost_fn
=
TowerFunc
(
get_cost_fn
,
input_signature
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
tower_func
=
get_cost_fn
...
...
tensorpack/train/trainers.py
View file @
047579df
...
...
@@ -18,7 +18,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from
..tfutils.tower
import
TrainTowerContext
from
..utils
import
logger
from
..utils.argtools
import
map_arg
from
..utils.develop
import
HIDE_DOC
from
..utils.develop
import
HIDE_DOC
,
deprecated
from
.tower
import
SingleCostTrainer
__all__
=
[
'NoOpTrainer'
,
'SimpleTrainer'
,
...
...
@@ -66,6 +66,7 @@ class NoOpTrainer(SimpleTrainer):
# Only exists for type check & back-compatibility
class
QueueInputTrainer
(
SimpleTrainer
):
@
deprecated
(
"SimpleTrainer is sufficient!"
,
"2019-12-31"
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
assert
isinstance
(
input
,
QueueInput
),
input
return
super
(
QueueInputTrainer
,
self
)
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment