Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e28d616e
Commit
e28d616e
authored
Aug 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
statholder move to callbacks
parent
6edb1f0d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
8 deletions
+15
-8
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+8
-0
tensorpack/callbacks/stat.py
tensorpack/callbacks/stat.py
+4
-0
tensorpack/train/base.py
tensorpack/train/base.py
+0
-2
tensorpack/train/config.py
tensorpack/train/config.py
+3
-6
No files found.
tensorpack/callbacks/group.py
View file @
e28d616e
...
@@ -108,6 +108,14 @@ class Callbacks(Callback):
...
@@ -108,6 +108,14 @@ class Callbacks(Callback):
if
not
isinstance
(
cb
.
type
,
(
TrainCallbackType
,
TestCallbackType
)):
if
not
isinstance
(
cb
.
type
,
(
TrainCallbackType
,
TestCallbackType
)):
raise
ValueError
(
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
# move "StatPrinter" to the last
for
cb
in
cbs
:
if
isinstance
(
cb
,
StatPrinter
):
sp
=
cb
cbs
.
remove
(
sp
)
cbs
.
append
(
sp
)
break
print
(
cbs
)
self
.
cbs
=
cbs
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
self
.
test_callback_context
=
TestCallbackContext
()
...
...
tensorpack/callbacks/s
ummary
.py
→
tensorpack/callbacks/s
tat
.py
View file @
e28d616e
...
@@ -103,3 +103,7 @@ class StatPrinter(Callback):
...
@@ -103,3 +103,7 @@ class StatPrinter(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
trainer
.
stat_holder
.
set_print_tag
(
self
.
print_tag
)
self
.
trainer
.
stat_holder
.
set_print_tag
(
self
.
print_tag
)
def
_trigger_epoch
(
self
):
self
.
trainer
.
stat_holder
.
add_stat
(
'global_step'
,
self
.
global_step
)
self
.
trainer
.
stat_holder
.
finalize
()
tensorpack/train/base.py
View file @
e28d616e
...
@@ -68,8 +68,6 @@ class Trainer(object):
...
@@ -68,8 +68,6 @@ class Trainer(object):
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
self
.
summary_writer
.
flush
()
self
.
stat_holder
.
add_stat
(
'global_step'
,
self
.
global_step
)
self
.
stat_holder
.
finalize
()
@
abstractmethod
@
abstractmethod
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
...
tensorpack/train/config.py
View file @
e28d616e
...
@@ -21,13 +21,12 @@ class TrainConfig(object):
...
@@ -21,13 +21,12 @@ class TrainConfig(object):
:param dataset: the dataset to train. a `DataFlow` instance.
:param dataset: the dataset to train. a `DataFlow` instance.
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training. It has to contain a
the callbacks to perform during training.
SummaryWriter and a PeriodicSaver
:param session_config: a `tf.ConfigProto` instance to instantiate the
:param session_config: a `tf.ConfigProto` instance to instantiate the
session. default to a session running 1 GPU.
session. default to a session running 1 GPU.
:param session_init: a `sessinit.SessionInit` instance to
:param session_init: a `sessinit.SessionInit` instance to
initialize variables of a session. default to a new session.
initialize variables of a session. default to a new session.
:param model: a `ModelDesc` instance.
j
:param model: a `ModelDesc` instance.
:param starting_epoch: int. default to be 1.
:param starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to inf
:param max_epoch: maximum number of epoch to run training. default to inf
...
@@ -63,9 +62,7 @@ class TrainConfig(object):
...
@@ -63,9 +62,7 @@ class TrainConfig(object):
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
**
kwargs
):
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
nr_tower
=
kwargs
.
pop
(
'nr_tower'
,
None
)
tower
=
kwargs
.
pop
(
'tower'
,
None
)
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
if
nr_tower
:
if
nr_tower
:
tower
=
list
(
range
(
nr_tower
))
tower
=
list
(
range
(
nr_tower
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment