Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d43c8a28
Commit
d43c8a28
authored
Oct 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
updates about `launch_train`
parent
b82a1fda
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
11 deletions
+28
-11
docs/tutorial/index.rst
docs/tutorial/index.rst
+11
-3
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+4
-0
tensorpack/train/base.py
tensorpack/train/base.py
+12
-7
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+1
-1
No files found.
docs/tutorial/index.rst
View file @
d43c8a28
...
@@ -8,12 +8,13 @@ A High Level Glance
...
@@ -8,12 +8,13 @@ A High Level Glance
.. image:: https://user-images.githubusercontent.com/1381301/29187907-2caaa740-7dc6-11e7-8220-e20ca52c3ca6.png
.. image:: https://user-images.githubusercontent.com/1381301/29187907-2caaa740-7dc6-11e7-8220-e20ca52c3ca6.png
*
DataFlow
is a library to load data efficiently in Python.
*
``DataFlow``
is a library to load data efficiently in Python.
Apart from DataFlow, native TF operators can be used for data loading as well.
Apart from DataFlow, native TF operators can be used for data loading as well.
They will eventually be wrapped under the same interface and go through prefetching.
They will eventually be wrapped under the same
``InputSource``
interface and go through prefetching.
* You can use any TF-based symbolic function library to define a model, including
* You can use any TF-based symbolic function library to define a model, including
a small set of models within tensorpack. ``ModelDesc`` is an interface to connect symbolic graph to tensorpack trainers.
a small set of models within tensorpack. ``ModelDesc`` is an interface to connect the graph with the
``InputSource`` interface.
* tensorpack trainers manage the training loops for you.
* tensorpack trainers manage the training loops for you.
They also include data parallel logic for multi-GPU or distributed training.
They also include data parallel logic for multi-GPU or distributed training.
...
@@ -22,6 +23,13 @@ A High Level Glance
...
@@ -22,6 +23,13 @@ A High Level Glance
* Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training,
* Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training,
everything you want to do other than the main iterations can be defined through callbacks and easily reused.
everything you want to do other than the main iterations can be defined through callbacks and easily reused.
* All the components, though work perfectly together, are highly decorrelated: you can:
* Use DataFlow alone as a data loading library, without tensorfow at all.
* Use tensorpack to build the graph with multi-GPU or distributed support,
then train it with your own loops.
* Build the graph on your own, and train it with tensorpack callbacks.
User Tutorials
User Tutorials
========================
========================
...
...
tensorpack/callbacks/inference_runner.py
View file @
d43c8a28
...
@@ -58,6 +58,8 @@ class InferenceRunnerBase(Callback):
...
@@ -58,6 +58,8 @@ class InferenceRunnerBase(Callback):
""" Base class for inference runner.
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
Please note that InferenceRunner will use `input.size()` to determine
how much iterations to run, so you want it to be accurate.
how much iterations to run, so you want it to be accurate.
Also, InferenceRunner assumes that `trainer.model` exists.
"""
"""
def
__init__
(
self
,
input
,
infs
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
extra_hooks
=
None
):
"""
"""
...
@@ -120,6 +122,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -120,6 +122,7 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
...
@@ -178,6 +181,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -178,6 +181,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
assert
self
.
trainer
.
model
is
not
None
cbs
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
# build each predict tower
self
.
_handles
=
[]
self
.
_handles
=
[]
...
...
tensorpack/train/base.py
View file @
d43c8a28
...
@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession
...
@@ -20,7 +20,7 @@ from ..tfutils.sessinit import JustCurrentSession
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..graph_builder.predictor_factory
import
PredictorFactory
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
,
'launch_train'
]
class
StopTraining
(
BaseException
):
class
StopTraining
(
BaseException
):
...
@@ -287,21 +287,25 @@ class Trainer(object):
...
@@ -287,21 +287,25 @@ class Trainer(object):
def
launch_train
(
def
launch_train
(
run_step
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
run_step
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
):
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
):
"""
"""
This is a
simpler interface to start training after
the graph has been built already.
This is a
nother trainer interface, to start training **after**
the graph has been built already.
You can build the graph however you like
You can build the graph however you like
(with or without tensorpack), and invoke this function to start training.
(with or without tensorpack), and invoke this function to start training
with callbacks & monitors
.
This provides the flexibility to define the training config after graph has been buit.
This provides the flexibility to define the training config after graph has been buit.
One typical use is that callbacks often depend on names that are
u
known
One typical use is that callbacks often depend on names that are
not
known
only
after
the graph has been built.
only
until
the graph has been built.
Args:
Args:
run_step (tf.Tensor or function): Define what the training iteration is.
run_step (tf.Tensor or function): Define what the training iteration is.
If given a Tensor/Operation, will eval it every step.
If given a Tensor/Operation, will eval it every step.
If given a function, will invoke this function under the default session in every step.
If given a function, will invoke this function under the default session in every step.
model (None or ModelDesc): Certain callbacks (e.g. InferenceRunner) depends on
the existence of :class:`ModelDesc`. If you use a :class:`ModelDesc` to
build the graph, add it here to to allow those callbacks to work.
If you didn't use :class:`ModelDesc`, leave it empty.
Other arguments are the same as in :class:`TrainConfig`.
Other arguments are the same as in :class:`TrainConfig`.
Examples:
Examples:
...
@@ -310,13 +314,14 @@ def launch_train(
...
@@ -310,13 +314,14 @@ def launch_train(
model = MyModelDesc()
model = MyModelDesc()
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
train_op, cbs = SimpleTrainer.setup_graph(model, QueueInput(mydataflow))
launch_train(train_op, callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
launch_train(train_op,
model=model,
callbacks=[...] + cbs, steps_per_epoch=mydataflow.size())
# the above is equivalent to:
# the above is equivalent to:
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
config = TrainConfig(model=MyModelDesc(), data=QueueInput(mydataflow) callbacks=[...])
SimpleTrainer(config).train()
SimpleTrainer(config).train()
"""
"""
assert
steps_per_epoch
is
not
None
,
steps_per_epoch
assert
steps_per_epoch
is
not
None
,
steps_per_epoch
trainer
=
Trainer
(
TrainConfig
(
trainer
=
Trainer
(
TrainConfig
(
model
=
model
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
extra_callbacks
=
extra_callbacks
,
extra_callbacks
=
extra_callbacks
,
monitors
=
monitors
,
monitors
=
monitors
,
...
...
tensorpack/utils/logger.py
View file @
d43c8a28
...
@@ -64,7 +64,7 @@ def _set_file(path):
...
@@ -64,7 +64,7 @@ def _set_file(path):
if
os
.
path
.
isfile
(
path
):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
_get_time_str
()
backup_name
=
path
+
'.'
+
_get_time_str
()
shutil
.
move
(
path
,
backup_name
)
shutil
.
move
(
path
,
backup_name
)
info
(
"L
og file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
# noqa: F821
_logger
.
info
(
"Existing l
og file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
# noqa: F821
hdl
=
logging
.
FileHandler
(
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
hdl
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
hdl
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment