Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0b226adb
Commit
0b226adb
authored
Jul 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
allow empty model & data in config
parent
e791b9a5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
26 deletions
+24
-26
tensorpack/train/config.py
tensorpack/train/config.py
+24
-26
No files found.
tensorpack/train/config.py
View file @
0b226adb
...
...
@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..dataflow.base
import
DataFlow
...
...
@@ -24,20 +24,23 @@ class TrainConfig(object):
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
]
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
None
,
**
kwargs
):
"""
Note:
It depends on the specific trainer what fields are necessary.
Most existing trainers in tensorpack requires one of `dataflow` or `data`,
and `model` to be present in the config.
Args:
dataflow (DataFlow):
the dataflow to train.
data (InputSource):
an `InputSource` instance. Only one of ``dataflow``
or ``data`` has to be present.
model (ModelDesc): the model to train.
dataflow (DataFlow):
data (InputSource):
model (ModelDesc):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
...
...
@@ -45,14 +48,17 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id. Use -1 for cpu.
...
...
@@ -62,7 +68,7 @@ class TrainConfig(object):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
# process data
# process data
& model
if
'dataset'
in
kwargs
:
dataflow
=
kwargs
.
pop
(
'dataset'
)
log_deprecated
(
"TrainConfig.dataset"
,
"Use TrainConfig.dataflow instead."
,
"2017-09-11"
)
...
...
@@ -71,16 +77,16 @@ class TrainConfig(object):
self
.
dataflow
=
dataflow
assert_type
(
self
.
dataflow
,
DataFlow
)
self
.
data
=
None
els
e
:
if
data
is
not
Non
e
:
self
.
data
=
data
assert_type
(
self
.
data
,
InputSource
)
self
.
dataflow
=
None
if
model
is
not
None
:
assert_type
(
model
,
ModelDesc
)
self
.
model
=
model
if
callbacks
is
None
:
callbacks
=
[]
assert
not
isinstance
(
callbacks
,
Callbacks
),
\
"TrainConfig(callbacks=Callbacks([...]))"
\
"Change the argument 'callbacks=' to a *list* of callbacks without StatPrinter()."
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
...
...
@@ -89,16 +95,11 @@ class TrainConfig(object):
MergeAllSummaries
(),
RunUpdateOps
()]
self
.
_callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
_callbacks
,
list
)
if
monitors
is
None
:
monitors
=
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
self
.
monitors
=
monitors
if
model
is
not
None
:
assert_type
(
model
,
ModelDesc
)
self
.
model
=
model
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
...
...
@@ -128,22 +129,19 @@ class TrainConfig(object):
self
.
starting_epoch
=
int
(
starting_epoch
)
self
.
max_epoch
=
int
(
max_epoch
)
assert
self
.
steps_per_epoch
>
=
0
and
self
.
max_epoch
>
0
assert
self
.
steps_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
nr_tower
if
tower
is
not
None
:
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
tower
=
tower
if
predict_tower
is
None
:
predict_tower
=
[
0
]
self
.
predict_tower
=
predict_tower
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
assert
len
(
set
(
self
.
predict_tower
))
==
len
(
self
.
predict_tower
),
\
"Cannot have duplicated predict_tower!"
assert
'optimizer'
not
in
kwargs
,
\
"TrainConfig(optimizer=...) was already deprecated! "
\
"Use ModelDesc._get_optimizer() instead."
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
@
property
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment