Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b82a1fda
Commit
b82a1fda
authored
Oct 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix import; introducing new interface to train after graph has been built.
parent
f7993410
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
25 deletions
+51
-25
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-1
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+1
-20
tensorpack/train/base.py
tensorpack/train/base.py
+48
-0
tensorpack/train/config.py
tensorpack/train/config.py
+0
-4
No files found.
examples/GAN/GAN.py
View file @
b82a1fda
...
@@ -8,8 +8,9 @@ import numpy as np
...
@@ -8,8 +8,9 @@ import numpy as np
import
time
import
time
from
tensorpack
import
(
Trainer
,
QueueInput
,
from
tensorpack
import
(
Trainer
,
QueueInput
,
ModelDescBase
,
DataFlow
,
StagingInputWrapper
,
ModelDescBase
,
DataFlow
,
StagingInputWrapper
,
MultiGPUTrainerBase
,
LeastLoadedDeviceSetter
,
MultiGPUTrainerBase
,
TowerContext
)
TowerContext
)
from
tensorpack.train.utility
import
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.argtools
import
memoized
...
...
tensorpack/graph_builder/model_desc.py
View file @
b82a1fda
...
@@ -32,30 +32,11 @@ class InputDesc(
...
@@ -32,30 +32,11 @@ class InputDesc(
shape (tuple):
shape (tuple):
name (str):
name (str):
"""
"""
shape
=
tuple
(
shape
)
# has to be tuple for
self
to be hashable
shape
=
tuple
(
shape
)
# has to be tuple for
"self"
to be hashable
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
.
_cached_placeholder
=
None
self
.
_cached_placeholder
=
None
return
self
return
self
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def
build_placeholder
(
self
,
prefix
=
''
):
def
build_placeholder
(
self
,
prefix
=
''
):
"""
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Build a tf.placeholder from the metadata, with an optional prefix.
...
...
tensorpack/train/base.py
View file @
b82a1fda
...
@@ -284,3 +284,51 @@ class Trainer(object):
...
@@ -284,3 +284,51 @@ class Trainer(object):
self
.
_predictor_factory
=
PredictorFactory
(
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
vs_name_for_predictor
)
self
.
model
,
self
.
vs_name_for_predictor
)
return
self
.
_predictor_factory
return
self
.
_predictor_factory
def
launch_train
(
run_step
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
):
"""
This is a simpler interface to start training after the graph has been built already.
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are uknown
only after the graph has been built.
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
.. code-block:: python
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
"""
assert
steps_per_epoch
is
not
None
,
steps_per_epoch
trainer
=
Trainer
(
TrainConfig
(
callbacks
=
callbacks
,
extra_callbacks
=
extra_callbacks
,
monitors
=
monitors
,
session_creator
=
session_creator
,
session_config
=
session_config
,
session_init
=
session_init
,
starting_epoch
=
starting_epoch
,
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
max_epoch
))
if
isinstance
(
run_step
,
(
tf
.
Tensor
,
tf
.
Operation
)):
trainer
.
train_op
=
run_step
else
:
assert
callable
(
run_step
),
run_step
trainer
.
run_step
=
lambda
self
:
run_step
()
trainer
.
train
()
tensorpack/train/config.py
View file @
b82a1fda
...
@@ -9,7 +9,6 @@ from ..callbacks import (
...
@@ -9,7 +9,6 @@ from ..callbacks import (
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..graph_builder.model_desc
import
ModelDescBase
from
..graph_builder.model_desc
import
ModelDescBase
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
...
@@ -69,9 +68,6 @@ class TrainConfig(object):
...
@@ -69,9 +68,6 @@ class TrainConfig(object):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
# process data & model
# process data & model
if
'dataset'
in
kwargs
:
dataflow
=
kwargs
.
pop
(
'dataset'
)
log_deprecated
(
"TrainConfig.dataset"
,
"Use TrainConfig.dataflow instead."
,
"2017-09-11"
)
assert
data
is
None
or
dataflow
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
assert
data
is
None
or
dataflow
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
if
dataflow
is
not
None
:
if
dataflow
is
not
None
:
assert_type
(
dataflow
,
DataFlow
)
assert_type
(
dataflow
,
DataFlow
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment