Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
847fae12
Commit
847fae12
authored
Nov 26, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up deprecation
parent
7b33a43c
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
44 additions
and
144 deletions
+44
-144
docs/conf.py
docs/conf.py
+4
-9
docs/tutorial/efficient-dataflow.md
docs/tutorial/efficient-dataflow.md
+8
-2
docs/tutorial/extend/dataflow.md
docs/tutorial/extend/dataflow.md
+10
-4
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+2
-0
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+6
-6
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-6
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+2
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-21
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+3
-80
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-1
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+2
-0
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+1
-15
No files found.
docs/conf.py
View file @
847fae12
...
@@ -369,37 +369,32 @@ def process_signature(app, what, name, obj, options, signature,
...
@@ -369,37 +369,32 @@ def process_signature(app, what, name, obj, options, signature,
_DEPRECATED_NAMES
=
set
([
_DEPRECATED_NAMES
=
set
([
# deprecated stuff:
# deprecated stuff:
'TryResumeTraining'
,
'QueueInputTrainer'
,
'QueueInputTrainer'
,
'SimplePredictBuilder'
,
'SimplePredictBuilder'
,
'LMDBDataPoint'
,
'LMDBDataPoint'
,
'TFRecordData'
,
'TFRecordData'
,
'dump_dataflow_to_lmdb'
,
'dump_dataflow_to_lmdb'
,
'dump_dataflow_to_tfrecord'
,
'dump_dataflow_to_tfrecord'
,
'pyplot2img'
,
'IntBox'
,
'FloatBox'
,
'IntBox'
,
'FloatBox'
,
'PrefetchOnGPUs'
,
# renamed stuff:
# renamed stuff:
'DumpTensor'
,
'DumpTensor'
,
'DumpParamAsImage'
,
'DumpParamAsImage'
,
'StagingInputWrapper'
,
'PeriodicRunHooks'
,
'PeriodicRunHooks'
,
'get_nr_gpu'
,
'get_nr_gpu'
,
'start_test'
,
# TestDataSpeed
# deprecated or renamed symbolic code
# deprecated or renamed symbolic code
'ImageSample'
,
'ImageSample'
,
'BilinearUpSample'
'BilinearUpSample'
'Deconv2D'
,
'Deconv2D'
,
'psnr'
,
'get_scalar_var'
,
'psnr'
,
'prediction_incorrect'
,
'huber_loss'
,
# internal only
# internal only
'SessionUpdate'
,
'SessionUpdate'
,
'apply_default_prefetch'
,
'average_grads'
,
'average_grads'
,
'aggregate_grads'
,
'aggregate_grads'
,
'allreduce_grads'
,
'allreduce_grads'
,
'PrefetchOnGPUs'
,
])
])
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
...
@@ -414,7 +409,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -414,7 +409,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# Hide some names that are deprecated or not intended to be used
# Hide some names that are deprecated or not intended to be used
if
name
in
_DEPRECATED_NAMES
:
if
name
in
_DEPRECATED_NAMES
:
return
True
return
True
if
name
in
[
'
get_data'
,
'size
'
,
'reset_state'
]:
if
name
in
[
'
__iter__'
,
'__len__
'
,
'reset_state'
]:
# skip these methods with empty docstring
# skip these methods with empty docstring
if
not
obj
.
__doc__
and
inspect
.
isfunction
(
obj
):
if
not
obj
.
__doc__
and
inspect
.
isfunction
(
obj
):
# https://stackoverflow.com/questions/3589311/get-defining-class-of-unbound-method-object-in-python-3
# https://stackoverflow.com/questions/3589311/get-defining-class-of-unbound-method-object-in-python-3
...
...
docs/tutorial/efficient-dataflow.md
View file @
847fae12
...
@@ -6,6 +6,8 @@ a __Python generator__ which yields preprocessed ImageNet images and labels as f
...
@@ -6,6 +6,8 @@ a __Python generator__ which yields preprocessed ImageNet images and labels as f
Since it is simply a generator interface, you can use the DataFlow in any Python-based frameworks (e.g. PyTorch, Keras)
Since it is simply a generator interface, you can use the DataFlow in any Python-based frameworks (e.g. PyTorch, Keras)
or your own code as well.
or your own code as well.
**What we are going to do**
: We'll use ILSVRC12 dataset, which contains 1.28 million images.
**What we are going to do**
: We'll use ILSVRC12 dataset, which contains 1.28 million images.
The original images (JPEG compressed) are 140G in total.
The original images (JPEG compressed) are 140G in total.
The average resolution is about 400x350
<sup>
[[1]]
</sup>
.
The average resolution is about 400x350
<sup>
[[1]]
</sup>
.
...
@@ -27,8 +29,12 @@ Some things to know before reading:
...
@@ -27,8 +29,12 @@ Some things to know before reading:
But in validation we often need the exact set of data, to be able to compute a correct and comparable score.
But in validation we often need the exact set of data, to be able to compute a correct and comparable score.
This will affect how we build the DataFlow.
This will affect how we build the DataFlow.
4.
The actual performance would depend on not only the disk, but also memory (for caching) and CPU (for data processing).
4.
The actual performance would depend on not only the disk, but also memory (for caching) and CPU (for data processing).
You may need to tune the parameters (#processes, #threads, size of buffer, etc.)
You may need to tune the parameters (#processes, #threads, size of buffer, etc.)
or change the pipeline for new tasks and new machines to achieve the best performance.
or change the pipeline for new tasks and new machines to achieve the best performance.
The solutions in this tutorial may not help you.
To improve your own DataFlow, read the
[
performance tuning tutorial
](
performance-tuning.html#investigate-dataflow
)
before doing any optimizations.
The benchmark code for this tutorial can be found in
[
tensorpack/benchmarks
](
https://github.com/tensorpack/benchmarks/tree/master/ImageNet
)
,
The benchmark code for this tutorial can be found in
[
tensorpack/benchmarks
](
https://github.com/tensorpack/benchmarks/tree/master/ImageNet
)
,
including comparison with a similar (but simpler) pipeline built with
`tf.data`
.
including comparison with a similar (but simpler) pipeline built with
`tf.data`
.
...
...
docs/tutorial/extend/dataflow.md
View file @
847fae12
...
@@ -12,11 +12,11 @@ and then compose it with existing modules (e.g. mapping, batching, prefetching,
...
@@ -12,11 +12,11 @@ and then compose it with existing modules (e.g. mapping, batching, prefetching,
The easiest way to create a DataFlow to load custom data, is to wrap a custom generator, e.g.:
The easiest way to create a DataFlow to load custom data, is to wrap a custom generator, e.g.:
```
python
```
python
def
my_data_loader
():
def
my_data_loader
():
while
True
:
# load data from somewhere with Python, and yield them
# load data from somewhere with Python
for
k
in
range
(
100
):
yield
[
my_array
,
my_label
]
yield
[
my_array
,
my_label
]
d
ataflow
=
DataFromGenerator
(
my_data_loader
)
d
f
=
DataFromGenerator
(
my_data_loader
)
```
```
To write more complicated DataFlow, you need to inherit the base
`DataFlow`
class.
To write more complicated DataFlow, you need to inherit the base
`DataFlow`
class.
...
@@ -24,6 +24,7 @@ Usually, you just need to implement the `__iter__()` method which yields a datap
...
@@ -24,6 +24,7 @@ Usually, you just need to implement the `__iter__()` method which yields a datap
```
python
```
python
class
MyDataFlow
(
DataFlow
):
class
MyDataFlow
(
DataFlow
):
def
__iter__
(
self
):
def
__iter__
(
self
):
# load data from somewhere with Python, and yield them
for
k
in
range
(
100
):
for
k
in
range
(
100
):
digit
=
np
.
random
.
rand
(
28
,
28
)
digit
=
np
.
random
.
rand
(
28
,
28
)
label
=
np
.
random
.
randint
(
10
)
label
=
np
.
random
.
randint
(
10
)
...
@@ -38,6 +39,8 @@ for datapoint in df:
...
@@ -38,6 +39,8 @@ for datapoint in df:
Optionally, you can implement the
`__len__`
and
`reset_state`
method.
Optionally, you can implement the
`__len__`
and
`reset_state`
method.
The detailed semantics of these three methods are explained
The detailed semantics of these three methods are explained
in the
[
API documentation
](
../../modules/dataflow.html#tensorpack.dataflow.DataFlow
)
.
in the
[
API documentation
](
../../modules/dataflow.html#tensorpack.dataflow.DataFlow
)
.
If you're writing a complicated DataFlow, make sure to read the API documentation
for the semantics.
DataFlow implementations for several well-known datasets are provided in the
DataFlow implementations for several well-known datasets are provided in the
[
dataflow.dataset
](
../../modules/dataflow.dataset.html
)
[
dataflow.dataset
](
../../modules/dataflow.dataset.html
)
...
@@ -52,9 +55,12 @@ processing on top of the source DataFlow, e.g.:
...
@@ -52,9 +55,12 @@ processing on top of the source DataFlow, e.g.:
class
ProcessingDataFlow
(
DataFlow
):
class
ProcessingDataFlow
(
DataFlow
):
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
self
.
ds
=
ds
self
.
ds
=
ds
def
reset_state
(
self
):
self
.
ds
.
reset_state
()
def
__iter__
(
self
):
def
__iter__
(
self
):
for
datapoint
in
self
.
ds
.
get_data
()
:
for
datapoint
in
self
.
ds
:
# do something
# do something
yield
new_datapoint
yield
new_datapoint
```
```
...
...
tensorpack/callbacks/trigger.py
View file @
847fae12
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
.base
import
ProxyCallback
,
Callback
from
.base
import
ProxyCallback
,
Callback
from
..utils.develop
import
log_deprecated
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
,
'EnableCallbackIf'
]
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
,
'EnableCallbackIf'
]
...
@@ -77,6 +78,7 @@ class PeriodicRunHooks(ProxyCallback):
...
@@ -77,6 +78,7 @@ class PeriodicRunHooks(ProxyCallback):
"""
"""
self
.
_every_k_steps
=
int
(
every_k_steps
)
self
.
_every_k_steps
=
int
(
every_k_steps
)
super
(
PeriodicRunHooks
,
self
)
.
__init__
(
callback
)
super
(
PeriodicRunHooks
,
self
)
.
__init__
(
callback
)
log_deprecated
(
"PeriodicRunHooks"
,
"Use PeriodicCallback instead!"
,
"2019-02-28"
)
def
_before_run
(
self
,
ctx
):
def
_before_run
(
self
,
ctx
):
if
self
.
global_step
%
self
.
_every_k_steps
==
0
:
if
self
.
global_step
%
self
.
_every_k_steps
==
0
:
...
...
tensorpack/dataflow/base.py
View file @
847fae12
...
@@ -65,7 +65,7 @@ class DataFlow(object):
...
@@ -65,7 +65,7 @@ class DataFlow(object):
"""
"""
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time.
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time.
Each element in the list should be either a number or a numpy array.
Each element in the list should be either a number or a numpy array.
For now, tensorpack also
partially
supports dict instead of list.
For now, tensorpack also
**partially**
supports dict instead of list.
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
...
@@ -107,7 +107,7 @@ class DataFlow(object):
...
@@ -107,7 +107,7 @@ class DataFlow(object):
it yourself, especially when using data-parallel trainer.
it yourself, especially when using data-parallel trainer.
+ The length of progress bar when processing a dataflow.
+ The length of progress bar when processing a dataflow.
+ Used by :class:`InferenceRunner` to get the number of iterations in inference.
+ Used by :class:`InferenceRunner` to get the number of iterations in inference.
In this case users are
responsible
for making sure that :meth:`__len__` is accurate.
In this case users are
**responsible**
for making sure that :meth:`__len__` is accurate.
This is to guarantee that inference is run on a fixed set of images.
This is to guarantee that inference is run on a fixed set of images.
Returns:
Returns:
...
@@ -127,11 +127,11 @@ class DataFlow(object):
...
@@ -127,11 +127,11 @@ class DataFlow(object):
by the **process that uses the dataflow** before :meth:`__iter__` is called.
by the **process that uses the dataflow** before :meth:`__iter__` is called.
The caller thread of this method should stay alive to keep this dataflow alive.
The caller thread of this method should stay alive to keep this dataflow alive.
* It is meant for
initialization works that involve
processes,
* It is meant for
certain initialization that involves
processes,
e.g., initialize random number generator (RNG), create worker processes.
e.g., initialize random number generator
s
(RNG), create worker processes.
Because it's very common to use RNG in data processing,
Because it's very common to use RNG in data processing,
developers of dataflow can also subclass :class:`RNGDataFlow` to
simplify the work
.
developers of dataflow can also subclass :class:`RNGDataFlow` to
have easier access to an RNG
.
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs.
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs.
...
@@ -158,7 +158,7 @@ class RNGDataFlow(DataFlow):
...
@@ -158,7 +158,7 @@ class RNGDataFlow(DataFlow):
class
ProxyDataFlow
(
DataFlow
):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another.
""" Base class for DataFlow that proxies another.
Every method is proxied to ``self.ds`` unless override
by
subclass.
Every method is proxied to ``self.ds`` unless override
n by a
subclass.
"""
"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
...
...
tensorpack/dataflow/common.py
View file @
847fae12
...
@@ -43,6 +43,10 @@ class TestDataSpeed(ProxyDataFlow):
...
@@ -43,6 +43,10 @@ class TestDataSpeed(ProxyDataFlow):
yield
dp
yield
dp
def
start_test
(
self
):
def
start_test
(
self
):
log_deprecated
(
"TestDataSpeed.start_test() was renamed to start()"
,
"2019-03-30"
)
self
.
start
()
def
start
(
self
):
"""
"""
Start testing with a progress bar.
Start testing with a progress bar.
"""
"""
...
@@ -59,12 +63,6 @@ class TestDataSpeed(ProxyDataFlow):
...
@@ -59,12 +63,6 @@ class TestDataSpeed(ProxyDataFlow):
if
idx
==
self
.
test_size
-
1
:
if
idx
==
self
.
test_size
-
1
:
break
break
def
start
(
self
):
"""
Alias of start_test.
"""
self
.
start_test
()
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
"""
"""
...
...
tensorpack/dataflow/parallel.py
View file @
847fae12
...
@@ -21,6 +21,7 @@ from ..utils.concurrency import (ensure_proc_terminate,
...
@@ -21,6 +21,7 @@ from ..utils.concurrency import (ensure_proc_terminate,
from
..utils.serialize
import
loads
,
dumps
from
..utils.serialize
import
loads
,
dumps
from
..utils
import
logger
from
..utils
import
logger
from
..utils.gpu
import
change_gpu
from
..utils.gpu
import
change_gpu
from
..utils.develop
import
log_deprecated
__all__
=
[
'PrefetchData'
,
'MultiProcessPrefetchData'
,
__all__
=
[
'PrefetchData'
,
'MultiProcessPrefetchData'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
'MultiThreadPrefetchData'
]
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
'MultiThreadPrefetchData'
]
...
@@ -339,6 +340,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
...
@@ -339,6 +340,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
ds (DataFlow): input DataFlow.
ds (DataFlow): input DataFlow.
gpus (list[int]): list of GPUs to use. Will also start this number of processes.
gpus (list[int]): list of GPUs to use. Will also start this number of processes.
"""
"""
log_deprecated
(
"PrefetchOnGPUs"
,
"It does not seem useful, and please implement it yourself."
,
"2019-02-28"
)
self
.
gpus
=
gpus
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
))
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
))
...
...
tensorpack/tfutils/sessinit.py
View file @
847fae12
...
@@ -2,20 +2,18 @@
...
@@ -2,20 +2,18 @@
# File: sessinit.py
# File: sessinit.py
import
os
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
.common
import
get_op_tensor_name
from
.common
import
get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'ChainInit'
,
__all__
=
[
'SessionInit'
,
'ChainInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'JustCurrentSession'
,
'get_model_loader'
,
'TryResumeTraining'
]
'JustCurrentSession'
,
'get_model_loader'
]
class
SessionInit
(
object
):
class
SessionInit
(
object
):
...
@@ -260,21 +258,3 @@ def get_model_loader(filename):
...
@@ -260,21 +258,3 @@ def get_model_loader(filename):
return
DictRestore
(
dict
(
obj
))
return
DictRestore
(
dict
(
obj
))
else
:
else
:
return
SaverRestore
(
filename
)
return
SaverRestore
(
filename
)
@
deprecated
(
"It's better to write the logic yourself or use AutoResumeTrainConfig!"
,
"2018-07-01"
)
def
TryResumeTraining
():
"""
Try loading latest checkpoint from ``logger.get_logger_dir()``, only if there is one.
Actually not very useful... better to write your own one.
Returns:
SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`.
"""
if
not
logger
.
get_logger_dir
():
return
JustCurrentSession
()
path
=
os
.
path
.
join
(
logger
.
get_logger_dir
(),
'checkpoint'
)
if
not
tf
.
gfile
.
Exists
(
path
):
return
JustCurrentSession
()
logger
.
info
(
"Found checkpoint at {}."
.
format
(
path
))
return
SaverRestore
(
path
)
tensorpack/tfutils/symbolic_functions.py
View file @
847fae12
...
@@ -3,35 +3,10 @@
...
@@ -3,35 +3,10 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
__all__
=
[
'get_scalar_var'
,
'prediction_incorrect'
,
'flatten'
,
'batch_flatten'
,
'print_stat'
,
'rms'
,
'huber_loss'
]
__all__
=
[
'print_stat'
,
'rms'
]
# this function exists for backwards-compatibility
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
@
deprecated
(
"Please implement it yourself!"
,
"2018-08-01"
)
def
flatten
(
x
):
"""
Flatten the tensor.
"""
return
tf
.
reshape
(
x
,
[
-
1
])
@
deprecated
(
"Please implement it yourself!"
,
"2018-08-01"
)
def
batch_flatten
(
x
):
"""
Flatten the tensor except the first dimension.
"""
shape
=
x
.
get_shape
()
.
as_list
()[
1
:]
if
None
not
in
shape
:
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
tf
.
stack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
def
print_stat
(
x
,
message
=
None
):
def
print_stat
(
x
,
message
=
None
):
...
@@ -47,8 +22,7 @@ def print_stat(x, message=None):
...
@@ -47,8 +22,7 @@ def print_stat(x, message=None):
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
# after deprecated, keep it for internal use only
# for internal use only
# @deprecated("Please implement it yourself!", "2018-08-01")
def
rms
(
x
,
name
=
None
):
def
rms
(
x
,
name
=
None
):
"""
"""
Returns:
Returns:
...
@@ -61,58 +35,7 @@ def rms(x, name=None):
...
@@ -61,58 +35,7 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
@
deprecated
(
"Please use tf.losses.huber_loss instead!"
,
"2018-08-01"
)
# don't hurt to leave it here
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
r"""
Huber loss of x.
.. math::
y = \begin{cases} \frac{x^2}{2}, & |x| < \delta \\
\delta |x| - \frac{\delta^2}{2}, & |x| \ge \delta
\end{cases}
Args:
x: the difference vector.
delta (float):
Returns:
a tensor of the same shape of x.
"""
with
tf
.
name_scope
(
'huber_loss'
):
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
cond
=
abscost
<
delta
l2
=
sqrcost
*
0.5
l1
=
abscost
*
delta
-
0.5
*
delta
**
2
return
tf
.
where
(
cond
,
l2
,
l1
,
name
=
name
)
# TODO deprecate this in the future
# doesn't hurt to keep it here for now
@
deprecated
(
"Simply use tf.get_variable instead!"
,
"2018-08-01"
)
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
"""
Get a scalar float variable with certain initial value.
You can just call `tf.get_variable(name, initializer=init_value, trainable=False)` instead.
Args:
name (str): name of the variable.
init_value (float): initial value.
summary (bool): whether to summary this variable.
trainable (bool): trainable or not.
Returns:
tf.Variable: the variable
"""
ret
=
tf
.
get_variable
(
name
,
initializer
=
float
(
init_value
),
trainable
=
trainable
)
if
summary
:
# this is recognized in callbacks.StatHolder
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ret
)
return
ret
@
deprecated
(
"Please implement it by yourself."
,
"2018-04-28"
)
@
deprecated
(
"Please implement it by yourself."
,
"2018-04-28"
)
def
psnr
(
prediction
,
ground_truth
,
maxp
=
None
,
name
=
'psnr'
):
def
psnr
(
prediction
,
ground_truth
,
maxp
=
None
,
name
=
'psnr'
):
"""`Peek Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
"""`Peek Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
...
...
tensorpack/train/interface.py
View file @
847fae12
...
@@ -12,7 +12,7 @@ from .config import TrainConfig
...
@@ -12,7 +12,7 @@ from .config import TrainConfig
from
.tower
import
SingleCostTrainer
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
from
.trainers
import
SimpleTrainer
__all__
=
[
'launch_train_with_config'
,
'apply_default_prefetch'
]
__all__
=
[
'launch_train_with_config'
]
def
apply_default_prefetch
(
input_source_or_dataflow
,
trainer
):
def
apply_default_prefetch
(
input_source_or_dataflow
,
trainer
):
...
...
tensorpack/utils/rect.py
View file @
847fae12
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
numpy
as
np
import
numpy
as
np
from
.develop
import
log_deprecated
__all__
=
[
'IntBox'
,
'FloatBox'
]
__all__
=
[
'IntBox'
,
'FloatBox'
]
...
@@ -11,6 +12,7 @@ class BoxBase(object):
...
@@ -11,6 +12,7 @@ class BoxBase(object):
__slots__
=
[
'x1'
,
'y1'
,
'x2'
,
'y2'
]
__slots__
=
[
'x1'
,
'y1'
,
'x2'
,
'y2'
]
def
__init__
(
self
,
x1
,
y1
,
x2
,
y2
):
def
__init__
(
self
,
x1
,
y1
,
x2
,
y2
):
log_deprecated
(
"IntBox and FloatBox"
,
"Please implement them by your own."
,
"2019-02-28"
)
self
.
x1
=
x1
self
.
x1
=
x1
self
.
y1
=
y1
self
.
y1
=
y1
self
.
x2
=
x2
self
.
x2
=
x2
...
...
tensorpack/utils/viz.py
View file @
847fae12
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
sys
import
sys
import
io
from
.fs
import
mkdir_p
from
.fs
import
mkdir_p
from
.argtools
import
shape2d
from
.argtools
import
shape2d
from
.palette
import
PALETTE_RGB
from
.palette
import
PALETTE_RGB
...
@@ -16,24 +15,12 @@ except ImportError:
...
@@ -16,24 +15,12 @@ except ImportError:
pass
pass
__all__
=
[
'
pyplot2img'
,
'
interactive_imshow'
,
__all__
=
[
'interactive_imshow'
,
'stack_patches'
,
'gen_stack_patches'
,
'stack_patches'
,
'gen_stack_patches'
,
'dump_dataflow_images'
,
'intensity_to_rgb'
,
'dump_dataflow_images'
,
'intensity_to_rgb'
,
'draw_boxes'
]
'draw_boxes'
]
def
pyplot2img
(
plt
):
""" Convert a pyplot instance to image """
buf
=
io
.
BytesIO
()
plt
.
axis
(
'off'
)
plt
.
savefig
(
buf
,
format
=
'png'
,
bbox_inches
=
'tight'
,
pad_inches
=
0
)
buf
.
seek
(
0
)
rawbuf
=
np
.
frombuffer
(
buf
.
getvalue
(),
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
rawbuf
,
cv2
.
IMREAD_COLOR
)
buf
.
close
()
return
im
def
interactive_imshow
(
img
,
lclick_cb
=
None
,
rclick_cb
=
None
,
**
kwargs
):
def
interactive_imshow
(
img
,
lclick_cb
=
None
,
rclick_cb
=
None
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -428,7 +415,6 @@ from ..utils.develop import create_dummy_func # noqa
...
@@ -428,7 +415,6 @@ from ..utils.develop import create_dummy_func # noqa
try
:
try
:
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
except
(
ImportError
,
RuntimeError
):
except
(
ImportError
,
RuntimeError
):
pyplot2img
=
create_dummy_func
(
'pyplot2img'
,
'matplotlib'
)
# noqa
intensity_to_rgb
=
create_dummy_func
(
'intensity_to_rgb'
,
'matplotlib'
)
# noqa
intensity_to_rgb
=
create_dummy_func
(
'intensity_to_rgb'
,
'matplotlib'
)
# noqa
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment