Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5a8d500c
Commit
5a8d500c
authored
Jun 19, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ChainInit & verbose download error
parent
7e32ccc7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
37 additions
and
10 deletions
+37
-10
examples/ResNet/README.md
examples/ResNet/README.md
+3
-0
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+2
-2
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+16
-3
tensorpack/train/base.py
tensorpack/train/base.py
+1
-0
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+9
-3
tensorpack/utils/stat.py
tensorpack/utils/stat.py
+3
-0
No files found.
examples/ResNet/README.md
View file @
5a8d500c
...
@@ -8,3 +8,6 @@ The train error shown here is a moving average of the error rate of each batch i
...
@@ -8,3 +8,6 @@ The train error shown here is a moving average of the error rate of each batch i
The validation error here is computed on test set.
The validation error here is computed on test set.


Download model:
[
Cifar10 n=18
](
https://drive.google.com/open?id=0B308TeQzmFDLeHpSaHAxWGV1WDg
)
tensorpack/callbacks/inference.py
View file @
5a8d500c
...
@@ -178,7 +178,7 @@ class ClassificationError(Inferencer):
...
@@ -178,7 +178,7 @@ class ClassificationError(Inferencer):
return
[
self
.
wrong_var_name
]
return
[
self
.
wrong_var_name
]
def
_before_inference
(
self
):
def
_before_inference
(
self
):
self
.
err_stat
=
Accuracy
()
self
.
err_stat
=
RatioCounter
()
def
_datapoint
(
self
,
dp
,
outputs
):
def
_datapoint
(
self
,
dp
,
outputs
):
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
...
@@ -186,7 +186,7 @@ class ClassificationError(Inferencer):
...
@@ -186,7 +186,7 @@ class ClassificationError(Inferencer):
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
self
.
trainer
.
write_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
accuracy
)
self
.
trainer
.
write_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
ratio
)
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
""" Compute precision/recall in binary classification, given the
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
5a8d500c
...
@@ -50,7 +50,8 @@ class ILSVRCMeta(object):
...
@@ -50,7 +50,8 @@ class ILSVRCMeta(object):
proto_path
=
download
(
CAFFE_PROTO_URL
,
self
.
dir
)
proto_path
=
download
(
CAFFE_PROTO_URL
,
self
.
dir
)
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
self
.
dir
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
self
.
dir
))
assert
ret
==
0
,
"caffe proto compilation failed!"
assert
ret
==
0
,
\
"caffe proto compilation failed! Did you install protoc?"
def
get_image_list
(
self
,
name
):
def
get_image_list
(
self
,
name
):
"""
"""
...
...
tensorpack/tfutils/sessinit.py
View file @
5a8d500c
...
@@ -13,7 +13,7 @@ import six
...
@@ -13,7 +13,7 @@ import six
from
..utils
import
logger
from
..utils
import
logger
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ParamRestore'
,
'ChainInit'
,
'JustCurrentSession'
,
'JustCurrentSession'
,
'dump_session_params'
]
'dump_session_params'
]
...
@@ -65,7 +65,6 @@ class SaverRestore(SessionInit):
...
@@ -65,7 +65,6 @@ class SaverRestore(SessionInit):
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
logger
.
info
(
logger
.
info
(
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
sess
.
run
(
tf
.
initialize_all_variables
())
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
...
@@ -131,7 +130,6 @@ class ParamRestore(SessionInit):
...
@@ -131,7 +130,6 @@ class ParamRestore(SessionInit):
self
.
prms
=
param_dict
self
.
prms
=
param_dict
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
initialize_all_variables
())
# allow restore non-trainable variables
# allow restore non-trainable variables
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
...
@@ -152,6 +150,21 @@ class ParamRestore(SessionInit):
...
@@ -152,6 +150,21 @@ class ParamRestore(SessionInit):
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
sess
.
run
(
var
.
assign
(
value
))
sess
.
run
(
var
.
assign
(
value
))
def
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
:params sess_inits: list of `SessionInit` instances.
:params new_session: add a `NewSession()` and the beginning, if not there
"""
if
new_session
and
not
isinstance
(
sess_inits
[
0
],
NewSession
):
sess_inits
.
insert
(
0
,
NewSession
())
self
.
inits
=
sess_inits
def
_init
(
self
,
sess
):
for
i
in
self
.
inits
:
i
.
init
(
sess
)
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" Dump value of all trainable variables to a dict and save to `path` as
""" Dump value of all trainable variables to a dict and save to `path` as
npy format, loadable by ParamRestore
npy format, loadable by ParamRestore
...
...
tensorpack/train/base.py
View file @
5a8d500c
...
@@ -105,6 +105,7 @@ class Trainer(object):
...
@@ -105,6 +105,7 @@ class Trainer(object):
get_global_step_var
()
# ensure there is such var, before finalizing the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
callbacks
.
setup_graph
(
self
)
callbacks
.
setup_graph
(
self
)
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
self
.
_start_concurrency
()
self
.
_start_concurrency
()
...
...
tensorpack/train/config.py
View file @
5a8d500c
...
@@ -47,7 +47,7 @@ class TrainConfig(object):
...
@@ -47,7 +47,7 @@ class TrainConfig(object):
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
New
Session
())
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrent
Session
())
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
))
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
))
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
))
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
))
...
...
tensorpack/utils/fs.py
View file @
5a8d500c
...
@@ -27,10 +27,16 @@ def download(url, dir):
...
@@ -27,10 +27,16 @@ def download(url, dir):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
float
(
count
*
block_size
)
/
float
(
total_size
)
*
100.0
))
(
fname
,
float
(
count
*
block_size
)
/
float
(
total_size
)
*
100.0
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
fpath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
fpath
,
reporthook
=
_progress
)
try
:
statinfo
=
os
.
stat
(
fpath
)
fpath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
fpath
,
reporthook
=
_progress
)
statinfo
=
os
.
stat
(
fpath
)
size
=
statinfo
.
st_size
except
:
logger
.
error
(
"Failed to download {}"
.
format
(
url
))
raise
assert
size
>
0
,
"Download an empty file!"
sys
.
stdout
.
write
(
'
\n
'
)
sys
.
stdout
.
write
(
'
\n
'
)
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
s
tatinfo
.
st_s
ize
)
+
' bytes.'
)
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
return
fpath
return
fpath
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/utils/stat.py
View file @
5a8d500c
...
@@ -6,6 +6,7 @@ import numpy as np
...
@@ -6,6 +6,7 @@ import numpy as np
__all__
=
[
'StatCounter'
,
'Accuracy'
,
'BinaryStatistics'
,
'RatioCounter'
]
__all__
=
[
'StatCounter'
,
'Accuracy'
,
'BinaryStatistics'
,
'RatioCounter'
]
class
StatCounter
(
object
):
class
StatCounter
(
object
):
""" A simple counter"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -35,6 +36,7 @@ class StatCounter(object):
...
@@ -35,6 +36,7 @@ class StatCounter(object):
return
max
(
self
.
_values
)
return
max
(
self
.
_values
)
class
RatioCounter
(
object
):
class
RatioCounter
(
object
):
""" A counter to count ratio of something"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -57,6 +59,7 @@ class RatioCounter(object):
...
@@ -57,6 +59,7 @@ class RatioCounter(object):
return
self
.
_tot
return
self
.
_tot
class
Accuracy
(
RatioCounter
):
class
Accuracy
(
RatioCounter
):
""" A RatioCounter with a fancy name """
@
property
@
property
def
accuracy
(
self
):
def
accuracy
(
self
):
return
self
.
ratio
return
self
.
ratio
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment