Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
77bcc8b1
Commit
77bcc8b1
authored
Nov 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
input_names instead of input_var_names
parent
c6c9a4ba
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
57 additions
and
47 deletions
+57
-47
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+3
-3
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+2
-2
examples/HED/hed.py
examples/HED/hed.py
+2
-2
examples/OpenAIGym/run-atari.py
examples/OpenAIGym/run-atari.py
+2
-2
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+2
-2
examples/ResNet/README.md
examples/ResNet/README.md
+1
-1
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+2
-2
examples/ResNet/load-resnet.py
examples/ResNet/load-resnet.py
+4
-4
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+2
-2
examples/load-alexnet.py
examples/load-alexnet.py
+2
-2
examples/load-vgg16.py
examples/load-vgg16.py
+2
-2
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+1
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+14
-14
tensorpack/predict/common.py
tensorpack/predict/common.py
+18
-9
No files found.
examples/Atari2600/DQN.py
View file @
77bcc8b1
...
...
@@ -123,7 +123,7 @@ class Model(ModelDesc):
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
self
.
cost
=
symbf
.
huber_loss
(
target
-
pred_action_value
,
name
=
'cost'
)
self
.
cost
=
tf
.
truediv
(
symbf
.
huber_loss
(
target
-
pred_action_value
),
BATCH_SIZE
,
name
=
'cost'
)
summary
.
add_param_summary
([(
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
])
])
# monitor all W
...
...
@@ -200,8 +200,8 @@ if __name__ == '__main__':
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
input_
var_
names
=
[
'state'
],
output_
var_
names
=
[
'Qvalue'
])
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
if
args
.
task
==
'play'
:
play_model
(
cfg
)
elif
args
.
task
==
'eval'
:
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
77bcc8b1
...
...
@@ -245,8 +245,8 @@ def run_image(model, sess_init, inputs):
model
=
model
,
session_init
=
sess_init
,
session_config
=
get_default_sess_config
(
0.9
),
input_
var_
names
=
[
'input'
],
output_
var_
names
=
[
'output'
]
input_names
=
[
'input'
],
output_names
=
[
'output'
]
)
predict_func
=
get_predict_func
(
pred_config
)
meta
=
dataset
.
ILSVRCMeta
()
...
...
examples/HED/hed.py
View file @
77bcc8b1
...
...
@@ -184,9 +184,9 @@ def get_config():
def
run
(
model_path
,
image_path
):
pred_config
=
PredictConfig
(
model
=
Model
(),
input_data_mapping
=
[
0
],
session_init
=
get_model_loader
(
model_path
),
output_var_names
=
[
'output'
+
str
(
k
)
for
k
in
range
(
1
,
7
)])
input_names
=
[
'image'
],
output_names
=
[
'output'
+
str
(
k
)
for
k
in
range
(
1
,
7
)])
predict_func
=
get_predict_func
(
pred_config
)
im
=
cv2
.
imread
(
image_path
)
assert
im
is
not
None
...
...
examples/OpenAIGym/run-atari.py
View file @
77bcc8b1
...
...
@@ -95,6 +95,6 @@ if __name__ == '__main__':
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
input_
var_
names
=
[
'state'
],
output_
var_
names
=
[
'logits'
])
input_names
=
[
'state'
],
output_names
=
[
'logits'
])
run_submission
(
cfg
,
args
.
output
,
args
.
episode
)
examples/OpenAIGym/train-atari.py
View file @
77bcc8b1
...
...
@@ -235,8 +235,8 @@ if __name__ == '__main__':
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
input_
var_
names
=
[
'state'
],
output_
var_
names
=
[
'logits'
])
input_names
=
[
'state'
],
output_names
=
[
'logits'
])
if
args
.
task
==
'play'
:
play_model
(
cfg
)
elif
args
.
task
==
'eval'
:
...
...
examples/ResNet/README.md
View file @
77bcc8b1
## imagenet-resnet.py
ImageNet training code of pre-activation Res
Net. It follows the setup in
Training code of pre-activation ResNet on Image
Net. It follows the setup in
[
fb.resnet.torch
](
https://github.com/facebook/fb.resnet.torch
)
and gets similar performance (with much fewer lines of code).
More results to come.
...
...
examples/ResNet/imagenet-resnet.py
View file @
77bcc8b1
...
...
@@ -213,9 +213,9 @@ def eval_on_ILSVRC12(model_file, data_dir):
ds
=
get_data
(
'val'
)
pred_config
=
PredictConfig
(
model
=
Model
(),
input_var_names
=
[
'input'
,
'label'
],
session_init
=
get_model_loader
(
model_file
),
output_var_names
=
[
'wrong-top1'
,
'wrong-top5'
]
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'wrong-top5'
]
)
pred
=
SimpleDatasetPredictor
(
pred_config
,
ds
)
acc1
,
acc5
=
RatioCounter
(),
RatioCounter
()
...
...
examples/ResNet/load-resnet.py
View file @
77bcc8b1
...
...
@@ -111,8 +111,8 @@ def run_test(params, input):
pred_config
=
PredictConfig
(
model
=
Model
(),
session_init
=
ParamRestore
(
params
),
input_
var_
names
=
[
'input'
],
output_
var_
names
=
[
'prob'
]
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
)
predict_func
=
get_predict_func
(
pred_config
)
...
...
@@ -134,9 +134,9 @@ def eval_on_ILSVRC12(params, data_dir):
ds
=
BatchData
(
ds
,
128
,
remainder
=
True
)
pred_config
=
PredictConfig
(
model
=
Model
(),
input_var_names
=
[
'input'
,
'label'
],
session_init
=
ParamRestore
(
params
),
output_var_names
=
[
'wrong-top1'
,
'wrong-top5'
]
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'wrong-top5'
]
)
pred
=
SimpleDatasetPredictor
(
pred_config
,
ds
)
acc1
,
acc5
=
RatioCounter
(),
RatioCounter
()
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
77bcc8b1
...
...
@@ -109,8 +109,8 @@ def view_warp(modelpath):
pred
=
OfflinePredictor
(
PredictConfig
(
session_init
=
get_model_loader
(
modelpath
),
model
=
Model
(),
input_
var_
names
=
[
'input'
],
output_
var_
names
=
[
'viz'
,
'STN1/affine'
,
'STN2/affine'
]))
input_names
=
[
'input'
],
output_names
=
[
'viz'
,
'STN1/affine'
,
'STN2/affine'
]))
xys
=
np
.
array
([[
0
,
0
,
1
],
[
WARP_TARGET_SIZE
,
0
,
1
],
...
...
examples/load-alexnet.py
View file @
77bcc8b1
...
...
@@ -53,10 +53,10 @@ def run_test(path, input):
pred_config
=
PredictConfig
(
model
=
Model
(),
input_var_names
=
[
'input'
],
session_init
=
ParamRestore
(
param_dict
),
session_config
=
get_default_sess_config
(
0.9
),
output_var_names
=
[
'output'
]
# the variable 'output' is the probability distribution
input_names
=
[
'input'
],
output_names
=
[
'output'
]
# the variable 'output' is the probability distribution
)
predict_func
=
get_predict_func
(
pred_config
)
...
...
examples/load-vgg16.py
View file @
77bcc8b1
...
...
@@ -72,10 +72,10 @@ def run_test(path, input):
param_dict
=
np
.
load
(
path
)
.
item
()
pred_config
=
PredictConfig
(
model
=
Model
(),
input_
var_
names
=
[
'input'
],
input_names
=
[
'input'
],
session_init
=
ParamRestore
(
param_dict
),
session_config
=
get_default_sess_config
(
0.9
),
output_
var_
names
=
[
'output'
]
# output:0 is the probability distribution
output_names
=
[
'output'
]
# output:0 is the probability distribution
)
predict_func
=
get_predict_func
(
pred_config
)
...
...
tensorpack/dataflow/remote.py
View file @
77bcc8b1
...
...
@@ -26,6 +26,7 @@ def serve_data(ds, addr):
try
:
ds
.
reset_state
()
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
# TODO print statistics here
while
True
:
for
dp
in
ds
.
get_data
():
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
...
...
tensorpack/predict/base.py
View file @
77bcc8b1
...
...
@@ -8,7 +8,7 @@ import tensorflow as tf
import
six
from
..utils
import
logger
from
..tfutils
import
get_
va
rs_by_names
,
TowerContext
from
..tfutils
import
get_
tenso
rs_by_names
,
TowerContext
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
...
...
@@ -41,7 +41,7 @@ class PredictorBase(object):
@
abstractmethod
def
_do_call
(
self
,
dp
):
"""
:param dp: input datapoint. must have the same length as input_
var_
names
:param dp: input datapoint. must have the same length as input_names
:return: output as defined by the config
"""
...
...
@@ -67,18 +67,18 @@ class AsyncPredictorBase(PredictorBase):
return
fut
.
result
()
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_
vars
,
output_va
rs
,
return_input
=
False
):
def
__init__
(
self
,
sess
,
input_
tensors
,
output_tenso
rs
,
return_input
=
False
):
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
input_
vars
=
input_va
rs
self
.
output_
vars
=
output_va
rs
self
.
input_
tensors
=
input_tenso
rs
self
.
output_
tensors
=
output_tenso
rs
def
_do_call
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_
va
rs
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_
va
rs
))
feed
=
dict
(
zip
(
self
.
input_
va
rs
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_
va
rs
,
feed_dict
=
feed
)
assert
len
(
dp
)
==
len
(
self
.
input_
tenso
rs
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_
tenso
rs
))
feed
=
dict
(
zip
(
self
.
input_
tenso
rs
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_
tenso
rs
,
feed_dict
=
feed
)
return
output
...
...
@@ -91,8 +91,8 @@ class OfflinePredictor(OnlinePredictor):
with
TowerContext
(
''
,
False
):
config
.
model
.
build_graph
(
input_vars
)
input_vars
=
get_
vars_by_names
(
config
.
input_var
_names
)
output_vars
=
get_
vars_by_names
(
config
.
output_var
_names
)
input_vars
=
get_
tensors_by_names
(
config
.
input
_names
)
output_vars
=
get_
tensors_by_names
(
config
.
output
_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
...
...
@@ -124,12 +124,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
self
.
sess
)
input_vars
=
get_
vars_by_names
(
config
.
input_var
_names
)
input_vars
=
get_
tensors_by_names
(
config
.
input
_names
)
for
k
in
towers
:
output_vars
=
get_
va
rs_by_names
(
output_vars
=
get_
tenso
rs_by_names
(
[
'{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
for
n
in
config
.
output_
var_
names
])
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
...
tensorpack/predict/common.py
View file @
77bcc8b1
...
...
@@ -26,9 +26,9 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_var_names: a list of input variable names.
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
:param input_names: a list of input variable names.
:param output_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
...
...
@@ -45,15 +45,24 @@ class PredictConfig(object):
assert_type
(
self
.
model
,
ModelDesc
)
# inputs & outputs
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
if
self
.
input_var_names
is
None
:
# TODO add deprecated warning later
self
.
input_names
=
kwargs
.
pop
(
'input_names'
,
None
)
if
self
.
input_names
is
None
:
self
.
input_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
if
self
.
input_names
is
not
None
:
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if
self
.
input_names
is
None
:
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input_vars_desc
()
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
assert
len
(
self
.
input_var_names
),
self
.
input_var_names
for
v
in
self
.
input_var_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_var_names
),
self
.
output_var_names
self
.
input_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_names
=
kwargs
.
pop
(
'output_names'
,
None
)
if
self
.
output_names
is
None
:
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment