Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
260936ff
Commit
260936ff
authored
Oct 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix imports; Add call_only_once decorator;
parent
9fff46d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
8 deletions
+59
-8
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+2
-1
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+9
-0
tensorpack/trainv2/trainers.py
tensorpack/trainv2/trainers.py
+8
-6
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+40
-1
No files found.
tensorpack/input_source/input_source_base.py
View file @
260936ff
...
...
@@ -9,7 +9,7 @@ from six.moves import zip
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
,
call_only_once
from
..callbacks.base
import
CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
...
...
@@ -85,6 +85,7 @@ class InputSource(object):
def
_get_input_tensors
(
self
):
pass
@
call_only_once
def
setup
(
self
,
inputs_desc
):
"""
Args:
...
...
tensorpack/trainv2/base.py
View file @
260936ff
...
...
@@ -10,6 +10,7 @@ import six
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
...
...
@@ -73,6 +74,7 @@ class Trainer(object):
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
@
call_only_once
def
setup_callbacks
(
self
,
callbacks
,
monitors
):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
...
...
@@ -92,6 +94,7 @@ class Trainer(object):
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
@
call_only_once
def
initialize
(
self
,
session_creator
,
session_init
):
"""
Initialize self.sess and self.hooked_sess.
...
...
@@ -120,6 +123,7 @@ class Trainer(object):
and self.hooked_sess (the session with hooks and coordinator)
"""
@
call_only_once
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
=
1
,
max_epoch
=
99999
):
"""
Run the main training loop.
...
...
@@ -213,6 +217,10 @@ class SingleCostTrainer(Trainer):
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks
=
callbacks
+
self
.
_internal_callbacks
Trainer
.
train
(
self
,
...
...
@@ -220,6 +228,7 @@ class SingleCostTrainer(Trainer):
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Build the main training graph. Defaults to do nothing.
...
...
tensorpack/trainv2/trainers.py
View file @
260936ff
...
...
@@ -6,17 +6,19 @@ import os
from
..callbacks.graph
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..utils
import
logger
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..input_source
import
QueueInput
from
..graph_builder.training
import
(
SimpleBuilder
,
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
AsyncMultiGPUBuilder
,
DistributedReplicatedBuilder
)
AsyncMultiGPUBuilder
)
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
..utils
import
logger
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..input_source
import
QueueInput
from
.base
import
SingleCostTrainer
...
...
tensorpack/utils/argtools.py
View file @
260936ff
...
...
@@ -12,7 +12,7 @@ else:
import
functools
__all__
=
[
'map_arg'
,
'memoized'
,
'graph_memoized'
,
'shape2d'
,
'shape4d'
,
'memoized_ignoreargs'
,
'log_once'
]
'memoized_ignoreargs'
,
'log_once'
,
'call_only_once'
]
def
map_arg
(
**
maps
):
...
...
@@ -140,3 +140,42 @@ def log_once(message, func):
func(str): the name of the logger method. e.g. "info", "warn", "error".
"""
getattr
(
logger
,
func
)(
message
)
_FUNC_CALLED
=
set
()
def
call_only_once
(
func
):
"""
Decorate a method of a class, so that this method can only
be called once for every instance.
Calling it more than once will result in exception.
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
self
=
args
[
0
]
assert
hasattr
(
self
,
func
.
__name__
),
"call_only_once can only be used on method!"
key
=
(
self
,
func
)
assert
key
not
in
_FUNC_CALLED
,
\
"Method {}.{} can only be called once per object!"
.
format
(
type
(
self
)
.
__name__
,
func
.
__name__
)
_FUNC_CALLED
.
add
(
key
)
func
(
*
args
,
**
kwargs
)
return
wrapper
if
__name__
==
'__main__'
:
class
A
():
@
call_only_once
def
f
(
self
,
x
):
print
(
x
)
a
=
A
()
a
.
f
(
1
)
b
=
A
()
b
.
f
(
2
)
b
.
f
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment