Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3e61aacd
Commit
3e61aacd
authored
Feb 22, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
standardize the name "predictor" instead of "predict_func"
parent
088521fc
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
41 additions
and
31 deletions
+41
-31
docs/casestudies/colorize.md
docs/casestudies/colorize.md
+3
-3
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-2
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+3
-3
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+1
-1
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+2
-2
examples/HED/hed.py
examples/HED/hed.py
+2
-2
examples/Saliency/saliency-maps.py
examples/Saliency/saliency-maps.py
+2
-2
examples/load-alexnet.py
examples/load-alexnet.py
+2
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+3
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-1
tensorpack/train/base.py
tensorpack/train/base.py
+17
-8
tensorpack/train/predict.py
tensorpack/train/predict.py
+0
-1
No files found.
docs/casestudies/colorize.md
View file @
3e61aacd
...
...
@@ -349,7 +349,7 @@ class OnlineExport(Callback):
self
.
example_input
=
color
.
rgb2lab
(
cv2
.
imread
(
'myimage.jpg'
)[:,
:,
::
-
1
])[:,
:,
0
]
# read rgb image and extract luminance
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predict
_func
([
'luminance'
],
[
'prediction/output'
])
self
.
predictor
=
self
.
trainer
.
get_predict
or
([
'luminance'
],
[
'prediction/output'
])
def
_trigger_epoch
(
self
):
pass
...
...
@@ -367,7 +367,7 @@ you can simply `print(prediction)` to find out the name.
These two names allows us to build the inference part of the network in
```
python
self
.
trainer
.
get_predict
_func
([
'luminance'
,
'prediction/output'
])
self
.
trainer
.
get_predict
or
([
'luminance'
,
'prediction/output'
])
```
This is very convenient because in the
`_tigger_epoch`
we can use:
...
...
@@ -385,7 +385,7 @@ class OnlineExport(Callback):
self
.
example_input
=
color
.
rgb2lab
(
cv2
.
imread
(
'myimage.jpg'
)[:,
:,
[
2
,
1
,
0
]])[:,
:,
0
]
def
_setup_graph
(
self
):
self
.
trainer
.
get_predict
_func
([
'luminance'
,
'prediction/output'
])
self
.
trainer
.
get_predict
or
([
'luminance'
,
'prediction/output'
])
def
_trigger_epoch
(
self
):
hopefully_cool_rgb
=
self
.
pred
([[
self
.
example_input
]])[
0
][
0
]
...
...
examples/A3C-Gym/train-atari.py
View file @
3e61aacd
...
...
@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
_setup_graph
(
self
):
self
.
async_predictor
=
MultiThreadAsyncPredictor
(
self
.
trainer
.
get_predict
_func
s
([
'state'
],
[
'logitsT'
,
'pred_value'
],
self
.
trainer
.
get_predict
or
s
([
'state'
],
[
'logitsT'
,
'pred_value'
],
PREDICTOR_THREAD
),
batch_size
=
15
)
def
_before_train
(
self
):
...
...
examples/DeepQNetwork/common.py
View file @
3e61aacd
...
...
@@ -38,7 +38,7 @@ def play_model(cfg):
print
(
"Total:"
,
score
)
def
eval_with_funcs
(
predict
_func
s
,
nr_eval
):
def
eval_with_funcs
(
predict
or
s
,
nr_eval
):
class
Worker
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
...
...
@@ -62,7 +62,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
()
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict
_func
s
]
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict
or
s
]
for
k
in
threads
:
k
.
start
()
...
...
@@ -103,7 +103,7 @@ class Evaluator(Triggerable):
def
_setup_graph
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
20
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict
_func
(
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict
or
(
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
def
_trigger
(
self
):
...
...
examples/DeepQNetwork/expreplay.py
View file @
3e61aacd
...
...
@@ -229,7 +229,7 @@ class ExpReplay(DataFlow, Callback):
return
[
state
,
action
,
reward
,
isOver
]
def
_setup_graph
(
self
):
self
.
predictor
=
self
.
trainer
.
get_predict
_func
(
*
self
.
predictor_io_names
)
self
.
predictor
=
self
.
trainer
.
get_predict
or
(
*
self
.
predictor_io_names
)
def
_before_train
(
self
):
self
.
_init_memory
()
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
3e61aacd
...
...
@@ -258,7 +258,7 @@ def run_image(model, sess_init, inputs):
input_names
=
[
'input'
],
output_names
=
[
'output'
]
)
predict
_func
=
OfflinePredictor
(
pred_config
)
predict
or
=
OfflinePredictor
(
pred_config
)
meta
=
dataset
.
ILSVRCMeta
()
pp_mean
=
meta
.
get_per_pixel_mean
()
pp_mean_224
=
pp_mean
[
16
:
-
16
,
16
:
-
16
,
:]
...
...
@@ -282,7 +282,7 @@ def run_image(model, sess_init, inputs):
assert
img
is
not
None
img
=
transformers
.
augment
(
img
)[
np
.
newaxis
,
:,
:,
:]
outputs
=
predict
_func
([
img
])[
0
]
outputs
=
predict
or
([
img
])[
0
]
prob
=
outputs
[
0
]
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
...
...
examples/HED/hed.py
View file @
3e61aacd
...
...
@@ -192,11 +192,11 @@ def run(model_path, image_path, output):
session_init
=
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
output_names
=
[
'output'
+
str
(
k
)
for
k
in
range
(
1
,
7
)])
predict
_func
=
OfflinePredictor
(
pred_config
)
predict
or
=
OfflinePredictor
(
pred_config
)
im
=
cv2
.
imread
(
image_path
)
assert
im
is
not
None
im
=
cv2
.
resize
(
im
,
(
im
.
shape
[
1
]
//
16
*
16
,
im
.
shape
[
0
]
//
16
*
16
))
outputs
=
predict
_func
([[
im
.
astype
(
'float32'
)]])
outputs
=
predict
or
([[
im
.
astype
(
'float32'
)]])
if
output
is
None
:
for
k
in
range
(
6
):
pred
=
outputs
[
k
][
0
]
...
...
examples/Saliency/saliency-maps.py
View file @
3e61aacd
...
...
@@ -30,7 +30,7 @@ class Model(tp.ModelDesc):
def
run
(
model_path
,
image_path
):
predict
_func
=
tp
.
OfflinePredictor
(
tp
.
PredictConfig
(
predict
or
=
tp
.
OfflinePredictor
(
tp
.
PredictConfig
(
model
=
Model
(),
session_init
=
tp
.
get_model_loader
(
model_path
),
input_names
=
[
'image'
],
...
...
@@ -42,7 +42,7 @@ def run(model_path, image_path):
im
=
cv2
.
resize
(
im
,
(
IMAGE_SIZE
,
IMAGE_SIZE
))
im
=
im
.
astype
(
np
.
float32
)[:,
:,
::
-
1
]
saliency_images
=
predict
_func
([
im
])[
0
]
saliency_images
=
predict
or
([
im
])[
0
]
abs_saliency
=
np
.
abs
(
saliency_images
)
.
max
(
axis
=-
1
)
pos_saliency
=
np
.
maximum
(
0
,
saliency_images
)
...
...
examples/load-alexnet.py
View file @
3e61aacd
...
...
@@ -54,7 +54,7 @@ class Model(ModelDesc):
def
run_test
(
path
,
input
):
param_dict
=
np
.
load
(
path
,
encoding
=
'latin1'
)
.
item
()
predict
_func
=
OfflinePredictor
(
PredictConfig
(
predict
or
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
ParamRestore
(
param_dict
),
input_names
=
[
'input'
],
...
...
@@ -65,7 +65,7 @@ def run_test(path, input):
assert
im
is
not
None
,
input
im
=
cv2
.
resize
(
im
,
(
227
,
227
))[:,
:,
::
-
1
]
.
reshape
(
(
1
,
227
,
227
,
3
))
.
astype
(
'float32'
)
-
110
outputs
=
predict
_func
([
im
])[
0
]
outputs
=
predict
or
([
im
])[
0
]
prob
=
outputs
[
0
]
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
print
(
"Top10 predictions:"
,
ret
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
3e61aacd
...
...
@@ -92,7 +92,7 @@ class InferenceRunner(Triggerable):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
pred
_func
=
self
.
trainer
.
get_predict_func
(
self
.
pred
ictor
=
self
.
trainer
.
get_predictor
(
self
.
input_tensors
,
self
.
output_tensors
)
def
_find_input_tensors
(
self
):
...
...
@@ -135,7 +135,7 @@ class InferenceRunner(Triggerable):
self
.
ds
.
reset_state
()
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
outputs
=
self
.
pred
_func
(
dp
)
outputs
=
self
.
pred
ictor
(
dp
)
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
for
k
in
tensormap
]
...
...
tensorpack/callbacks/param.py
View file @
3e61aacd
...
...
@@ -218,8 +218,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
param: same as in :class:`HyperParamSetter`.
schedule (list): with the format ``[(epoch1, val1), (epoch2, val2), (epoch3, val3)]``.
Each ``(ep, val)`` pair means to set the param
to "val"
after
the completion of `ep` th epoch.
If ep == 0, the value will be set before t
raining
.
to "val"
__after__
the completion of `ep` th epoch.
If ep == 0, the value will be set before t
he first epoch
.
interp: None: no interpolation. 'linear': linear interpolation
Example:
...
...
@@ -263,6 +263,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
Args:
param: same as in :class:`HyperParamSetter`.
func: ``param`` will be set by ``new_value = func(epoch_num, old_value)``.
``epoch_num`` is the number of epochs that have finished.
Example:
Decrease by a factor of 0.9 every two epochs:
...
...
tensorpack/predict/base.py
View file @
3e61aacd
...
...
@@ -7,7 +7,7 @@ from abc import abstractmethod, ABCMeta
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..utils
import
logger
,
deprecated
from
..utils.argtools
import
memoized
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..tfutils
import
get_tensors_by_names
,
TowerContext
...
...
@@ -146,6 +146,7 @@ class OfflinePredictor(OnlinePredictor):
input_tensors
,
output_tensors
,
config
.
return_input
,
sess
)
@
deprecated
(
"Use OfflinePredictor instead!"
,
"2017-05-20"
)
def
get_predict_func
(
config
):
"""
Equivalent to ``OfflinePredictor(config)``.
...
...
tensorpack/train/base.py
View file @
3e61aacd
...
...
@@ -189,7 +189,7 @@ class Trainer(object):
self
.
summary_writer
.
close
()
self
.
monitored_sess
.
close
()
def
get_predict
_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predict
or
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
Args:
input_names (list), output_names(list): list of names
...
...
@@ -200,16 +200,25 @@ class Trainer(object):
"""
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
)
nr_tower
=
len
(
self
.
config
.
predict_tower
)
if
nr_tower
<
tower
:
logger
.
warn
(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin."
.
format
(
tower
,
nr_tower
))
tower
=
tower
%
nr_tower
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict
_func
s
(
self
,
input_names
,
output_names
,
n
):
def
get_predict
or
s
(
self
,
input_names
,
output_names
,
n
):
""" Return n predictors. """
nr_tower
=
len
(
self
.
config
.
predict_tower
)
if
nr_tower
<
n
:
logger
.
warn
(
"Requested {} predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin."
.
format
(
n
,
nr_tower
))
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
%
nr_tower
)
for
k
in
range
(
n
)]
return
[
self
.
get_predictor
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
@
deprecated
(
"Use get_predictor instead!"
,
"2017-05-20"
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
return
self
.
get_predictor
(
input_names
,
output_names
,
tower
)
@
deprecated
(
"Use get_predictors instead!"
,
"2017-05-20"
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
self
.
get_predictors
(
input_names
,
output_names
,
n
)
@
deprecated
(
"Don't need to call it any more!"
,
"2017-03-20"
)
def
_setup_predictor_factory
(
self
):
...
...
tensorpack/train/predict.py
View file @
3e61aacd
...
...
@@ -26,7 +26,6 @@ class PredictorFactory(object):
self
.
_tower_builder
=
PredictorTowerBuilder
(
fn
)
assert
isinstance
(
self
.
towers
,
list
)
# TODO sess option
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment