Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
375123f5
Commit
375123f5
authored
Jul 19, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor inferencer
parent
6728b686
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
36 deletions
+47
-36
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+17
-8
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+0
-9
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+1
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+28
-18
No files found.
examples/Atari2600/DQN.py
View file @
375123f5
...
@@ -92,7 +92,7 @@ class Model(ModelDesc):
...
@@ -92,7 +92,7 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
,
is_training
):
def
_build_graph
(
self
,
inputs
,
is_training
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
,
is_training
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
,
is_training
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
,
1.0
,
0.0
)
action_onehot
=
tf
.
one_hot
(
action
,
NUM_ACTIONS
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
#N,
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
#N,
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
max_pred_reward
=
tf
.
reduce_mean
(
tf
.
reduce_max
(
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
...
...
tensorpack/callbacks/inference.py
View file @
375123f5
...
@@ -6,6 +6,7 @@ import tensorflow as tf
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
six.moves
import
zip
,
map
from
six.moves
import
zip
,
map
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
...
@@ -43,8 +44,9 @@ class Inferencer(object):
...
@@ -43,8 +44,9 @@ class Inferencer(object):
def
after_inference
(
self
):
def
after_inference
(
self
):
"""
"""
Called after a round of inference ends.
Called after a round of inference ends.
Returns a dict of statistics.
"""
"""
self
.
_after_inference
()
return
self
.
_after_inference
()
def
_after_inference
(
self
):
def
_after_inference
(
self
):
pass
pass
...
@@ -84,8 +86,6 @@ class InferenceRunner(Callback):
...
@@ -84,8 +86,6 @@ class InferenceRunner(Callback):
input_names
=
[
x
.
name
for
x
in
self
.
input_vars
]
input_names
=
[
x
.
name
for
x
in
self
.
input_vars
]
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
input_names
,
self
.
output_tensors
)
input_names
,
self
.
output_tensors
)
for
v
in
self
.
vcs
:
v
.
trainer
=
self
.
trainer
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
self
.
output_tensors
=
[]
# list of names
self
.
output_tensors
=
[]
# list of names
...
@@ -118,7 +118,14 @@ class InferenceRunner(Callback):
...
@@ -118,7 +118,14 @@ class InferenceRunner(Callback):
pbar
.
update
()
pbar
.
update
()
for
vc
in
self
.
vcs
:
for
vc
in
self
.
vcs
:
vc
.
after_inference
()
ret
=
vc
.
after_inference
()
for
k
,
v
in
six
.
iteritems
(
ret
):
try
:
v
=
float
(
v
)
except
:
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
vc
)
.
__name__
))
continue
self
.
trainer
.
write_scalar_summary
(
k
,
v
)
class
ScalarStats
(
Inferencer
):
class
ScalarStats
(
Inferencer
):
"""
"""
...
@@ -150,10 +157,12 @@ class ScalarStats(Inferencer):
...
@@ -150,10 +157,12 @@ class ScalarStats(Inferencer):
self
.
stats
=
np
.
mean
(
self
.
stats
,
axis
=
0
)
self
.
stats
=
np
.
mean
(
self
.
stats
,
axis
=
0
)
assert
len
(
self
.
stats
)
==
len
(
self
.
names
)
assert
len
(
self
.
stats
)
==
len
(
self
.
names
)
ret
=
{}
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
for
stat
,
name
in
zip
(
self
.
stats
,
self
.
names
):
opname
,
_
=
get_op_var_name
(
name
)
opname
,
_
=
get_op_var_name
(
name
)
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
name
=
'{}_{}'
.
format
(
self
.
prefix
,
opname
)
if
self
.
prefix
else
opname
self
.
trainer
.
write_scalar_summary
(
name
,
stat
)
ret
[
name
]
=
stat
return
ret
class
ClassificationError
(
Inferencer
):
class
ClassificationError
(
Inferencer
):
"""
"""
...
@@ -187,7 +196,7 @@ class ClassificationError(Inferencer):
...
@@ -187,7 +196,7 @@ class ClassificationError(Inferencer):
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
self
.
err_stat
.
feed
(
wrong
,
batch_size
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
self
.
trainer
.
write_scalar_summary
(
self
.
summary_name
,
self
.
err_stat
.
ratio
)
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
""" Compute precision/recall in binary classification, given the
...
@@ -214,5 +223,5 @@ class BinaryClassificationStats(Inferencer):
...
@@ -214,5 +223,5 @@ class BinaryClassificationStats(Inferencer):
self
.
stat
.
feed
(
pred
,
label
)
self
.
stat
.
feed
(
pred
,
label
)
def
_after_inference
(
self
):
def
_after_inference
(
self
):
self
.
trainer
.
write_scalar_summary
(
self
.
prefix
+
'_precision'
,
self
.
stat
.
precision
)
return
{
self
.
prefix
+
'_precision'
:
self
.
stat
.
precision
,
self
.
trainer
.
write_scalar_summary
(
self
.
prefix
+
'_recall'
,
self
.
stat
.
recall
)
self
.
prefix
+
'_recall'
:
self
.
stat
.
recall
}
tensorpack/tfutils/symbolic_functions.py
View file @
375123f5
...
@@ -6,15 +6,6 @@ import tensorflow as tf
...
@@ -6,15 +6,6 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
from
..utils
import
logger
from
..utils
import
logger
def
one_hot
(
y
,
num_labels
):
"""
:param y: prediction. an Nx1 int tensor.
:param num_labels: an int. number of output classes
:returns: an NxC onehot matrix.
"""
logger
.
warn
(
"symbf.one_hot is deprecated in favor of more general tf.one_hot"
)
return
tf
.
one_hot
(
y
,
num_labels
,
1.0
,
0.0
,
name
=
'one_hot'
)
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
):
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
):
"""
"""
:param logits: NxC
:param logits: NxC
...
...
tensorpack/utils/logger.py
View file @
375123f5
...
@@ -10,6 +10,7 @@ from datetime import datetime
...
@@ -10,6 +10,7 @@ from datetime import datetime
from
six.moves
import
input
from
six.moves
import
input
import
sys
import
sys
from
.utils
import
memoized
from
.fs
import
mkdir_p
from
.fs
import
mkdir_p
__all__
=
[]
__all__
=
[]
...
...
tensorpack/utils/utils.py
View file @
375123f5
...
@@ -11,8 +11,6 @@ import collections
...
@@ -11,8 +11,6 @@ import collections
import
numpy
as
np
import
numpy
as
np
import
six
import
six
from
.
import
logger
__all__
=
[
'change_env'
,
__all__
=
[
'change_env'
,
'map_arg'
,
'map_arg'
,
'get_rng'
,
'memoized'
,
'get_rng'
,
'memoized'
,
...
@@ -50,28 +48,39 @@ class memoized(object):
...
@@ -50,28 +48,39 @@ class memoized(object):
(not reevaluated).
(not reevaluated).
'''
'''
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
self
.
func
=
func
self
.
func
=
func
self
.
cache
=
{}
self
.
cache
=
{}
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
if
not
isinstance
(
args
,
collections
.
Hashable
):
if
not
isinstance
(
args
,
collections
.
Hashable
):
# uncacheable. a list, for instance.
# uncacheable. a list, for instance.
# better to not cache than blow up.
# better to not cache than blow up.
return
self
.
func
(
*
args
)
return
self
.
func
(
*
args
)
if
args
in
self
.
cache
:
if
args
in
self
.
cache
:
return
self
.
cache
[
args
]
return
self
.
cache
[
args
]
else
:
else
:
value
=
self
.
func
(
*
args
)
value
=
self
.
func
(
*
args
)
self
.
cache
[
args
]
=
value
self
.
cache
[
args
]
=
value
return
value
return
value
def
__repr__
(
self
):
def
__repr__
(
self
):
'''Return the function's docstring.'''
'''Return the function's docstring.'''
return
self
.
func
.
__doc__
return
self
.
func
.
__doc__
def
__get__
(
self
,
obj
,
objtype
):
def
__get__
(
self
,
obj
,
objtype
):
'''Support instance methods.'''
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
"""
"""
...
@@ -96,6 +105,7 @@ def get_rng(obj=None):
...
@@ -96,6 +105,7 @@ def get_rng(obj=None):
return
np
.
random
.
RandomState
(
seed
)
return
np
.
random
.
RandomState
(
seed
)
def
get_dataset_path
(
*
args
):
def
get_dataset_path
(
*
args
):
from
.
import
logger
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
if
d
is
None
:
if
d
is
None
:
d
=
os
.
path
.
abspath
(
os
.
path
.
join
(
d
=
os
.
path
.
abspath
(
os
.
path
.
join
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment