Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8cea7d8a
Commit
8cea7d8a
authored
Oct 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add train_with_defaults
parent
af667ff4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
81 additions
and
53 deletions
+81
-53
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+2
-4
examples/GAN/CycleGAN.py
examples/GAN/CycleGAN.py
+1
-3
examples/GAN/DCGAN.py
examples/GAN/DCGAN.py
+3
-4
examples/GAN/DiscoGAN-CelebA.py
examples/GAN/DiscoGAN-CelebA.py
+3
-5
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+3
-5
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+3
-4
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+2
-3
examples/GAN/WGAN.py
examples/GAN/WGAN.py
+7
-6
tensorpack/train/base.py
tensorpack/train/base.py
+42
-5
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-0
tensorpack/trainv1/config.py
tensorpack/trainv1/config.py
+14
-14
No files found.
examples/GAN/ConditionalGAN-mnist.py
View file @
8cea7d8a
...
...
@@ -134,11 +134,9 @@ if __name__ == '__main__':
sample
(
args
.
load
)
else
:
logger
.
auto_set_dir
()
config
=
TrainConfig
(
GANTrainer
(
QueueInput
(
get_data
()),
Model
())
.
train_with_defaults
(
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
500
,
max_epoch
=
100
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
GANTrainer
(
QueueInput
(
get_data
()),
Model
())
.
train_with_config
(
config
)
examples/GAN/CycleGAN.py
View file @
8cea7d8a
...
...
@@ -218,7 +218,7 @@ if __name__ == '__main__':
data
=
get_data
(
args
.
data
)
data
=
PrintData
(
data
)
config
=
TrainConfig
(
GANTrainer
(
QueueInput
(
data
),
Model
())
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
...
...
@@ -230,5 +230,3 @@ if __name__ == '__main__':
steps_per_epoch
=
data
.
size
(),
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
GANTrainer
(
QueueInput
(
data
),
Model
())
.
train_with_config
(
config
)
examples/GAN/DCGAN.py
View file @
8cea7d8a
...
...
@@ -156,12 +156,11 @@ if __name__ == '__main__':
else
:
assert
args
.
data
logger
.
auto_set_dir
()
config
=
TrainConfig
(
GANTrainer
(
input
=
QueueInput
(
get_data
(
args
.
data
)),
model
=
Model
())
.
train_with_defaults
(
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
max_epoch
=
200
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
GANTrainer
(
input
=
QueueInput
(
get_data
(
args
.
data
)),
model
=
Model
())
.
train_with_config
(
config
)
examples/GAN/DiscoGAN-CelebA.py
View file @
8cea7d8a
...
...
@@ -217,13 +217,11 @@ if __name__ == '__main__':
data
=
get_celebA_data
(
args
.
data
,
args
.
style_A
,
args
.
style_B
)
config
=
TrainConfig
(
# train 1 D after 2 G
SeparateGANTrainer
(
QueueInput
(
data
),
Model
(),
d_period
=
3
)
.
train_with_defaults
(
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
max_epoch
=
250
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
# train 1 D after 2 G
SeparateGANTrainer
(
QueueInput
(
data
),
Model
(),
d_period
=
3
)
.
train_with_config
(
config
)
examples/GAN/Image2Image.py
View file @
8cea7d8a
...
...
@@ -210,15 +210,13 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
data
=
QueueInput
(
get_data
())
config
=
TrainConfig
(
GANTrainer
(
data
,
Model
())
.
train_with_defaults
(
callbacks
=
[
PeriodicTrigger
(
ModelSaver
(),
every_k_epochs
=
3
),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
],
steps_per_epoch
=
data
.
size
(),
max_epoch
=
300
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
GANTrainer
(
data
,
Model
())
.
train_with_config
(
config
)
examples/GAN/Improved-WGAN.py
View file @
8cea7d8a
...
...
@@ -95,12 +95,11 @@ if __name__ == '__main__':
else
:
assert
args
.
data
logger
.
auto_set_dir
()
config
=
TrainConfig
(
SeparateGANTrainer
(
QueueInput
(
DCGAN
.
get_data
(
args
.
data
)),
Model
(),
g_period
=
6
)
.
train_with_defaults
(
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
max_epoch
=
200
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
SeparateGANTrainer
(
QueueInput
(
DCGAN
.
get_data
(
args
.
data
)),
Model
(),
g_period
=
6
)
.
train_with_config
(
config
)
examples/GAN/InfoGAN-mnist.py
View file @
8cea7d8a
...
...
@@ -245,11 +245,10 @@ if __name__ == '__main__':
sample
(
args
.
load
)
else
:
logger
.
auto_set_dir
()
cfg
=
TrainConfig
(
GANTrainer
(
QueueInput
(
get_data
()),
Model
())
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(
keep_freq
=
0.1
)],
steps_per_epoch
=
500
,
max_epoch
=
100
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
GANTrainer
(
QueueInput
(
get_data
()),
Model
())
.
train_with_config
(
cfg
)
examples/GAN/WGAN.py
View file @
8cea7d8a
...
...
@@ -76,14 +76,15 @@ if __name__ == '__main__':
else
:
assert
args
.
data
logger
.
auto_set_dir
()
config
=
TrainConfig
(
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer
(
input
=
QueueInput
(
DCGAN
.
get_data
(
args
.
data
)),
model
=
Model
(),
d_period
=
3
)
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(),
ClipCallback
()],
steps_per_epoch
=
500
,
max_epoch
=
200
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
)
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer
(
input
=
QueueInput
(
DCGAN
.
get_data
(
args
.
data
)),
model
=
Model
(),
d_period
=
3
)
.
train_with_config
(
config
)
tensorpack/train/base.py
View file @
8cea7d8a
...
...
@@ -9,13 +9,16 @@ from six.moves import range
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..callbacks
import
(
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
,
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
..callbacks.steps
import
MaintainStepCounter
...
...
@@ -31,6 +34,18 @@ from ..trainv1.config import TrainConfig
__all__
=
[
'TrainConfig'
,
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
def
DEFAULT_CALLBACKS
():
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
def
DEFAULT_MONITORS
():
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
class
Trainer
(
object
):
""" Base class for a trainer.
"""
...
...
@@ -220,8 +235,8 @@ class Trainer(object):
def
train_with_config
(
self
,
config
):
"""
An alias to simplify the use of `TrainConfig`.
It is equivalent to
the following:
An alias to simplify the use of `TrainConfig`
with `Trainer`
.
This method is literally
the following:
.. code-block:: python
...
...
@@ -240,6 +255,28 @@ class Trainer(object):
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
def
train_with_defaults
(
self
,
callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_init
=
None
,
steps_per_epoch
=
None
,
starting_epoch
=
1
,
max_epoch
=
9999
):
"""
Same as :meth:`train()`, but will:
1. Append `DEFAULT_CALLBACKS()` to callbacks.
2. Append `DEFAULT_MONITORS()` to monitors.
3. Provide default values for every option except `steps_per_epoch`.
"""
callbacks
=
(
callbacks
or
[])
+
DEFAULT_CALLBACKS
()
monitors
=
(
monitors
or
[])
+
DEFAULT_MONITORS
()
assert
steps_per_epoch
is
not
None
session_creator
=
session_creator
or
NewSessionCreator
()
session_init
=
session_init
or
JustCurrentSession
()
self
.
train
(
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
...
...
tensorpack/train/interface.py
View file @
8cea7d8a
...
...
@@ -86,4 +86,5 @@ def launch_train_with_config(config, trainer):
trainer
.
setup_graph
(
inputs_desc
,
input
,
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
config
.
data
=
config
.
dataflow
=
config
.
model
=
None
trainer
.
train_with_config
(
config
)
tensorpack/trainv1/config.py
View file @
8cea7d8a
...
...
@@ -17,6 +17,18 @@ from ..utils.develop import log_deprecated
__all__
=
[
'TrainConfig'
]
def
DEFAULT_CALLBACKS
():
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
def
DEFAULT_MONITORS
():
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
class
TrainConfig
(
object
):
"""
A collection of options to be used for trainers.
...
...
@@ -84,9 +96,9 @@ class TrainConfig(object):
callbacks
=
[]
assert_type
(
callbacks
,
list
)
self
.
_callbacks
=
callbacks
+
\
(
extra_callbacks
or
TrainConfig
.
DEFAULT_EXTRA
_CALLBACKS
())
(
extra_callbacks
or
DEFAULT
_CALLBACKS
())
self
.
monitors
=
monitors
or
TrainConfig
.
DEFAULT_MONITORS
()
self
.
monitors
=
monitors
or
DEFAULT_MONITORS
()
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
...
...
@@ -148,15 +160,3 @@ class TrainConfig(object):
@
property
def
callbacks
(
self
):
# disable setter
return
self
.
_callbacks
@
staticmethod
def
DEFAULT_EXTRA_CALLBACKS
():
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
@
staticmethod
def
DEFAULT_MONITORS
():
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment