Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e465842d
Commit
e465842d
authored
Jun 09, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs
parent
c346e924
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
16 deletions
+34
-16
docs/tutorial/extend/trainer.md
docs/tutorial/extend/trainer.md
+10
-2
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+1
-1
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+0
-3
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+9
-6
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+13
-3
No files found.
docs/tutorial/extend/trainer.md
View file @
e465842d
...
@@ -68,11 +68,19 @@ You can customize the trainer by either using or inheriting the base `Trainer` c
...
@@ -68,11 +68,19 @@ You can customize the trainer by either using or inheriting the base `Trainer` c
You will need to do two things for a new Trainer:
You will need to do two things for a new Trainer:
1.
Define the graph. There are 2 ways you can do this:
1.
Define the graph. There are 2 ways you can do this:
1.
Create any tensors and ops you
like
, before creating the trainer.
1.
Create any tensors and ops you
need
, before creating the trainer.
2.
Create them inside
`Trainer.__init__`
.
2.
Create them inside
`Trainer.__init__`
.
2.
Define what is the iteration. There are 2 ways to define the iteration:
2.
Define what is the iteration. There are 2 ways to define the iteration:
1.
Set
`Trainer.train_op`
to a TensorFlow operation. This op will be run by default.
1.
Set
`Trainer.train_op`
to a TensorFlow operation. This op will be run by default.
2.
Subclass
`Trainer`
and override the
`run_step()`
method. This way you can do something more than running an op.
2.
Subclass
`Trainer`
and override the
`run_step()`
method. This way you can
do something more than running an op.
Note that trainer has `self.sess` and `self.hooked_sess`: only the hooked
session will trigger the `before_run`/`after_run` callbacks.
If you need more than one `Session.run` in one steps, special care needs
to be taken to choose which session to use, because many states
(global steps, StagingArea, summaries) are maintained through `before_run`/`after_run`.
There are several different
[
GAN trainers
](
../../examples/GAN/GAN.py
)
for reference.
There are several different
[
GAN trainers
](
../../examples/GAN/GAN.py
)
for reference.
examples/FasterRCNN/train.py
View file @
e465842d
...
@@ -418,7 +418,7 @@ class ResNetFPNModel(DetectionModel):
...
@@ -418,7 +418,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss
=
0.0
mrcnn_loss
=
0.0
wd_cost
=
regularize_cost
(
wd_cost
=
regularize_cost
(
'(?:group1|group2|group3|rpn|f
pn|f
astrcnn|maskrcnn)/.*W'
,
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W'
,
l2_regularizer
(
1e-4
),
name
=
'wd_cost'
)
l2_regularizer
(
1e-4
),
name
=
'wd_cost'
)
total_cost
=
tf
.
add_n
(
rpn_loss_collection
+
[
total_cost
=
tf
.
add_n
(
rpn_loss_collection
+
[
...
...
tensorpack/callbacks/hooks.py
View file @
e465842d
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'CallbackToHook'
,
'HookToCallback'
]
__all__
=
[
'CallbackToHook'
,
'HookToCallback'
]
...
@@ -17,8 +16,6 @@ class CallbackToHook(tf.train.SessionRunHook):
...
@@ -17,8 +16,6 @@ class CallbackToHook(tf.train.SessionRunHook):
You shouldn't need to use this.
You shouldn't need to use this.
"""
"""
_chief_only
=
False
def
__init__
(
self
,
cb
):
def
__init__
(
self
,
cb
):
self
.
_cb
=
cb
self
.
_cb
=
cb
...
...
tensorpack/dataflow/remote.py
View file @
e465842d
...
@@ -22,18 +22,19 @@ else:
...
@@ -22,18 +22,19 @@ else:
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
format
=
None
,
bind
=
False
):
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
format
=
None
,
bind
=
False
):
"""
"""
Run DataFlow and send data to a ZMQ socket addr.
Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr,
It will serialize and send each datapoint to this address with a PUSH socket.
serialize and send each datapoint to this addr with a PUSH socket.
This function never returns.
This function never returns unless an error is encountered.
Args:
Args:
df (DataFlow): Will infinitely loop over the DataFlow.
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket endpoint.
addr: a ZMQ socket endpoint.
hwm (int): ZMQ high-water mark (buffer size)
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`.
Default format uses :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops.
This format works with :class:`dataflow.RemoteDataZMQ`.
bind (bool): whether to bind or connect to the endpoint.
An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops
and :class:`input_source.ZMQInput`.
bind (bool): whether to bind or connect to the endpoint address.
"""
"""
assert
format
in
[
None
,
'zmq_op'
,
'zmq_ops'
]
assert
format
in
[
None
,
'zmq_op'
,
'zmq_ops'
]
if
format
is
None
:
if
format
is
None
:
...
@@ -82,6 +83,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
...
@@ -82,6 +83,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
class
RemoteDataZMQ
(
DataFlow
):
class
RemoteDataZMQ
(
DataFlow
):
"""
"""
Produce data from ZMQ PULL socket(s).
Produce data from ZMQ PULL socket(s).
It is the receiver-side counterpart of :func:`send_dataflow_zmq`, which uses :mod:`tensorpack.utils.serialize`
for serialization.
See http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html#distributed-dataflow
See http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html#distributed-dataflow
Attributes:
Attributes:
...
...
tensorpack/input_source/input_source.py
View file @
e465842d
...
@@ -373,7 +373,7 @@ class DummyConstantInput(TensorInput):
...
@@ -373,7 +373,7 @@ class DummyConstantInput(TensorInput):
class
ZMQInput
(
TensorInput
):
class
ZMQInput
(
TensorInput
):
"""
"""
Receive tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
Receive tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op
s
')`.
"""
"""
def
__init__
(
self
,
end_point
,
hwm
,
bind
=
True
):
def
__init__
(
self
,
end_point
,
hwm
,
bind
=
True
):
"""
"""
...
...
tensorpack/train/base.py
View file @
e465842d
...
@@ -196,10 +196,8 @@ class Trainer(object):
...
@@ -196,10 +196,8 @@ class Trainer(object):
logger
.
info
(
"Creating the session ..."
)
logger
.
info
(
"Creating the session ..."
)
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
session_creator
.
create_session
()
self
.
sess
=
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
self
.
initialize_hooks
()
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
if
self
.
is_chief
:
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
logger
.
info
(
"Initializing the session ..."
)
...
@@ -211,6 +209,18 @@ class Trainer(object):
...
@@ -211,6 +209,18 @@ class Trainer(object):
self
.
sess
.
graph
.
finalize
()
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
logger
.
info
(
"Graph Finalized."
)
@
call_only_once
def
initialize_hooks
(
self
):
"""
Create SessionRunHooks for all callbacks, and hook it onto self.sess.
A new trainer may override this method to create multiple groups of hooks,
which can be useful when the training is not done by a single `train_op`.
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
@
call_only_once
@
call_only_once
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment