Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a4371695
Commit
a4371695
authored
Apr 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
docs
parent
e1fbdca1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
163 additions
and
95 deletions
+163
-95
tensorpack/predict.py
tensorpack/predict.py
+46
-40
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+4
-4
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+13
-4
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+13
-7
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+5
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+30
-7
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+10
-9
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+19
-3
tensorpack/train/base.py
tensorpack/train/base.py
+6
-1
tensorpack/train/config.py
tensorpack/train/config.py
+14
-15
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+3
-3
No files found.
tensorpack/predict.py
View file @
a4371695
...
@@ -15,38 +15,41 @@ from .utils import logger
...
@@ -15,38 +15,41 @@ from .utils import logger
from
.tfutils.modelutils
import
describe_model
from
.tfutils.modelutils
import
describe_model
from
.dataflow
import
DataFlow
,
BatchData
from
.dataflow
import
DataFlow
,
BatchData
__all__
=
[
'PredictConfig'
,
'DatasetPredictor'
,
'get_predict_func'
]
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
The config used by `get_predict_func`
The config used by `get_predict_func`.
Args:
session_config: a tf.ConfigProto instance to instantiate the
:param session_config: a `tf.ConfigProto` instance to instantiate the
session. default to a session running 1 GPU.
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
initialize variables of a session.
input_data_mapping: Decide the mapping from each component in data
:param input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
the `label` input is not used if you only need probability
distribution).
distribution).
It should be a list with size=len(one_data_point),
It should be a list with size=len(data_point),
where each element is an index of the input variables each
where each element is an index of the input variables each
component of the data point should be fed into.
component of the data point should be fed into.
If not given, defaults to range(len(input_vars))
If not given, defaults to range(len(input_vars))
For example, with image classification task, the testing
For example, in image classification task, the testing
dataset only provides datapoints of images (no labels). When
dataset only provides datapoints of images (no labels). When
the input variables of the model is:
the input variables of the model is: ::
input_vars: [image_var, label_var]
the mapping should look like:
input_vars: [image_var, label_var]
input_data_mapping: [0]
If this argument is not set in this case, the inputs and the data points won't be aligned.
the mapping should look like: ::
model: a ModelDesc instance
output_var_names: a list of names of the output variable to predict, the
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
variables can be any computable tensor in the graph.
if None, will only calculate the cost returned by `get_model_func`.
:param model: a `ModelDesc` instance
Predict only specific output (instead of the cost)
:param output_var_names: a list of names of the output variables to predict, the
might be faster and might require only some of the input variables.
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
@@ -55,18 +58,14 @@ class PredictConfig(object):
...
@@ -55,18 +58,14 @@ class PredictConfig(object):
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Args:
:param config: a `PredictConfig` instance.
config: a PredictConfig
:returns: A prediction function that takes a list of input values, and return
Returns:
a list of output values defined in ``config.output_var_names``.
A prediction function that takes a list of inputs value, and return
one/a list of output values.
If `output_var_names` is set, then the prediction function will
return a list of output values. If not, will return a cost.
"""
"""
output_var_names
=
config
.
output_var_names
output_var_names
=
config
.
output_var_names
...
@@ -106,10 +105,14 @@ def get_predict_func(config):
...
@@ -106,10 +105,14 @@ def get_predict_func(config):
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
class
DatasetPredictor
(
object
):
class
DatasetPredictor
(
object
):
"""
Run the predict_config on a given `DataFlow`.
"""
def
__init__
(
self
,
predict_config
,
dataset
,
batch
=
0
):
def
__init__
(
self
,
predict_config
,
dataset
,
batch
=
0
):
"""
"""
A predictor with the given predict_config, run on the given dataset
:param predict_config: a `PredictConfig` instance.
if batch is larger than zero, the dataset will be batched
:param dataset: a `DataFlow` instance.
:param batch: if batch > zero, will batch the dataset before running.
"""
"""
assert
isinstance
(
dataset
,
DataFlow
)
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
self
.
ds
=
dataset
...
@@ -118,11 +121,14 @@ class DatasetPredictor(object):
...
@@ -118,11 +121,14 @@ class DatasetPredictor(object):
self
.
predict_func
=
get_predict_func
(
predict_config
)
self
.
predict_func
=
get_predict_func
(
predict_config
)
def
get_result
(
self
):
def
get_result
(
self
):
"""
a generator to return
prediction for each data"""
"""
A generator to produce
prediction for each data"""
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
PredictResult
(
dp
,
self
.
predict_func
(
dp
))
yield
PredictResult
(
dp
,
self
.
predict_func
(
dp
))
pbar
.
update
()
pbar
.
update
()
def
get_all_result
(
self
):
def
get_all_result
(
self
):
"""
Run over the dataset and return a list of all predictions.
"""
return
list
(
self
.
get_result
())
return
list
(
self
.
get_result
())
tensorpack/tfutils/__init__.py
View file @
a4371695
...
@@ -5,13 +5,13 @@
...
@@ -5,13 +5,13 @@
from
pkgutil
import
walk_packages
from
pkgutil
import
walk_packages
import
os
import
os
def
global_import
(
name
):
def
_
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'sessinit'
)
_
global_import
(
'sessinit'
)
global_import
(
'common'
)
_
global_import
(
'common'
)
global_import
(
'gradproc'
)
_
global_import
(
'gradproc'
)
tensorpack/tfutils/common.py
View file @
a4371695
...
@@ -8,8 +8,11 @@ import tensorflow as tf
...
@@ -8,8 +8,11 @@ import tensorflow as tf
def
get_default_sess_config
(
mem_fraction
=
0.5
):
def
get_default_sess_config
(
mem_fraction
=
0.5
):
"""
"""
Return a better config to use as default.
Return a better session config to use as default.
Tensorflow default session config consume too much resources
Tensorflow default session config consume too much resources.
:param mem_fraction: fraction of memory to use.
:returns: a `tf.ConfigProto` object.
"""
"""
conf
=
tf
.
ConfigProto
()
conf
=
tf
.
ConfigProto
()
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
mem_fraction
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
mem_fraction
...
@@ -18,7 +21,7 @@ def get_default_sess_config(mem_fraction=0.5):
...
@@ -18,7 +21,7 @@ def get_default_sess_config(mem_fraction=0.5):
return
conf
return
conf
def
get_global_step_var
():
def
get_global_step_var
():
"""
get global_step variable in the current graph
"""
"""
:returns: the global_step variable in the current graph. create if not existed
"""
try
:
try
:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
except
KeyError
:
except
KeyError
:
...
@@ -27,13 +30,19 @@ def get_global_step_var():
...
@@ -27,13 +30,19 @@ def get_global_step_var():
return
var
return
var
def
get_global_step
():
def
get_global_step
():
"""
get global_step value with
current graph and session"""
"""
:returns: global_step value in
current graph and session"""
return
tf
.
train
.
global_step
(
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
get_global_step_var
())
get_global_step_var
())
def
get_op_var_name
(
name
):
def
get_op_var_name
(
name
):
"""
Variable name is assumed to be ``op_name + ':0'``
:param name: an op or a variable name
:returns: (op_name, variable_name)
"""
if
name
.
endswith
(
':0'
):
if
name
.
endswith
(
':0'
):
return
name
[:
-
2
],
name
return
name
[:
-
2
],
name
else
:
else
:
...
...
tensorpack/tfutils/gradproc.py
View file @
a4371695
...
@@ -14,18 +14,24 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
...
@@ -14,18 +14,24 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
class
GradientProcessor
(
object
):
class
GradientProcessor
(
object
):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
@
abstractmethod
def
process
(
self
,
grads
):
def
process
(
self
,
grads
):
"""
"""
Process the symbolic gradients, return symbolic gradients
Process the symbolic gradients.
grads: list of (grad, var)
:param grads: list of (grad, var)
:returns: symbolic gradients with the same type as input
"""
"""
self
.
_process
(
grads
)
@
abstractmethod
def
_process
(
self
,
grads
):
pass
class
SummaryGradient
(
GradientProcessor
):
class
SummaryGradient
(
GradientProcessor
):
"""
"""
Summary history and RMS for each graident variable
Summary history and RMS for each graident variable
"""
"""
def
process
(
self
,
grads
):
def
_
process
(
self
,
grads
):
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
tf
.
histogram_summary
(
var
.
op
.
name
+
'/grad'
,
grad
)
tf
.
histogram_summary
(
var
.
op
.
name
+
'/grad'
,
grad
)
tf
.
scalar_summary
(
var
.
op
.
name
+
'/gradRMS'
,
tf
.
scalar_summary
(
var
.
op
.
name
+
'/gradRMS'
,
...
@@ -37,7 +43,7 @@ class CheckGradient(GradientProcessor):
...
@@ -37,7 +43,7 @@ class CheckGradient(GradientProcessor):
"""
"""
Check for numeric issue
Check for numeric issue
"""
"""
def
process
(
self
,
grads
):
def
_
process
(
self
,
grads
):
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
assert
grad
is
not
None
,
"Grad is None for variable {}"
.
format
(
var
.
name
)
assert
grad
is
not
None
,
"Grad is None for variable {}"
.
format
(
var
.
name
)
# TODO make assert work
# TODO make assert work
...
@@ -50,11 +56,11 @@ class ScaleGradient(GradientProcessor):
...
@@ -50,11 +56,11 @@ class ScaleGradient(GradientProcessor):
"""
"""
def
__init__
(
self
,
multipliers
):
def
__init__
(
self
,
multipliers
):
"""
"""
multipliers: list of (regex, float)
:param
multipliers: list of (regex, float)
"""
"""
self
.
multipliers
=
multipliers
self
.
multipliers
=
multipliers
def
process
(
self
,
grads
):
def
_
process
(
self
,
grads
):
# TODO use None for zero to speed up?
# TODO use None for zero to speed up?
ret
=
[]
ret
=
[]
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
...
...
tensorpack/tfutils/modelutils.py
View file @
a4371695
...
@@ -7,7 +7,7 @@ import tensorflow as tf
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
..utils
import
logger
from
..utils
import
logger
def
describe_model
():
def
describe_model
():
"""
describe the current model parameters
"""
"""
print a description of the current model parameters
"""
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
msg
=
[
""
]
msg
=
[
""
]
total
=
0
total
=
0
...
@@ -22,7 +22,10 @@ def describe_model():
...
@@ -22,7 +22,10 @@ def describe_model():
def
get_shape_str
(
tensors
):
def
get_shape_str
(
tensors
):
""" return the shape string for a tensor or a list of tensors"""
"""
:param tensors: a tensor or a list of tensors
:returns: a string to describe the shape
"""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
if
isinstance
(
tensors
,
(
list
,
tuple
)):
for
v
in
tensors
:
for
v
in
tensors
:
assert
isinstance
(
v
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
v
))
assert
isinstance
(
v
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
v
))
...
...
tensorpack/tfutils/sessinit.py
View file @
a4371695
...
@@ -14,18 +14,36 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'ParamRestore',
...
@@ -14,18 +14,36 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'ParamRestore',
'dump_session_params'
]
'dump_session_params'
]
class
SessionInit
(
object
):
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
@
abstractmethod
def
init
(
self
,
sess
):
def
init
(
self
,
sess
):
""" Method to initialize a session"""
""" Initialize a session
:param sess: a `tf.Session`
"""
self
.
_init
(
sess
)
@
abstractmethod
def
_init
(
self
,
sess
):
pass
class
NewSession
(
SessionInit
):
class
NewSession
(
SessionInit
):
def
init
(
self
,
sess
):
"""
Create a new session. All variables will be initialized by their
initializer.
"""
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
class
SaverRestore
(
SessionInit
):
class
SaverRestore
(
SessionInit
):
"""
Restore an old model saved by `tf.Saver`.
"""
def
__init__
(
self
,
model_path
):
def
__init__
(
self
,
model_path
):
"""
:param model_path: a model file or a ``checkpoint`` file.
"""
assert
os
.
path
.
isfile
(
model_path
)
assert
os
.
path
.
isfile
(
model_path
)
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
model_path
=
tf
.
train
.
get_checkpoint_state
(
model_path
=
tf
.
train
.
get_checkpoint_state
(
...
@@ -33,7 +51,7 @@ class SaverRestore(SessionInit):
...
@@ -33,7 +51,7 @@ class SaverRestore(SessionInit):
assert
os
.
path
.
isfile
(
model_path
)
assert
os
.
path
.
isfile
(
model_path
)
self
.
set_path
(
model_path
)
self
.
set_path
(
model_path
)
def
init
(
self
,
sess
):
def
_
init
(
self
,
sess
):
saver
=
tf
.
train
.
Saver
()
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
sess
,
self
.
path
)
saver
.
restore
(
sess
,
self
.
path
)
logger
.
info
(
logger
.
info
(
...
@@ -44,12 +62,15 @@ class SaverRestore(SessionInit):
...
@@ -44,12 +62,15 @@ class SaverRestore(SessionInit):
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
"""
"""
Restore trainable variables from a dictionary
Restore trainable variables from a dictionary
.
"""
"""
def
__init__
(
self
,
param_dict
):
def
__init__
(
self
,
param_dict
):
"""
:param param_dict: a dict of {name: value}
"""
self
.
prms
=
param_dict
self
.
prms
=
param_dict
def
init
(
self
,
sess
):
def
_
init
(
self
,
sess
):
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
...
@@ -70,7 +91,9 @@ class ParamRestore(SessionInit):
...
@@ -70,7 +91,9 @@ class ParamRestore(SessionInit):
sess
.
run
(
var
.
assign
(
value
))
sess
.
run
(
var
.
assign
(
value
))
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" dump value of all trainable variables to a dict"""
""" Dump value of all trainable variables to a dict and save to `path` as
npy format.
"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
result
=
{}
result
=
{}
for
v
in
var
:
for
v
in
var
:
...
...
tensorpack/tfutils/summary.py
View file @
a4371695
...
@@ -10,9 +10,7 @@ from . import get_global_step_var
...
@@ -10,9 +10,7 @@ from . import get_global_step_var
def
create_summary
(
name
,
v
):
def
create_summary
(
name
,
v
):
"""
"""
Return a tf.Summary object with name and simple value v
Return a tf.Summary object with name and simple scalar value v
Args: v: a value
"""
"""
assert
isinstance
(
name
,
six
.
string_types
),
type
(
name
)
assert
isinstance
(
name
,
six
.
string_types
),
type
(
name
)
v
=
float
(
v
)
v
=
float
(
v
)
...
@@ -22,8 +20,8 @@ def create_summary(name, v):
...
@@ -22,8 +20,8 @@ def create_summary(name, v):
def
add_activation_summary
(
x
,
name
=
None
):
def
add_activation_summary
(
x
,
name
=
None
):
"""
"""
Summary
for an activation tensor x.
Add summary to graph
for an activation tensor x.
If name is None, use x.name
If name is None, use x.name
.
"""
"""
ndim
=
x
.
get_shape
()
.
ndims
ndim
=
x
.
get_shape
()
.
ndims
assert
ndim
>=
2
,
\
assert
ndim
>=
2
,
\
...
@@ -35,9 +33,10 @@ def add_activation_summary(x, name=None):
...
@@ -35,9 +33,10 @@ def add_activation_summary(x, name=None):
def
add_param_summary
(
summary_lists
):
def
add_param_summary
(
summary_lists
):
"""
"""
summary_lists: list of (regex, [list of action to perform])
action can be 'mean', 'scalar', 'histogram', 'sparsity'
Add summary for all trainable variables matching the regex
Add summary for all trainable variables matching the regex
:param summary_lists: list of (regex, [list of action to perform]).
Action can be 'mean', 'scalar', 'histogram', 'sparsity'.
"""
"""
def
perform
(
var
,
action
):
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
ndim
=
var
.
get_shape
()
.
ndims
...
@@ -67,10 +66,12 @@ def add_param_summary(summary_lists):
...
@@ -67,10 +66,12 @@ def add_param_summary(summary_lists):
for
act
in
actions
:
for
act
in
actions
:
perform
(
p
,
act
)
perform
(
p
,
act
)
# TODO use name of cost_var
def
summary_moving_average
(
cost_var
):
def
summary_moving_average
(
cost_var
):
""" Create a MovingAverage op and summary for all variables in
""" Create a MovingAverage op and summary for all variables in
MOVING_SUMMARY_VARS_KEY, as well as the argument
MOVING_SUMMARY_VARS_KEY, as well as `cost_var`.
Return a op to maintain these average
:returns: a op to maintain these average.
"""
"""
global_step_var
=
get_global_step_var
()
global_step_var
=
get_global_step_var
()
averager
=
tf
.
train
.
ExponentialMovingAverage
(
averager
=
tf
.
train
.
ExponentialMovingAverage
(
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
a4371695
...
@@ -6,6 +6,11 @@ import tensorflow as tf
...
@@ -6,6 +6,11 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
def
one_hot
(
y
,
num_labels
):
def
one_hot
(
y
,
num_labels
):
"""
:param y: prediction. an Nx1 int tensor.
:param num_labels: an int. number of output classes
:returns: an NxC onehot matrix.
"""
with
tf
.
op_scope
([
y
,
num_labels
],
'one_hot'
):
with
tf
.
op_scope
([
y
,
num_labels
],
'one_hot'
):
batch_size
=
tf
.
size
(
y
)
batch_size
=
tf
.
size
(
y
)
y
=
tf
.
expand_dims
(
y
,
1
)
y
=
tf
.
expand_dims
(
y
,
1
)
...
@@ -18,9 +23,9 @@ def one_hot(y, num_labels):
...
@@ -18,9 +23,9 @@ def one_hot(y, num_labels):
def
prediction_incorrect
(
logits
,
label
):
def
prediction_incorrect
(
logits
,
label
):
"""
"""
logits: batchxN
:param logits: NxC
label: batch
:param label: N
return a binary vector with 1 means
incorrect prediction
:returns: a binary vector of length N with 1 meaning
incorrect prediction
"""
"""
with
tf
.
op_scope
([
logits
,
label
],
'incorrect'
):
with
tf
.
op_scope
([
logits
,
label
],
'incorrect'
):
wrong
=
tf
.
not_equal
(
wrong
=
tf
.
not_equal
(
...
@@ -30,13 +35,24 @@ def prediction_incorrect(logits, label):
...
@@ -30,13 +35,24 @@ def prediction_incorrect(logits, label):
return
wrong
return
wrong
def
flatten
(
x
):
def
flatten
(
x
):
"""
Flatten the tensor.
"""
return
tf
.
reshape
(
x
,
[
-
1
])
return
tf
.
reshape
(
x
,
[
-
1
])
def
batch_flatten
(
x
):
def
batch_flatten
(
x
):
"""
Flatten the tensor except the first dimension.
"""
total_dim
=
np
.
prod
(
x
.
get_shape
()[
1
:]
.
as_list
())
total_dim
=
np
.
prod
(
x
.
get_shape
()[
1
:]
.
as_list
())
return
tf
.
reshape
(
x
,
[
-
1
,
total_dim
])
return
tf
.
reshape
(
x
,
[
-
1
,
total_dim
])
def
logSoftmax
(
x
):
def
logSoftmax
(
x
):
"""
Batch log softmax.
:param x: NxC tensor.
:returns: NxC tensor.
"""
with
tf
.
op_scope
([
x
],
'logSoftmax'
):
with
tf
.
op_scope
([
x
],
'logSoftmax'
):
z
=
x
-
tf
.
reduce_max
(
x
,
1
,
keep_dims
=
True
)
z
=
x
-
tf
.
reduce_max
(
x
,
1
,
keep_dims
=
True
)
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
...
...
tensorpack/train/base.py
View file @
a4371695
...
@@ -17,11 +17,14 @@ from ..tfutils.modelutils import describe_model
...
@@ -17,11 +17,14 @@ from ..tfutils.modelutils import describe_model
__all__
=
[
'Trainer'
]
__all__
=
[
'Trainer'
]
class
Trainer
(
object
):
class
Trainer
(
object
):
"""
Base class for a trainer.
"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
C
onfig: a `TrainConfig` instance
:param c
onfig: a `TrainConfig` instance
"""
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
config
=
config
...
@@ -29,10 +32,12 @@ class Trainer(object):
...
@@ -29,10 +32,12 @@ class Trainer(object):
@
abstractmethod
@
abstractmethod
def
train
(
self
):
def
train
(
self
):
""" Start training"""
pass
pass
@
abstractmethod
@
abstractmethod
def
run_step
(
self
):
def
run_step
(
self
):
""" run an iteration"""
pass
pass
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
...
...
tensorpack/train/config.py
View file @
a4371695
...
@@ -18,21 +18,20 @@ class TrainConfig(object):
...
@@ -18,21 +18,20 @@ class TrainConfig(object):
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
Args:
:param dataset: the dataset to train. a `DataFlow` instance.
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. It has to contain a
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
SummaryWriter and a PeriodicSaver
:param session_config: a `tf.ConfigProto` instance to instantiate the
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session. default to a session running 1 GPU.
:param session_init: a `sessinit.SessionInit` instance to
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
initialize variables of a session. default to a new session.
:param model: a `ModelDesc` instance.j
model: a ModelDesc instance
:param starting_epoch: int. default to be 1.
starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to 100
max_epoch: maximum number of epoch to run training. default to 100
:param nr_tower: int. number of towers. default to 1.
nr_tower: int. number of towers. default to 1.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
tensorpack/utils/__init__.py
View file @
a4371695
...
@@ -10,10 +10,10 @@ Common utils.
...
@@ -10,10 +10,10 @@ Common utils.
These utils should be irrelevant to tensorflow.
These utils should be irrelevant to tensorflow.
"""
"""
def
global_import
(
name
):
def
_
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'naming'
)
_
global_import
(
'naming'
)
global_import
(
'utils'
)
_
global_import
(
'utils'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment