Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a976e871
Commit
a976e871
authored
Jan 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split-out callbacks dir
parent
eea48e2e
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
180 additions
and
136 deletions
+180
-136
example_cifar10.py
example_cifar10.py
+1
-2
example_mnist.py
example_mnist.py
+1
-2
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+19
-0
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+62
-0
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+48
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+40
-94
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+6
-5
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+2
-2
tensorpack/train.py
tensorpack/train.py
+1
-1
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+0
-30
No files found.
example_cifar10.py
View file @
a976e871
...
@@ -13,8 +13,7 @@ from tensorpack.models import *
...
@@ -13,8 +13,7 @@ from tensorpack.models import *
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
imgaug
from
tensorpack.dataflow
import
imgaug
...
...
example_mnist.py
View file @
a976e871
...
@@ -15,8 +15,7 @@ from tensorpack.models import *
...
@@ -15,8 +15,7 @@ from tensorpack.models import *
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
BATCH_SIZE
=
128
BATCH_SIZE
=
128
...
...
tensorpack/callbacks/__init__.py
0 → 100644
View file @
a976e871
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
walk_packages
import
os
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
())
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
tensorpack/callbacks/base.py
0 → 100644
View file @
a976e871
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
sys
import
os
import
time
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
class
Callback
(
object
):
__metaclass__
=
ABCMeta
running_graph
=
'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def
before_train
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
self
.
_before_train
()
def
_before_train
(
self
):
"""
Called before starting iterative training
"""
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
"""
def
trigger_epoch
(
self
):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class
PeriodicCallback
(
Callback
):
def
__init__
(
self
,
period
):
self
.
__period
=
period
self
.
epoch_num
=
0
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
if
self
.
epoch_num
%
self
.
__period
==
0
:
self
.
global_step
=
get_global_step
()
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
pass
tensorpack/callbacks/common.py
0 → 100644
View file @
a976e871
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
os
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
__all__
=
[
'PeriodicSaver'
,
'SummaryWriter'
]
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
global_step
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
):
self
.
log_dir
=
logger
.
LOG_DIR
def
_before_train
(
self
):
self
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
trigger_epoch
(
self
):
# check if there is any summary
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
self
.
writer
.
add_summary
(
summary_str
,
get_global_step
())
tensorpack/
utils/callback
.py
→
tensorpack/
callbacks/group
.py
View file @
a976e871
#!/usr/bin/env python2
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File:
callback
.py
# File:
group
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
sys
from
contextlib
import
contextmanager
import
numpy
as
np
import
os
from
.base
import
Callback
import
time
from
.common
import
*
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
*
from
.
import
create_test_session
,
get_global_step
__all__
=
[
'Callbacks'
]
from
.naming
import
*
import
logger
@
contextmanager
def
create_test_graph
():
class
Callback
(
object
):
G
=
tf
.
get_default_graph
()
__metaclass__
=
ABCMeta
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
running_graph
=
'train'
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
""" The graph that this callback should run on.
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
Either 'train' or 'test'
# create a global step var in test graph
"""
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
def
before_train
(
self
):
input_vars
=
[]
self
.
graph
=
tf
.
get_default_graph
()
for
v
in
input_vars_train
:
self
.
sess
=
tf
.
get_default_session
()
name
=
v
.
name
self
.
_before_train
()
assert
name
.
endswith
(
':0'
),
"I think placeholder variable should all ends with ':0'"
name
=
name
[:
-
2
]
def
_before_train
(
self
):
input_vars
.
append
(
tf
.
placeholder
(
"""
v
.
dtype
,
shape
=
v
.
get_shape
(),
name
=
name
Called before starting iterative training
))
"""
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
output_vars
,
cost
=
forward_func
(
input_vars
,
is_training
=
False
)
"""
for
v
in
output_vars
:
Callback to be triggered after every step (every backpropagation)
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
Args:
yield
Gtest
inputs: the list of input values
outputs: list of output values after running this inputs
@
contextmanager
cost: the cost value after running this input
def
create_test_session
():
"""
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
def
trigger_epoch
(
self
):
yield
sess
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class
PeriodicCallback
(
Callback
):
def
__init__
(
self
,
period
):
self
.
__period
=
period
self
.
epoch_num
=
0
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
if
self
.
epoch_num
%
self
.
__period
==
0
:
self
.
global_step
=
get_global_step
()
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
pass
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
global_step
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
):
self
.
log_dir
=
logger
.
LOG_DIR
def
_before_train
(
self
):
self
.
writer
=
tf
.
train
.
SummaryWriter
(
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
tf
.
add_to_collection
(
SUMMARY_WRITER_COLLECTION_KEY
,
self
.
writer
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
trigger_epoch
(
self
):
# check if there is any summary
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
self
.
writer
.
add_summary
(
summary_str
,
get_global_step
())
class
CallbackTimeLogger
(
object
):
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -126,7 +73,7 @@ class TrainCallbacks(Callback):
...
@@ -126,7 +73,7 @@ class TrainCallbacks(Callback):
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
self
.
cbs
.
insert
(
0
,
self
.
cbs
.
pop
(
idx
))
break
break
else
:
else
:
raise
Runtim
eError
(
"Callbacks must contain a SummaryWriter!"
)
raise
Valu
eError
(
"Callbacks must contain a SummaryWriter!"
)
def
before_train
(
self
):
def
before_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
...
@@ -199,7 +146,7 @@ class Callbacks(Callback):
...
@@ -199,7 +146,7 @@ class Callbacks(Callback):
elif
cb
.
running_graph
==
'train'
:
elif
cb
.
running_graph
==
'train'
:
train_cbs
.
append
(
cb
)
train_cbs
.
append
(
cb
)
else
:
else
:
raise
Runtim
eError
(
raise
Valu
eError
(
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
"Unknown callback running graph {}!"
.
format
(
cb
.
running_graph
))
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
train
=
TrainCallbacks
(
train_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
self
.
test
=
TestCallbacks
(
test_cbs
)
...
@@ -216,4 +163,3 @@ class Callbacks(Callback):
...
@@ -216,4 +163,3 @@ class Callbacks(Callback):
self
.
train
.
trigger_epoch
()
self
.
train
.
trigger_epoch
()
# TODO test callbacks can be run async?
# TODO test callbacks can be run async?
self
.
test
.
trigger_epoch
()
self
.
test
.
trigger_epoch
()
tensorpack/
util
s/validation_callback.py
→
tensorpack/
callback
s/validation_callback.py
View file @
a976e871
...
@@ -6,11 +6,12 @@
...
@@ -6,11 +6,12 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
.stat
import
*
from
..utils
import
*
from
.callback
import
PeriodicCallback
,
Callback
from
..utils.stat
import
*
from
.naming
import
*
from
..utils.summary
import
*
from
.summary
import
*
from
.base
import
PeriodicCallback
,
Callback
import
logger
__all__
=
[
'ValidationError'
]
class
ValidationError
(
PeriodicCallback
):
class
ValidationError
(
PeriodicCallback
):
running_graph
=
'test'
running_graph
=
'test'
...
...
tensorpack/models/conv2d.py
View file @
a976e871
...
@@ -18,7 +18,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -18,7 +18,7 @@ def Conv2D(x, out_channel, kernel_shape,
kernel_shape: (h, w) or a int
kernel_shape: (h, w) or a int
stride: (h, w) or a int
stride: (h, w) or a int
padding: 'valid' or 'same'
padding: 'valid' or 'same'
split: split channels. used in
a
lexnet
split: split channels. used in
A
lexnet
"""
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
in_shape
=
x
.
get_shape
()
.
as_list
()
in_channel
=
in_shape
[
-
1
]
in_channel
=
in_shape
[
-
1
]
...
@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride
=
shape4d
(
stride
)
stride
=
shape4d
(
stride
)
if
W_init
is
None
:
if
W_init
is
None
:
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
1e-4
)
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
4e-3
)
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
...
...
tensorpack/train.py
View file @
a976e871
...
@@ -10,7 +10,7 @@ import argparse
...
@@ -10,7 +10,7 @@ import argparse
import
tqdm
import
tqdm
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.callback
import
Callbacks
from
callbacks
import
*
from
utils.summary
import
summary_moving_average
from
utils.summary
import
summary_moving_average
from
utils.modelutils
import
describe_model
from
utils.modelutils
import
describe_model
from
utils
import
logger
from
utils
import
logger
...
...
tensorpack/utils/__init__.py
View file @
a976e871
...
@@ -31,36 +31,6 @@ def timed_operation(msg, log_start=False):
...
@@ -31,36 +31,6 @@ def timed_operation(msg, log_start=False):
logger
.
info
(
'finished {}, time={:.2f}sec.'
.
format
(
logger
.
info
(
'finished {}, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
msg
,
time
.
time
()
-
start
))
@
contextmanager
def
create_test_graph
():
G
=
tf
.
get_default_graph
()
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
# create a global step var in test graph
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
input_vars
=
[]
for
v
in
input_vars_train
:
name
=
v
.
name
assert
name
.
endswith
(
':0'
),
"I think placeholder variable should all ends with ':0'"
name
=
name
[:
-
2
]
input_vars
.
append
(
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
(),
name
=
name
))
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
output_vars
,
cost
=
forward_func
(
input_vars
,
is_training
=
False
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
@
contextmanager
def
create_test_session
():
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
yield
sess
def
get_default_sess_config
():
def
get_default_sess_config
():
"""
"""
Return a better config to use as default.
Return a better config to use as default.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment