Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5ccaea83
Commit
5ccaea83
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'new-infer'
parents
8644248a
818e3faf
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
68 deletions
+91
-68
examples/cifar-convnet.py
examples/cifar-convnet.py
+5
-9
examples/mnist-convnet.py
examples/mnist-convnet.py
+1
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+29
-27
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+9
-12
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+16
-3
tensorpack/train/base.py
tensorpack/train/base.py
+5
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+26
-12
No files found.
examples/cifar-convnet.py
View file @
5ccaea83
...
...
@@ -8,14 +8,9 @@ import argparse
import
numpy
as
np
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
"""
A small convnet model for cifar 10 or cifar100 dataset.
...
...
@@ -65,7 +60,7 @@ class Model(ModelDesc):
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
# compute the number of failed samples, for ClassificationError to use at test time
wrong
=
prediction_incorrect
(
logits
,
label
)
wrong
=
symbf
.
prediction_incorrect
(
logits
,
label
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
...
...
@@ -161,4 +156,5 @@ if __name__ == '__main__':
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInputTrainer
(
config
)
.
train
()
#QueueInputTrainer(config).train()
SimpleTrainer
(
config
)
.
train
()
examples/mnist-convnet.py
View file @
5ccaea83
...
...
@@ -9,9 +9,6 @@ import os, sys
import
argparse
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.callbacks
import
*
"""
MNIST ConvNet example.
...
...
@@ -117,6 +114,5 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
#QueueInputTrainer(config).train()
SimpleInputTrainer
(
config
)
.
train
()
QueueInputTrainer
(
config
)
.
train
()
tensorpack/callbacks/group.py
View file @
5ccaea83
...
...
@@ -12,6 +12,7 @@ from ..utils import *
__all__
=
[
'Callbacks'
]
# --- Test-Callback related stuff seems not very useful.
@
contextmanager
def
create_test_graph
(
trainer
):
model
=
trainer
.
model
...
...
@@ -31,33 +32,6 @@ def create_test_session(trainer):
with
tf
.
Session
()
as
sess
:
yield
sess
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{:.3f}sec"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
class
TestCallbackContext
(
object
):
"""
A class holding the context needed for running TestCallback
...
...
@@ -91,6 +65,34 @@ class TestCallbackContext(object):
def
test_context
(
self
):
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
# ---
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
def
add
(
self
,
name
,
time
):
self
.
tot
+=
time
self
.
times
.
append
((
name
,
time
))
@
contextmanager
def
timed_callback
(
self
,
name
):
s
=
time
.
time
()
yield
self
.
add
(
name
,
time
.
time
()
-
s
)
def
log
(
self
):
""" log the time of some heavy callbacks """
if
self
.
tot
<
3
:
return
msgs
=
[]
for
name
,
t
in
self
.
times
:
if
t
/
self
.
tot
>
0.3
and
t
>
1
:
msgs
.
append
(
"{}:{:.3f}sec"
.
format
(
name
,
t
))
logger
.
info
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
class
Callbacks
(
Callback
):
"""
...
...
tensorpack/callbacks/inference.py
View file @
5ccaea83
...
...
@@ -13,7 +13,7 @@ from ..utils import *
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
.base
import
Callback
,
TestCallbackType
from
.base
import
Callback
__all__
=
[
'InferenceRunner'
,
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
...
...
@@ -63,7 +63,6 @@ class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
type
=
TestCallbackType
()
def
__init__
(
self
,
ds
,
vcs
):
"""
...
...
@@ -82,12 +81,15 @@ class InferenceRunner(Callback):
def
_before_train
(
self
):
self
.
input_vars
=
self
.
trainer
.
model
.
reuse_input_vars
()
self
.
_find_output_tensors
()
input_names
=
[
x
.
name
for
x
in
self
.
input_vars
]
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
input_names
,
self
.
output_tensors
)
for
v
in
self
.
vcs
:
v
.
trainer
=
self
.
trainer
def
_find_output_tensors
(
self
):
self
.
output_tensors
=
[]
self
.
vc_to_vars
=
[]
self
.
output_tensors
=
[]
# list of names
self
.
vc_to_vars
=
[]
# list of list of (var_name: output_idx)
for
vc
in
self
.
vcs
:
vc_vars
=
vc
.
_get_output_tensors
()
def
find_oid
(
var
):
...
...
@@ -99,12 +101,6 @@ class InferenceRunner(Callback):
vc_vars
=
[(
var
,
find_oid
(
var
))
for
var
in
vc_vars
]
self
.
vc_to_vars
.
append
(
vc_vars
)
# convert name to tensors
def
get_tensor
(
name
):
_
,
varname
=
get_op_var_name
(
name
)
return
self
.
graph
.
get_tensor_by_name
(
varname
)
self
.
output_tensors
=
list
(
map
(
get_tensor
,
self
.
output_tensors
))
def
_trigger_epoch
(
self
):
for
vc
in
self
.
vcs
:
vc
.
before_inference
()
...
...
@@ -112,8 +108,9 @@ class InferenceRunner(Callback):
sess
=
tf
.
get_default_session
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
# TODO custom dp mapping?
outputs
=
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
outputs
=
self
.
pred_func
(
dp
)
for
vc
,
varsmap
in
zip
(
self
.
vcs
,
self
.
vc_to_vars
):
vc_output
=
[
outputs
[
k
[
1
]]
for
k
in
varsmap
]
vc
.
datapoint
(
dp
,
vc_output
)
...
...
tensorpack/tfutils/common.py
View file @
5ccaea83
...
...
@@ -9,7 +9,9 @@ import tensorflow as tf
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step_var'
,
'get_op_var_name'
]
'get_op_var_name'
,
'get_vars_by_names'
]
def
get_default_sess_config
(
mem_fraction
=
0.9
):
"""
...
...
@@ -53,3 +55,14 @@ def get_op_var_name(name):
return
name
[:
-
2
],
name
else
:
return
name
,
name
+
':0'
def
get_vars_by_names
(
names
):
"""
Get a list of variables in the default graph by a list of names
"""
ret
=
[]
G
=
tf
.
get_default_graph
()
for
n
in
names
:
opn
,
varn
=
get_op_var_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
tensorpack/train/base.py
View file @
5ccaea83
...
...
@@ -50,6 +50,11 @@ class Trainer(object):
""" run an iteration"""
pass
@
abstractmethod
def
get_predict_func
(
self
,
input_names
,
output_names
):
""" return a predict function"""
pass
def
trigger_epoch
(
self
):
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
...
...
tensorpack/train/trainer.py
View file @
5ccaea83
...
...
@@ -53,25 +53,16 @@ class SimpleTrainer(Trainer):
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
input_vars
=
[]
for
n
in
input_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
input_vars
=
get_vars_by_names
(
input_names
)
for
v
in
input_vars
:
assert
v
in
self
.
input_vars
input_vars
.
append
(
v
)
output_vars
=
[]
for
n
in
output_names
:
opn
,
varn
=
get_op_var_name
(
n
)
v
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
varn
)
output_vars
.
append
(
v
)
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
super
(
EnqueueThread
,
self
)
.
__init__
()
...
...
@@ -126,6 +117,7 @@ class QueueInputTrainer(Trainer):
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
self
.
_dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
...
...
@@ -148,6 +140,7 @@ class QueueInputTrainer(Trainer):
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
self
.
_dequed_inputs
.
append
(
ret
)
return
ret
def
_single_tower_grad
(
self
):
...
...
@@ -248,6 +241,27 @@ class QueueInputTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
raw_input_vars
=
get_vars_by_names
(
input_names
)
input_var_idxs
=
[
self
.
input_vars
.
index
(
v
)
for
v
in
raw_input_vars
]
if
self
.
config
.
nr_tower
==
1
:
dequed
=
self
.
_dequed_inputs
[
0
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_vars
=
get_vars_by_names
(
output_names
)
else
:
# TODO naive impl: use the first tower only
dequed
=
self
.
_dequed_inputs
[
0
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_names
=
[
'tower0/'
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
def
start_train
(
config
):
tr
=
QueueInputTrainer
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment