Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f85c3003
Commit
f85c3003
authored
Aug 02, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update documentation
parent
b785bf77
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
12 deletions
+31
-12
docs/conf.py
docs/conf.py
+1
-0
docs/tutorial/extend/callback.md
docs/tutorial/extend/callback.md
+2
-2
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+15
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+6
-2
tensorpack/train/base.py
tensorpack/train/base.py
+3
-2
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+4
-3
No files found.
docs/conf.py
View file @
f85c3003
...
@@ -359,6 +359,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -359,6 +359,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable'
,
'replace_get_variable'
,
'remap_get_variable'
,
'remap_get_variable'
,
'freeze_get_variable'
,
'freeze_get_variable'
,
'Triggerable'
,
'ParamRestore'
]:
'ParamRestore'
]:
return
True
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
...
...
docs/tutorial/extend/callback.md
View file @
f85c3003
## Write a
c
allback
## Write a
C
allback
The places where each callback method gets called is demonstrated in this snippet:
The places where each callback method gets called is demonstrated in this snippet:
...
@@ -20,7 +20,7 @@ def main_loop():
...
@@ -20,7 +20,7 @@ def main_loop():
callbacks
.
after_train
()
callbacks
.
after_train
()
```
```
### Explain the
callback m
ethods
### Explain the
Callback M
ethods
You can override any of the following methods to define a new callback:
You can override any of the following methods to define a new callback:
...
...
docs/tutorial/trainer.md
View file @
f85c3003
...
@@ -17,8 +17,8 @@ To use trainers, pass a `TrainConfig` to configure them:
...
@@ -17,8 +17,8 @@ To use trainers, pass a `TrainConfig` to configure them:
```
python
```
python
config
=
TrainConfig
(
config
=
TrainConfig
(
model
=
MyModel
()
model
=
MyModel
()
dataflow
=
my_dataflow
,
dataflow
=
my_dataflow
,
# data=my_inputsource,
# alternatively, use a customized InputSource
# data=my_inputsource,
# alternatively, use a customized InputSource
callbacks
=
[
...
]
callbacks
=
[
...
]
)
)
...
@@ -45,4 +45,16 @@ would be multiplied by the number of GPUs.
...
@@ -45,4 +45,16 @@ would be multiplied by the number of GPUs.
### Custom Trainers
### Custom Trainers
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
For example,
[
GAN trainer
](
../examples/GAN/GAN.py
)
minimizes two cost functions alternatively.
The existing trainers implement the default logic, but you can implement them yourself by using the base
`Trainer`
class.
*
Two ways to customize the graph:
1.
Create the graph, add any tensors and ops before creating the trainer.
2.
Subclass
`Trainer`
and override the
`_setup()`
method which will be called in
`Trainer.__init__`
.
*
Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method.
There are several different
[
GAN trainers
](
../examples/GAN/GAN.py
)
for reference.
tensorpack/callbacks/base.py
View file @
f85c3003
...
@@ -13,7 +13,10 @@ __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
...
@@ -13,7 +13,10 @@ __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks.
""" Base class for all callbacks. See
`Write a Callback
<http://tensorpack.readthedocs.io/en/latest/tutorial/extend/callback.html>`_
for more detailed explanation of the callback methods.
Attributes:
Attributes:
epoch_num(int): the number of the current epoch.
epoch_num(int): the number of the current epoch.
...
@@ -261,7 +264,8 @@ class CallbackFactory(Callback):
...
@@ -261,7 +264,8 @@ class CallbackFactory(Callback):
"""
"""
Each lambda takes ``self`` as the only argument.
Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated.
Note:
trigger_epoch was deprecated.
"""
"""
self
.
_cb_setup_graph
=
setup_graph
self
.
_cb_setup_graph
=
setup_graph
...
...
tensorpack/train/base.py
View file @
f85c3003
...
@@ -37,8 +37,6 @@ class Trainer(object):
...
@@ -37,8 +37,6 @@ class Trainer(object):
sess (tf.Session): the current session in use.
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
local_step (int): the number of steps that have finished in the current epoch.
"""
"""
# step attr only available after before_train?
# step attr only available after before_train?
...
@@ -64,6 +62,9 @@ class Trainer(object):
...
@@ -64,6 +62,9 @@ class Trainer(object):
@
property
@
property
def
epoch_num
(
self
):
def
epoch_num
(
self
):
"""
The number of epochs that have finished.
"""
if
self
.
_epoch_num
is
not
None
:
if
self
.
_epoch_num
is
not
None
:
# has started training
# has started training
return
self
.
_epoch_num
return
self
.
_epoch_num
...
...
tensorpack/train/multigpu.py
View file @
f85c3003
...
@@ -18,10 +18,11 @@ from ..callbacks.graph import RunOp
...
@@ -18,10 +18,11 @@ from ..callbacks.graph import RunOp
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
.feedfree
import
FeedfreeTrainerBase
from
.feedfree
import
FeedfreeTrainerBase
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainer'
,
__all__
=
[
'MultiGPUTrainerBase'
,
'LeastLoadedDeviceSetter'
,
'AsyncMultiGPUTrainer'
,
'LeastLoadedDeviceSetter'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
]
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
def
_check_tf_version
():
def
_check_tf_version
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment