Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
651a5aea
Commit
651a5aea
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
deprecate TrainConfig.dataset and use 'dataflow' instead
parent
069c0b9c
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
58 additions
and
44 deletions
+58
-44
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+1
-1
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+1
-1
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+1
-1
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-2
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-1
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+1
-1
examples/HED/hed.py
examples/HED/hed.py
+1
-1
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+1
-1
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+1
-1
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+1
-1
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+1
-1
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+1
-1
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+1
-1
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+1
-1
examples/TIMIT/train-timit.py
examples/TIMIT/train-timit.py
+1
-1
examples/char-rnn/char-rnn.py
examples/char-rnn/char-rnn.py
+1
-1
examples/cifar-convnet.py
examples/cifar-convnet.py
+1
-1
examples/mnist-convnet.py
examples/mnist-convnet.py
+1
-1
examples/svhn-digit-convnet.py
examples/svhn-digit-convnet.py
+1
-1
tensorpack/train/config.py
tensorpack/train/config.py
+15
-9
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+16
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+4
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-2
No files found.
examples/Atari2600/DQN.py
View file @
651a5aea
...
...
@@ -173,7 +173,7 @@ def get_config():
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
1e-3
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
651a5aea
...
...
@@ -233,7 +233,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
1e-4
,
summary
=
True
)
return
TrainConfig
(
data
set
=
data_train
,
data
flow
=
data_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-5
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
651a5aea
...
...
@@ -161,7 +161,7 @@ def get_config():
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
TrainConfig
(
data
set
=
data_train
,
data
flow
=
data_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-5
),
callbacks
=
Callbacks
([
StatPrinter
(),
...
...
examples/GAN/DCGAN-CelebA.py
View file @
651a5aea
...
...
@@ -109,7 +109,7 @@ def get_config():
dataset
=
get_data
()
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
2e-4
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset
,
data
flow
=
dataset
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/GAN/GAN.py
View file @
651a5aea
...
...
@@ -6,13 +6,13 @@
import
tensorflow
as
tf
import
numpy
as
np
import
time
from
tensorpack
import
(
FeedfreeTrainer
,
TowerContext
,
from
tensorpack
import
(
FeedfreeTrainer
Base
,
TowerContext
,
get_global_step_var
,
QueueInput
)
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
tensorpack.dataflow
import
DataFlow
class
GANTrainer
(
FeedfreeTrainer
):
class
GANTrainer
(
FeedfreeTrainer
Base
):
def
__init__
(
self
,
config
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
)
...
...
examples/GAN/Image2Image.py
View file @
651a5aea
...
...
@@ -168,7 +168,7 @@ def get_config():
dataset
=
get_data
()
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
2e-4
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset
,
data
flow
=
dataset
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
PeriodicCallback
(
ModelSaver
(),
3
),
...
...
examples/GAN/InfoGAN-mnist.py
View file @
651a5aea
...
...
@@ -104,7 +104,7 @@ def get_config():
dataset
=
get_data
()
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
2e-4
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset
,
data
flow
=
dataset
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/HED/hed.py
View file @
651a5aea
...
...
@@ -171,7 +171,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
3e-5
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/Inception/inception-bn.py
View file @
651a5aea
...
...
@@ -158,7 +158,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
0.045
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/Inception/inceptionv3.py
View file @
651a5aea
...
...
@@ -266,7 +266,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
0.045
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/OpenAIGym/train-atari.py
View file @
651a5aea
...
...
@@ -204,7 +204,7 @@ def get_config():
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
0.001
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataflow
,
data
flow
=
dataflow
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/ResNet/cifar10-resnet.py
View file @
651a5aea
...
...
@@ -139,7 +139,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
0.01
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/ResNet/imagenet-resnet.py
View file @
651a5aea
...
...
@@ -187,7 +187,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
0.1
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
,
use_nesterov
=
True
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/ResNet/svhn-resnet.py
View file @
651a5aea
...
...
@@ -68,7 +68,7 @@ def get_config():
lr
=
get_scalar_var
(
'learning_rate'
,
0.01
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
callbacks
=
Callbacks
([
StatPrinter
(),
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
651a5aea
...
...
@@ -153,7 +153,7 @@ def get_config():
lr
=
symbf
.
get_scalar_var
(
'learning_rate'
,
5e-4
,
summary
=
True
)
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/TIMIT/train-timit.py
View file @
651a5aea
...
...
@@ -94,7 +94,7 @@ def get_config(ds_train, ds_test):
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
5e-3
,
summary
=
True
)
return
TrainConfig
(
data
set
=
ds_train
,
data
flow
=
ds_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/char-rnn/char-rnn.py
View file @
651a5aea
...
...
@@ -107,7 +107,7 @@ def get_config():
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
2e-3
,
summary
=
True
)
return
TrainConfig
(
data
set
=
ds
,
data
flow
=
ds
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/cifar-convnet.py
View file @
651a5aea
...
...
@@ -122,7 +122,7 @@ def get_config(cifar_classnum):
return
lr
*
0.31
return
TrainConfig
(
data
set
=
dataset_train
,
data
flow
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
examples/mnist-convnet.py
View file @
651a5aea
...
...
@@ -114,7 +114,7 @@ def get_config():
# get the config which contains everything necessary in a training
return
TrainConfig
(
data
set
=
dataset_train
,
# the DataFlow instance for training
data
flow
=
dataset_train
,
# the DataFlow instance for training
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
StatPrinter
(),
# print statistics in terminal after every epoch
...
...
examples/svhn-digit-convnet.py
View file @
651a5aea
...
...
@@ -99,7 +99,7 @@ def get_config():
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
TrainConfig
(
data
set
=
data_train
,
data
flow
=
data_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
...
...
tensorpack/train/config.py
View file @
651a5aea
...
...
@@ -20,7 +20,7 @@ class TrainConfig(object):
Config for trainer.
"""
def
__init__
(
self
,
data
set
=
None
,
data
=
None
,
def
__init__
(
self
,
data
flow
=
None
,
data
=
None
,
model
=
None
,
optimizer
=
None
,
callbacks
=
None
,
session_config
=
get_default_sess_config
(),
session_init
=
None
,
...
...
@@ -29,8 +29,8 @@ class TrainConfig(object):
**
kwargs
):
"""
Args:
data
set (DataFlow): the dataset
to train.
data (InputData): an `InputData` instance. Only one of ``data
set
``
data
flow (DataFlow): the dataflow
to train.
data (InputData): an `InputData` instance. Only one of ``data
flow
``
or ``data`` has to be present.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
...
...
@@ -49,13 +49,19 @@ class TrainConfig(object):
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
if
dataset
is
not
None
:
assert
data
is
None
,
"dataset and data cannot be both presented in TrainConfig!"
self
.
dataset
=
dataset
assert_type
(
self
.
dataset
,
DataFlow
)
if
'dataset'
in
kwargs
:
dataflow
=
kwargs
.
pop
(
'dataset'
)
logger
.
warn
(
"[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead."
)
if
dataflow
is
not
None
:
assert
data
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
self
.
dataflow
=
dataflow
assert_type
(
self
.
dataflow
,
DataFlow
)
self
.
data
=
None
else
:
self
.
data
=
data
assert_type
(
self
.
data
,
InputData
)
self
.
dataflow
=
None
self
.
optimizer
=
optimizer
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
...
...
@@ -74,8 +80,8 @@ class TrainConfig(object):
self
.
step_per_epoch
=
step_per_epoch
if
self
.
step_per_epoch
is
None
:
try
:
if
data
set
is
not
None
:
self
.
step_per_epoch
=
self
.
data
set
.
size
()
if
data
flow
is
not
None
:
self
.
step_per_epoch
=
self
.
data
flow
.
size
()
else
:
self
.
step_per_epoch
=
self
.
data
.
size
()
except
NotImplementedError
:
...
...
tensorpack/train/feedfree.py
View file @
651a5aea
...
...
@@ -15,12 +15,12 @@ from .input_data import QueueInput, FeedfreeInput
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
__all__
=
[
'FeedfreeTrainer
Base
'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster)
class
FeedfreeTrainer
Base
(
Trainer
):
""" A
base
trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
...
...
@@ -39,7 +39,7 @@ class FeedfreeTrainer(Trainer):
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
Base
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
...
...
@@ -78,11 +78,16 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
"""
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
"""
def
__init__
(
self
,
config
):
"""
A trainer with single cost, single training tower and feed-free input
config.data must exists
Args:
config (TrainConfig): ``config.data`` must exist and is a
:class:`FeedfreeInput`.
"""
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
...
...
@@ -105,17 +110,20 @@ class SimpleFeedfreeTrainer(
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
"""
A trainer which automatically wraps ``config.dataflow``
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.data
set
must exist
:param config: a `TrainConfig` instance. config.data
flow
must exist
:param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
config
.
data
=
QueueInput
(
config
.
data
set
,
input_queue
)
config
.
data
=
QueueInput
(
config
.
data
flow
,
input_queue
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
...
...
tensorpack/train/multigpu.py
View file @
651a5aea
...
...
@@ -53,8 +53,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
)
:
self
.
_input_method
=
QueueInput
(
config
.
data
set
,
input_queue
)
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
data
flow
,
input_queue
)
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
...
...
@@ -122,8 +122,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue
=
None
,
average_gradient
=
True
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
)
:
self
.
_input_method
=
QueueInput
(
config
.
data
set
,
input_queue
)
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
data
flow
,
input_queue
)
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
...
...
tensorpack/train/trainer.py
View file @
651a5aea
...
...
@@ -59,11 +59,11 @@ class SimpleTrainer(Trainer):
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
if
not
hasattr
(
config
,
'dataset'
)
:
if
config
.
dataflow
is
None
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedInput
)
else
:
self
.
_input_method
=
FeedInput
(
config
.
data
set
)
self
.
_input_method
=
FeedInput
(
config
.
data
flow
)
def
run_step
(
self
):
feed
=
self
.
_input_method
.
next_feed
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment