Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bbe8f42a
Commit
bbe8f42a
authored
Aug 31, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove test callback. add .print in linearwrap
parent
7a110067
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
52 deletions
+26
-52
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+0
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-13
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+6
-29
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+4
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+15
-9
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
bbe8f42a
...
@@ -102,7 +102,6 @@ class Model(ModelDesc):
...
@@ -102,7 +102,6 @@ class Model(ModelDesc):
logits
=
(
LinearWrap
(
image
)
logits
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
96
,
12
,
stride
=
4
,
padding
=
'VALID'
)
.
Conv2D
(
'conv0'
,
96
,
12
,
stride
=
4
,
padding
=
'VALID'
)
.
apply
(
activate
)
.
apply
(
activate
)
.
Conv2D
(
'conv1'
,
256
,
5
,
padding
=
'SAME'
,
split
=
2
)
.
Conv2D
(
'conv1'
,
256
,
5
,
padding
=
'SAME'
,
split
=
2
)
.
apply
(
fg
)
.
apply
(
fg
)
.
BatchNorm
(
'bn1'
)
.
BatchNorm
(
'bn1'
)
...
...
tensorpack/callbacks/base.py
View file @
bbe8f42a
...
@@ -10,24 +10,12 @@ from abc import abstractmethod, ABCMeta
...
@@ -10,24 +10,12 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'TrainCallbackType'
,
'TestCallbackType'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
class
TrainCallbackType
(
object
):
pass
class
TestCallbackType
(
object
):
pass
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
""" Base class for all callbacks """
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
type
=
TrainCallbackType
()
""" Determine the graph that this callback should run on.
Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
"""
def
before_train
(
self
):
def
before_train
(
self
):
"""
"""
Called right before the first iteration.
Called right before the first iteration.
...
...
tensorpack/callbacks/group.py
View file @
bbe8f42a
...
@@ -6,7 +6,7 @@ import tensorflow as tf
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
time
import
time
from
.base
import
Callback
,
TrainCallbackType
,
TestCallbackType
from
.base
import
Callback
from
.stat
import
*
from
.stat
import
*
from
..utils
import
*
from
..utils
import
*
...
@@ -50,9 +50,6 @@ class Callbacks(Callback):
...
@@ -50,9 +50,6 @@ class Callbacks(Callback):
# check type
# check type
for
cb
in
cbs
:
for
cb
in
cbs
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
if
not
isinstance
(
cb
.
type
,
(
TrainCallbackType
,
TestCallbackType
)):
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
# move "StatPrinter" to the last
# move "StatPrinter" to the last
for
cb
in
cbs
:
for
cb
in
cbs
:
if
isinstance
(
cb
,
StatPrinter
):
if
isinstance
(
cb
,
StatPrinter
):
...
@@ -62,24 +59,15 @@ class Callbacks(Callback):
...
@@ -62,24 +59,15 @@ class Callbacks(Callback):
break
break
self
.
cbs
=
cbs
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
cb
.
setup_graph
(
self
.
trainer
)
cb
.
setup_graph
(
self
.
trainer
)
else
:
with
self
.
test_callback_context
.
create_context
(
self
.
trainer
):
cb
.
setup_graph
(
self
.
trainer
)
def
_before_train
(
self
):
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
cb
.
before_train
()
cb
.
before_train
()
else
:
with
self
.
test_callback_context
.
test_context
():
cb
.
before_train
()
def
_after_train
(
self
):
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
...
@@ -87,9 +75,7 @@ class Callbacks(Callback):
...
@@ -87,9 +75,7 @@ class Callbacks(Callback):
def
trigger_step
(
self
):
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
cb
.
trigger_step
()
cb
.
trigger_step
()
# test callback don't have trigger_step
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
...
@@ -97,15 +83,6 @@ class Callbacks(Callback):
...
@@ -97,15 +83,6 @@ class Callbacks(Callback):
test_sess_restored
=
False
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
display_name
=
str
(
cb
)
display_name
=
str
(
cb
)
if
isinstance
(
cb
.
type
,
TrainCallbackType
):
with
tm
.
timed_callback
(
display_name
):
with
tm
.
timed_callback
(
display_name
):
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
else
:
if
not
test_sess_restored
:
with
tm
.
timed_callback
(
'restore checkpoint'
):
self
.
test_callback_context
.
restore_checkpoint
()
test_sess_restored
=
True
with
self
.
test_callback_context
.
test_context
(),
\
tm
.
timed_callback
(
display_name
):
cb
.
trigger_epoch
()
tm
.
log
()
tm
.
log
()
tensorpack/models/__init__.py
View file @
bbe8f42a
...
@@ -70,4 +70,8 @@ class LinearWrap(object):
...
@@ -70,4 +70,8 @@ class LinearWrap(object):
def
tensor
(
self
):
def
tensor
(
self
):
return
self
.
_t
return
self
.
_t
def
print
(
self
):
print
(
self
.
_t
)
return
self
tensorpack/models/batch_norm.py
View file @
bbe8f42a
...
@@ -55,7 +55,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -55,7 +55,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
batch_var
=
tf
.
identity
(
batch_var
,
'variance'
)
batch_var
=
tf
.
identity
(
batch_var
,
'variance'
)
emaname
=
'EMA'
emaname
=
'EMA'
ctx
=
get_current_
model
_context
()
ctx
=
get_current_
tower
_context
()
if
use_local_stat
is
None
:
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
use_local_stat
=
ctx
.
is_training
assert
use_local_stat
==
ctx
.
is_training
assert
use_local_stat
==
ctx
.
is_training
...
@@ -73,17 +73,23 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -73,17 +73,23 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else
:
else
:
assert
not
use_local_stat
assert
not
use_local_stat
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
# figure out the var name
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
if
ctx
.
is_main_tower
:
# not training, but main tower. need to create the vars
with
tf
.
name_scope
(
None
):
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
# use statistics in another tower
G
=
tf
.
get_default_graph
()
# figure out the var name
mean_var_name
=
ema
.
average_name
(
batch_mean
)
+
':0'
mean_var_name
=
ema
.
average_name
(
batch_mean
)
+
':0'
var_var_name
=
ema
.
average_name
(
batch_var
)
+
':0'
var_var_name
=
ema
.
average_name
(
batch_var
)
+
':0'
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
# use statistics in another tower
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
G
=
tf
.
get_default_graph
()
#logger.info("In prediction, using {} instead of {} for {}".format(
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
#mean_name, ema_mean.name, batch_mean.name))
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if
use_local_stat
:
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
with
tf
.
control_dependencies
([
ema_apply_op
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment