Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fb2a051c
Commit
fb2a051c
authored
Jan 02, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
run autopep8 over tensorpack/
parent
59553585
Changes
100
Show whitespace changes
Inline
Side-by-side
Showing
100 changed files
with
1047 additions
and
532 deletions
+1047
-532
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+2
-1
tensorpack/RL/common.py
tensorpack/RL/common.py
+9
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+13
-1
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+50
-45
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+4
-2
tensorpack/RL/history.py
tensorpack/RL/history.py
+2
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+22
-5
tensorpack/__init__.py
tensorpack/__init__.py
+1
-1
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+2
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+5
-1
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+4
-2
tensorpack/callbacks/dispatcher.py
tensorpack/callbacks/dispatcher.py
+2
-0
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+2
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-0
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+7
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+17
-12
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+24
-10
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+14
-10
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+13
-4
tensorpack/dataflow/__init__.py
tensorpack/dataflow/__init__.py
+2
-2
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+4
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+28
-4
tensorpack/dataflow/dataset/__init__.py
tensorpack/dataflow/dataset/__init__.py
+2
-1
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+5
-3
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+16
-7
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+13
-8
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+12
-6
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+2
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+4
-3
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+6
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+9
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+22
-8
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+6
-2
tensorpack/dataflow/imgaug/__init__.py
tensorpack/dataflow/imgaug/__init__.py
+1
-1
tensorpack/dataflow/imgaug/_test.py
tensorpack/dataflow/imgaug/_test.py
+4
-4
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+5
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+26
-16
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+18
-12
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+22
-18
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+21
-8
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+11
-2
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+5
-0
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+13
-6
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+10
-3
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+11
-2
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+8
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+4
-2
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+14
-12
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+2
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+9
-4
tensorpack/models/_test.py
tensorpack/models/_test.py
+3
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+19
-16
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+10
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-1
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+31
-26
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+14
-6
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+5
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+30
-19
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-1
tensorpack/models/shapes.py
tensorpack/models/shapes.py
+1
-0
tensorpack/models/softmax.py
tensorpack/models/softmax.py
+2
-1
tensorpack/predict/__init__.py
tensorpack/predict/__init__.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+22
-12
tensorpack/predict/common.py
tensorpack/predict/common.py
+6
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+18
-10
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+18
-10
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+2
-0
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+13
-3
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+12
-1
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+14
-4
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+6
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+21
-12
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+4
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+11
-4
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+7
-4
tensorpack/train/config.py
tensorpack/train/config.py
+3
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+28
-21
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+26
-13
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+22
-16
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+9
-4
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+2
-3
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+20
-11
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+13
-2
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+5
-3
tensorpack/utils/discretize.py
tensorpack/utils/discretize.py
+14
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+6
-2
tensorpack/utils/globvars.py
tensorpack/utils/globvars.py
+3
-1
tensorpack/utils/gpu.py
tensorpack/utils/gpu.py
+3
-1
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+11
-9
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+14
-4
tensorpack/utils/lut.py
tensorpack/utils/lut.py
+3
-1
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+4
-3
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+4
-2
tensorpack/utils/stats.py
tensorpack/utils/stats.py
+11
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+6
-1
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+22
-14
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+26
-17
No files found.
tensorpack/RL/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,8 @@ import os
...
@@ -8,6 +8,8 @@ import os
import
os.path
import
os.path
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
...
@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/RL/common.py
View file @
fb2a051c
...
@@ -11,12 +11,14 @@ from .envbase import ProxyPlayer
...
@@ -11,12 +11,14 @@ from .envbase import ProxyPlayer
__all__
=
[
'PreventStuckPlayer'
,
'LimitLengthPlayer'
,
'AutoRestartPlayer'
,
__all__
=
[
'PreventStuckPlayer'
,
'LimitLengthPlayer'
,
'AutoRestartPlayer'
,
'MapPlayerState'
]
'MapPlayerState'
]
class
PreventStuckPlayer
(
ProxyPlayer
):
class
PreventStuckPlayer
(
ProxyPlayer
):
""" Prevent the player from getting stuck (repeating a no-op)
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
where the agent needs to press the 'start' button to start playing.
"""
"""
# TODO hash the state as well?
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
"""
It does auto-reset, but doesn't auto-restart the underlying player.
It does auto-reset, but doesn't auto-restart the underlying player.
...
@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
self
.
act_que
.
clear
()
self
.
act_que
.
clear
()
class
LimitLengthPlayer
(
ProxyPlayer
):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
""" Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout
Will auto restart the underlying player on timeout
"""
"""
def
__init__
(
self
,
player
,
limit
):
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
limit
=
limit
...
@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
...
@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
self
.
cnt
=
0
self
.
cnt
=
0
class
AutoRestartPlayer
(
ProxyPlayer
):
class
AutoRestartPlayer
(
ProxyPlayer
):
""" Auto-restart the player on episode ends,
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
in case some player wasn't designed to do so. """
def
action
(
self
,
act
):
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
r
,
isOver
=
self
.
player
.
action
(
act
)
if
isOver
:
if
isOver
:
...
@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
...
@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
return
r
,
isOver
return
r
,
isOver
class
MapPlayerState
(
ProxyPlayer
):
class
MapPlayerState
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
func
):
def
__init__
(
self
,
player
,
func
):
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
self
.
func
=
func
self
.
func
=
func
...
...
tensorpack/RL/envbase.py
View file @
fb2a051c
...
@@ -13,8 +13,10 @@ from ..utils import get_rng
...
@@ -13,8 +13,10 @@ from ..utils import get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
'DiscreteActionSpace'
]
'DiscreteActionSpace'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
RLEnvironment
(
object
):
class
RLEnvironment
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset_stat
()
self
.
reset_stat
()
...
@@ -60,13 +62,15 @@ class RLEnvironment(object):
...
@@ -60,13 +62,15 @@ class RLEnvironment(object):
s
=
self
.
current_state
()
s
=
self
.
current_state
()
act
=
func
(
s
)
act
=
func
(
s
)
r
,
isOver
=
self
.
action
(
act
)
r
,
isOver
=
self
.
action
(
act
)
#print r
#
print r
if
isOver
:
if
isOver
:
s
=
[
self
.
stats
[
k
]
for
k
in
stat
]
s
=
[
self
.
stats
[
k
]
for
k
in
stat
]
self
.
reset_stat
()
self
.
reset_stat
()
return
s
if
len
(
s
)
>
1
else
s
[
0
]
return
s
if
len
(
s
)
>
1
else
s
[
0
]
class
ActionSpace
(
object
):
class
ActionSpace
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -77,7 +81,9 @@ class ActionSpace(object):
...
@@ -77,7 +81,9 @@ class ActionSpace(object):
def
num_actions
(
self
):
def
num_actions
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
DiscreteActionSpace
(
ActionSpace
):
class
DiscreteActionSpace
(
ActionSpace
):
def
__init__
(
self
,
num
):
def
__init__
(
self
,
num
):
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
self
.
num
=
num
self
.
num
=
num
...
@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
...
@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
def
__str__
(
self
):
def
__str__
(
self
):
return
"DiscreteActionSpace({})"
.
format
(
self
.
num
)
return
"DiscreteActionSpace({})"
.
format
(
self
.
num
)
class
NaiveRLEnvironment
(
RLEnvironment
):
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
""" for testing only"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
k
=
0
self
.
k
=
0
def
current_state
(
self
):
def
current_state
(
self
):
self
.
k
+=
1
self
.
k
+=
1
return
self
.
k
return
self
.
k
def
action
(
self
,
act
):
def
action
(
self
,
act
):
self
.
k
=
act
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
""" Serve as a proxy another player """
def
__init__
(
self
,
player
):
def
__init__
(
self
,
player
):
self
.
player
=
player
self
.
player
=
player
...
...
tensorpack/RL/expreplay.py
View file @
fb2a051c
...
@@ -10,7 +10,7 @@ import six
...
@@ -10,7 +10,7 @@ import six
from
six.moves
import
queue
from
six.moves
import
queue
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
,
get_rng
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
...
@@ -19,6 +19,7 @@ __all__ = ['ExpReplay']
...
@@ -19,6 +19,7 @@ __all__ = ['ExpReplay']
Experience
=
namedtuple
(
'Experience'
,
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
ExpReplay
(
DataFlow
,
Callback
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
"""
Implement experience replay in the paper
Implement experience replay in the paper
...
@@ -27,6 +28,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -27,6 +28,7 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as an DataFlow.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
predictor_io_names
,
predictor_io_names
,
player
,
player
,
...
@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
""" populate a transition by epsilon-greedy"""
#if len(self.mem):
#
if len(self.mem):
#
from copy import deepcopy # quickly fill the memory for debug
#
from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem[0]))
#
self.mem.append(deepcopy(self.mem[0]))
#
return
#
return
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
player
.
current_state
()
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
...
@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
...
@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
while
True
:
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2 # for debug
#
import cv2 # for debug
#def view_state(state, next_state):
#
def view_state(state, next_state):
#
""" for debugging state representation"""
#
""" for debugging state representation"""
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([r, r2], axis=0)
#
r = np.concatenate([r, r2], axis=0)
#
print r.shape
#
print r.shape
#
cv2.imshow("state", r)
#
cv2.imshow("state", r)
#
cv2.waitKey()
#
cv2.waitKey()
#exp = batch_exp[0]
#
exp = batch_exp[0]
#print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#if exp[2] or exp[4]:
#
if exp[2] or exp[4]:
#
view_state(exp[0], exp[1])
#
view_state(exp[0], exp[1])
yield
self
.
_process_batch
(
batch_exp
)
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_job_queue
.
put
(
1
)
self
.
_populate_job_queue
.
put
(
1
)
...
@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
# when x.isOver==True, (x+1).state is of a different episode
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
def
concat
(
idx
):
def
concat
(
idx
):
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
return
np
.
concatenate
(
v
,
axis
=
2
)
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
0
)
state
=
concat
(
0
)
next_state
=
concat
(
1
)
next_state
=
concat
(
1
)
...
@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
...
@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
# zero-fill state before starting
# zero-fill state before starting
zero_fill
=
False
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
for
k
in
range
(
1
,
self
.
history_len
):
if
samples
[
start_idx
-
k
]
.
isOver
:
if
samples
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
zero_fill
=
True
if
zero_fill
:
if
zero_fill
:
state
[:,
:,
-
k
-
1
]
=
0
state
[:,
:,
-
k
-
1
]
=
0
if
k
+
2
<=
self
.
history_len
:
if
k
+
2
<=
self
.
history_len
:
next_state
[:,
:,
-
k
-
2
]
=
0
next_state
[:,
:,
-
k
-
2
]
=
0
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
def
_process_batch
(
self
,
batch_exp
):
...
@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
# spawn a separate thread to run policy, can speed up 1.3x
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
def
populate_job_func
():
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
with
self
.
trainer
.
sess
.
as_default
():
...
@@ -203,10 +207,11 @@ class ExpReplay(DataFlow, Callback):
...
@@ -203,10 +207,11 @@ class ExpReplay(DataFlow, Callback):
pass
pass
self
.
player
.
reset_stat
()
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
.atari
import
AtariPlayer
from
.atari
import
AtariPlayer
import
sys
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
E
=
ExpReplay
(
predictor
,
player
=
player
,
player
=
player
,
...
@@ -216,9 +221,9 @@ if __name__ == '__main__':
...
@@ -216,9 +221,9 @@ if __name__ == '__main__':
E
.
_init_memory
()
E
.
_init_memory
()
for
k
in
E
.
get_data
():
for
k
in
E
.
get_data
():
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
pass
pass
#import IPython;
#
import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break
#
break
tensorpack/RL/gymenv.py
View file @
fb2a051c
...
@@ -9,7 +9,7 @@ from ..utils import logger
...
@@ -9,7 +9,7 @@ from ..utils import logger
try
:
try
:
import
gym
import
gym
# TODO
# TODO
#gym.undo_logger_setup()
#
gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
# not sure does it cause other problems
__all__
=
[
'GymEnv'
]
__all__
=
[
'GymEnv'
]
...
@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
...
@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
_ENV_LOCK
=
threading
.
Lock
()
_ENV_LOCK
=
threading
.
Lock
()
class
GymEnv
(
RLEnvironment
):
class
GymEnv
(
RLEnvironment
):
"""
"""
An OpenAI/gym wrapper. Can optionally auto restart.
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now
Only support discrete action space now
"""
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
with
_ENV_LOCK
:
with
_ENV_LOCK
:
self
.
gymenv
=
gym
.
make
(
name
)
self
.
gymenv
=
gym
.
make
(
name
)
...
@@ -82,7 +84,7 @@ if __name__ == '__main__':
...
@@ -82,7 +84,7 @@ if __name__ == '__main__':
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
while
True
:
while
True
:
act
=
rng
.
choice
(
range
(
num
))
act
=
rng
.
choice
(
range
(
num
))
#print act
#
print act
r
,
o
=
env
.
action
(
act
)
r
,
o
=
env
.
action
(
act
)
env
.
current_state
()
env
.
current_state
()
if
r
!=
0
or
o
:
if
r
!=
0
or
o
:
...
...
tensorpack/RL/history.py
View file @
fb2a051c
...
@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
...
@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
__all__
=
[
'HistoryFramePlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images
""" Include history frames in state, or use black images
Assume player will do auto-restart.
Assume player will do auto-restart.
"""
"""
def
__init__
(
self
,
player
,
hist_len
):
def
__init__
(
self
,
player
,
hist_len
):
"""
"""
:param hist_len: total length of the state, including the current
:param hist_len: total length of the state, including the current
...
@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
...
@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
self
.
history
.
clear
()
self
.
history
.
clear
()
self
.
history
.
append
(
self
.
player
.
current_state
())
self
.
history
.
append
(
self
.
player
.
current_state
())
tensorpack/RL/simulator.py
View file @
fb2a051c
...
@@ -34,8 +34,10 @@ except ImportError:
...
@@ -34,8 +34,10 @@ except ImportError:
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
__all__
=
[]
__all__
=
[]
class
TransitionExperience
(
object
):
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
state
=
state
...
@@ -44,6 +46,7 @@ class TransitionExperience(object):
...
@@ -44,6 +46,7 @@ class TransitionExperience(object):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SimulatorProcessBase
(
mp
.
Process
):
class
SimulatorProcessBase
(
mp
.
Process
):
...
@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to
A process that simulates a player and communicates to master to
send states and receive the next action
send states and receive the next action
"""
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
"""
"""
:param idx: idx of this process
:param idx: idx of this process
...
@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
#s2c_socket.set_hwm(5)
#
s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
s2c_socket
.
connect
(
self
.
s2c
)
state
=
player
.
current_state
()
state
=
player
.
current_state
()
...
@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
# compatibility
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
SimulatorProcess
=
SimulatorProcessStateExchange
class
SimulatorMaster
(
threading
.
Thread
):
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
""" A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
defining callbacks when a transition or an episode is finished.
"""
"""
class
ClientState
(
object
):
class
ClientState
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
self
.
memory
=
[]
# list of Experience
...
@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
...
@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
def
__del__
(
self
):
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
self
.
context
.
destroy
(
linger
=
0
)
class
SimulatorProcessDF
(
SimulatorProcessBase
):
class
SimulatorProcessDF
(
SimulatorProcessBase
):
""" A simulator which contains a forward model itself, allowing
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
it to produce data points directly """
def
__init__
(
self
,
idx
,
pipe_c2s
):
def
__init__
(
self
,
idx
,
pipe_c2s
):
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
self
.
pipe_c2s
=
pipe_c2s
self
.
pipe_c2s
=
pipe_c2s
...
@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
...
@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
def
get_data
(
self
):
def
get_data
(
self
):
pass
pass
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
""" A simulator process with an extra thread waiting for event,
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
Start me under some CUDA_VISIBLE_DEVICES set!
"""
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
condvar
=
condvar
self
.
condvar
=
condvar
...
@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
...
@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
# can be overwritten to update more params
# can be overwritten to update more params
return
tf
.
trainable_variables
()
return
tf
.
trainable_variables
()
class
WeightSync
(
Callback
):
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
self
.
shared_dic
=
shared_dic
...
@@ -260,6 +272,7 @@ class WeightSync(Callback):
...
@@ -260,6 +272,7 @@ class WeightSync(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_sync
()
self
.
_sync
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
_sync
()
self
.
_sync
()
...
@@ -274,13 +287,18 @@ class WeightSync(Callback):
...
@@ -274,13 +287,18 @@ class WeightSync(Callback):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
random
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
from
tensorpack.RL
import
NaiveRLEnvironment
class
NaiveSimulator
(
SimulatorProcess
):
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
class
NaiveActioner
(
SimulatorActioner
):
def
_get_action
(
self
,
state
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
#print("Over: ", client.memory)
client
.
memory
=
[]
client
.
memory
=
[]
...
@@ -296,4 +314,3 @@ if __name__ == '__main__':
...
@@ -296,4 +314,3 @@ if __name__ == '__main__':
import
time
import
time
time
.
sleep
(
100
)
time
.
sleep
(
100
)
tensorpack/__init__.py
View file @
fb2a051c
tensorpack/callbacks/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,8 @@ import os
...
@@ -7,6 +7,8 @@ import os
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
...
@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
continue
continue
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/callbacks/base.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ import six
...
@@ -11,6 +11,7 @@ import six
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
""" Base class for all callbacks """
...
@@ -72,7 +73,9 @@ class Callback(object):
...
@@ -72,7 +73,9 @@ class Callback(object):
def
__str__
(
self
):
def
__str__
(
self
):
return
type
(
self
)
.
__name__
return
type
(
self
)
.
__name__
class
ProxyCallback
(
Callback
):
class
ProxyCallback
(
Callback
):
def
__init__
(
self
,
cb
):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
self
.
cb
=
cb
...
@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
...
@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
def
__str__
(
self
):
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
class
PeriodicCallback
(
ProxyCallback
):
"""
"""
A callback to be triggered after every `period` epochs.
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
Doesn't work for trigger_step
"""
"""
def
__init__
(
self
,
cb
,
period
):
def
__init__
(
self
,
cb
,
period
):
"""
"""
:param cb: a `Callback`
:param cb: a `Callback`
...
@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
...
@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
def
__str__
(
self
):
def
__str__
(
self
):
return
"Periodic-"
+
str
(
self
.
cb
)
return
"Periodic-"
+
str
(
self
.
cb
)
tensorpack/callbacks/concurrency.py
View file @
fb2a051c
...
@@ -9,7 +9,9 @@ from ..utils import logger
...
@@ -9,7 +9,9 @@ from ..utils import logger
__all__
=
[
'StartProcOrThread'
]
__all__
=
[
'StartProcOrThread'
]
class
StartProcOrThread
(
Callback
):
class
StartProcOrThread
(
Callback
):
def
__init__
(
self
,
procs_threads
):
def
__init__
(
self
,
procs_threads
):
"""
"""
Start extra threads and processes before training
Start extra threads and processes before training
...
@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
...
@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
self
.
_procs_threads
=
procs_threads
self
.
_procs_threads
=
procs_threads
def
_before_train
(
self
):
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
\
logger
.
info
(
"Starting "
+
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
# avoid sigint get handled by other processes
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_procs_threads
)
start_proc_mask_signal
(
self
.
_procs_threads
)
tensorpack/callbacks/dispatcher.py
View file @
fb2a051c
...
@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
...
@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
__all__
=
[
'OutputTensorDispatcer'
]
__all__
=
[
'OutputTensorDispatcer'
]
class
OutputTensorDispatcer
(
object
):
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_names
=
[]
self
.
_idxs
=
[]
self
.
_idxs
=
[]
...
...
tensorpack/callbacks/dump.py
View file @
fb2a051c
...
@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
...
@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
__all__
=
[
'DumpParamAsImage'
]
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
class
DumpParamAsImage
(
Callback
):
"""
"""
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
"""
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
"""
:param var_name: the name of the variable.
:param var_name: the name of the variable.
...
@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
...
@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
if
self
.
clip
:
if
self
.
clip
:
res
=
np
.
clip
(
res
,
0
,
255
)
res
=
np
.
clip
(
res
,
0
,
255
)
cv2
.
imwrite
(
fname
,
res
.
astype
(
'uint8'
))
cv2
.
imwrite
(
fname
,
res
.
astype
(
'uint8'
))
tensorpack/callbacks/graph.py
View file @
fb2a051c
...
@@ -10,8 +10,10 @@ from ..utils import logger
...
@@ -10,8 +10,10 @@ from ..utils import logger
__all__
=
[
'RunOp'
]
__all__
=
[
'RunOp'
]
class
RunOp
(
Callback
):
class
RunOp
(
Callback
):
""" Run an op periodically"""
""" Run an op periodically"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
"""
:param setup_func: a function that returns the op in the graph
:param setup_func: a function that returns the op in the graph
...
@@ -34,5 +36,5 @@ class RunOp(Callback):
...
@@ -34,5 +36,5 @@ class RunOp(Callback):
if
self
.
run_epoch
:
if
self
.
run_epoch
:
self
.
_op
.
run
()
self
.
_op
.
run
()
#def _log(self):
#
def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/group.py
View file @
fb2a051c
...
@@ -12,7 +12,9 @@ from ..utils import logger
...
@@ -12,7 +12,9 @@ from ..utils import logger
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
class
CallbackTimeLogger
(
object
):
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
times
=
[]
self
.
tot
=
0
self
.
tot
=
0
...
@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
...
@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}"
.
format
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
self
.
tot
,
'; '
.
join
(
msgs
)))
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
"""
A container to hold all callbacks, and execute them in the right order and proper session.
A container to hold all callbacks, and execute them in the right order and proper session.
"""
"""
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
"""
"""
:param cbs: a list of `Callbacks`
:param cbs: a list of `Callbacks`
...
...
tensorpack/callbacks/inference.py
View file @
fb2a051c
...
@@ -16,6 +16,7 @@ from ..tfutils import get_op_var_name
...
@@ -16,6 +16,7 @@ from ..tfutils import get_op_var_name
__all__
=
[
'ClassificationError'
,
__all__
=
[
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
object
):
class
Inferencer
(
object
):
...
@@ -59,12 +60,14 @@ class Inferencer(object):
...
@@ -59,12 +60,14 @@ class Inferencer(object):
def
_get_output_tensors
(
self
):
def
_get_output_tensors
(
self
):
pass
pass
class
ScalarStats
(
Inferencer
):
class
ScalarStats
(
Inferencer
):
"""
"""
Write some scalar tensor to both stat and summary.
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
The value will be averaged over all data points in the inference dataflow.
"""
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
"""
:param names_to_print: list of names of tensors, or just a name
:param names_to_print: list of names of tensors, or just a name
...
@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
...
@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
ret
[
name
]
=
stat
ret
[
name
]
=
stat
return
ret
return
ret
class
ClassificationError
(
Inferencer
):
class
ClassificationError
(
Inferencer
):
"""
"""
Compute classification error in batch mode, from a `wrong` variable
Compute classification error in batch mode, from a `wrong` variable
...
@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
...
@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size).
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
Therefore the result is different from averaging the error rate of each batch.
"""
"""
def
__init__
(
self
,
wrong_var_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
def
__init__
(
self
,
wrong_var_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
"""
"""
:param wrong_var_name: name of the `wrong` variable
:param wrong_var_name: name of the `wrong` variable
...
@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
...
@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
def
_after_inference
(
self
):
def
_after_inference
(
self
):
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
""" Compute precision/recall in binary classification, given the
prediction vector and the label vector.
prediction vector and the label vector.
...
...
tensorpack/callbacks/inference_runner.py
View file @
fb2a051c
...
@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
...
@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
__all__
=
[
'InferenceRunner'
]
__all__
=
[
'InferenceRunner'
]
def
summary_inferencer
(
trainer
,
infs
):
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
ret
=
inf
.
after_inference
()
...
@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
...
@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
continue
continue
trainer
.
write_scalar_summary
(
k
,
v
)
trainer
.
write_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Callback
):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs different kinds of inferencer.
...
@@ -64,6 +66,7 @@ class InferenceRunner(Callback):
...
@@ -64,6 +66,7 @@ class InferenceRunner(Callback):
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
# TODO even if it works here, sparse still is unavailable
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# because get_tensor_by_name doesn't work for sparse
def
get_name
(
x
):
def
get_name
(
x
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
...
@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
...
@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
list
(
filter
(
self
.
output_tensors
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
def
find_oid
(
idxs
):
def
find_oid
(
idxs
):
ret
=
[]
ret
=
[]
for
idx
in
idxs
:
for
idx
in
idxs
:
...
@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
...
@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
def
_write_summary_after_inference
(
self
):
def
_write_summary_after_inference
(
self
):
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
class
FeedfreeInferenceRunner
(
Callback
):
class
FeedfreeInferenceRunner
(
Callback
):
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
...
@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
all_names
self
.
output_tensors
=
all_names
def
find_oid
(
idxs
):
def
find_oid
(
idxs
):
ret
=
[]
ret
=
[]
for
idx
in
idxs
:
for
idx
in
idxs
:
...
@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
# list of list of (var_name: IOTensor)
# list of list of (var_name: IOTensor)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
before_inference
()
inf
.
before_inference
()
...
@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
sz
=
self
.
_input_data
.
size
()
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
for
_
in
range
(
sz
):
#outputs = self.pred_func(dp)
#
outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
for k in tensormap]
#
for k in tensormap]
#
inf.datapoint(inf_output)
#
inf.datapoint(inf_output)
pbar
.
update
()
pbar
.
update
()
self
.
_write_summary_after_inference
()
self
.
_write_summary_after_inference
()
...
...
tensorpack/callbacks/param.py
View file @
fb2a051c
...
@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
...
@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
HyperParam
(
object
):
class
HyperParam
(
object
):
""" Base class for a hyper param"""
""" Base class for a hyper param"""
...
@@ -35,8 +37,10 @@ class HyperParam(object):
...
@@ -35,8 +37,10 @@ class HyperParam(object):
""" A name to display"""
""" A name to display"""
return
self
.
_readable_name
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
class
GraphVarParam
(
HyperParam
):
""" a variable in the graph can be a hyperparam"""
""" a variable in the graph can be a hyperparam"""
def
__init__
(
self
,
name
,
shape
=
[]):
def
__init__
(
self
,
name
,
shape
=
[]):
self
.
name
=
name
self
.
name
=
name
self
.
shape
=
shape
self
.
shape
=
shape
...
@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
...
@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
def
get_value
(
self
):
return
self
.
var
.
eval
()
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
class
ObjAttrParam
(
HyperParam
):
""" an attribute of an object can be a hyperparam"""
""" an attribute of an object can be a hyperparam"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
""" :param readable_name: default to be attrname."""
self
.
obj
=
obj
self
.
obj
=
obj
...
@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
...
@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
def
get_value
(
self
,
v
):
def
get_value
(
self
,
v
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Callback
):
class
HyperParamSetter
(
Callback
):
"""
"""
Base class to set hyperparameters after every epoch.
Base class to set hyperparameters after every epoch.
...
@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
...
@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
if
v
is
not
None
:
if
v
is
not
None
:
self
.
param
.
set_value
(
v
)
self
.
param
.
set_value
(
v
)
class
HumanHyperParamSetter
(
HyperParamSetter
):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by loading the value from a file each time it get called.
Set hyperparameters by loading the value from a file each time it get called.
"""
"""
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
"""
"""
:param file_name: a file containing the value of the variable.
:param file_name: a file containing the value of the variable.
...
@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
with
open
(
self
.
file_name
)
as
f
:
with
open
(
self
.
file_name
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
ret
=
dic
[
self
.
param
.
readable_name
]
ret
=
dic
[
self
.
param
.
readable_name
]
return
ret
return
ret
except
:
except
:
...
@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
self
.
param
.
readable_name
,
self
.
file_name
))
self
.
param
.
readable_name
,
self
.
file_name
))
return
None
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by a predefined schedule.
Set hyperparameters by a predefined schedule.
"""
"""
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
"""
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
...
@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
return
v
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
func
):
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
new_value = f(epoch_num, old_value)
...
@@ -207,7 +220,9 @@ class HyperParamSetterWithFunc(HyperParamSetter):
...
@@ -207,7 +220,9 @@ class HyperParamSetterWithFunc(HyperParamSetter):
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
class
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
last_k
,
reverse
=
False
):
):
...
@@ -236,10 +251,10 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -236,10 +251,10 @@ class StatMonitorParamSetter(HyperParamSetter):
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
holder
=
self
.
trainer
.
stat_holder
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
return
None
return
None
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist_first
=
hist
[
0
]
hist_first
=
hist
[
0
]
if
not
self
.
reverse
:
if
not
self
.
reverse
:
...
@@ -254,4 +269,3 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -254,4 +269,3 @@ class StatMonitorParamSetter(HyperParamSetter):
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history: "
+
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history: "
+
','
.
join
(
map
(
str
,
hist
)))
','
.
join
(
map
(
str
,
hist
)))
return
self
.
value_func
(
self
.
get_current_value
())
return
self
.
value_func
(
self
.
get_current_value
())
tensorpack/callbacks/saver.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
os
,
shutil
import
os
import
shutil
import
re
import
re
from
.base
import
Callback
from
.base
import
Callback
...
@@ -13,10 +14,12 @@ from ..tfutils import get_global_step
...
@@ -13,10 +14,12 @@ from ..tfutils import get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
class
ModelSaver
(
Callback
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
None
):
var_collections
=
None
):
"""
"""
...
@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Callback
):
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
self
.
monitor_stat
=
monitor_stat
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
reverse
=
reverse
...
@@ -122,9 +127,8 @@ class MinSaver(Callback):
...
@@ -122,9 +127,8 @@ class MinSaver(Callback):
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
class
MaxSaver
(
MinSaver
):
class
MaxSaver
(
MinSaver
):
def
__init__
(
self
,
monitor_stat
):
def
__init__
(
self
,
monitor_stat
):
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
tensorpack/callbacks/stats.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
re
,
os
import
re
import
os
import
operator
import
operator
import
json
import
json
...
@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
...
@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
class
StatHolder
(
object
):
class
StatHolder
(
object
):
"""
"""
A holder to keep all statistics aside from tensorflow events.
A holder to keep all statistics aside from tensorflow events.
"""
"""
def
__init__
(
self
,
log_dir
):
def
__init__
(
self
,
log_dir
):
"""
"""
:param log_dir: directory to save the stats.
:param log_dir: directory to save the stats.
...
@@ -62,9 +65,11 @@ class StatHolder(object):
...
@@ -62,9 +65,11 @@ class StatHolder(object):
ret
=
[]
ret
=
[]
for
h
in
self
.
stat_history
:
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
v
=
h
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
return
ret
return
ret
def
finalize
(
self
):
def
finalize
(
self
):
...
@@ -91,10 +96,12 @@ class StatHolder(object):
...
@@ -91,10 +96,12 @@ class StatHolder(object):
except
IOError
:
# disk error sometimes..
except
IOError
:
# disk error sometimes..
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
class
StatPrinter
(
Callback
):
class
StatPrinter
(
Callback
):
"""
"""
Control what stats to print.
Control what stats to print.
"""
"""
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
"""
"""
:param print_tag: a list of regex to match scalar summary to print.
:param print_tag: a list of regex to match scalar summary to print.
...
@@ -116,6 +123,7 @@ class StatPrinter(Callback):
...
@@ -116,6 +123,7 @@ class StatPrinter(Callback):
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
class
SendStat
(
Callback
):
class
SendStat
(
Callback
):
"""
"""
Execute a command with some specific stats.
Execute a command with some specific stats.
...
@@ -126,6 +134,7 @@ class SendStat(Callback):
...
@@ -126,6 +134,7 @@ class SendStat(Callback):
-d body={validation_error} > /dev/null 2>&1',
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
'validation_error')
"""
"""
def
__init__
(
self
,
command
,
stats
):
def
__init__
(
self
,
command
,
stats
):
self
.
command
=
command
self
.
command
=
command
if
not
isinstance
(
stats
,
list
):
if
not
isinstance
(
stats
,
list
):
...
...
tensorpack/dataflow/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from . import imgaug
...
@@ -12,6 +12,7 @@ from . import imgaug
__all__
=
[
'dataset'
,
'imgaug'
,
'dftools'
]
__all__
=
[
'dataset'
,
'imgaug'
,
'dftools'
]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -26,4 +27,3 @@ for _, module_name, _ in walk_packages(
...
@@ -26,4 +27,3 @@ for _, module_name, _ in walk_packages(
if
not
module_name
.
startswith
(
'_'
)
and
\
if
not
module_name
.
startswith
(
'_'
)
and
\
module_name
not
in
__SKIP
:
module_name
not
in
__SKIP
:
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/dataflow/base.py
View file @
fb2a051c
...
@@ -10,6 +10,7 @@ from ..utils import get_rng
...
@@ -10,6 +10,7 @@ from ..utils import get_rng
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
]
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DataFlow
(
object
):
class
DataFlow
(
object
):
""" Base class for all DataFlow """
""" Base class for all DataFlow """
...
@@ -17,7 +18,6 @@ class DataFlow(object):
...
@@ -17,7 +18,6 @@ class DataFlow(object):
class
Infinity
:
class
Infinity
:
pass
pass
@
abstractmethod
@
abstractmethod
def
get_data
(
self
):
def
get_data
(
self
):
"""
"""
...
@@ -44,11 +44,14 @@ class DataFlow(object):
...
@@ -44,11 +44,14 @@ class DataFlow(object):
class
RNGDataFlow
(
DataFlow
):
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
""" A dataflow with rng"""
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
class
ProxyDataFlow
(
DataFlow
):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
:param ds: a :mod:`DataFlow` instance to proxy
:param ds: a :mod:`DataFlow` instance to proxy
...
...
tensorpack/dataflow/common.py
View file @
fb2a051c
...
@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
...
@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
class
TestDataSpeed
(
ProxyDataFlow
):
class
TestDataSpeed
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
size
=
1000
):
def
__init__
(
self
,
ds
,
size
=
1000
):
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
self
.
test_size
=
size
self
.
test_size
=
size
...
@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
...
@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
pbar
.
update
()
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
"""
Group data in `ds` into batches.
Group data in `ds` into batches.
...
@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
...
@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
raise
raise
except
:
except
:
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
return
result
return
result
class
BatchDataByShape
(
BatchData
):
class
BatchDataByShape
(
BatchData
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
""" Group datapoint of the same shape together to batches
""" Group datapoint of the same shape together to batches
...
@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
...
@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
yield
BatchData
.
_aggregate_batch
(
holder
)
yield
BatchData
.
_aggregate_batch
(
holder
)
del
holder
[:]
del
holder
[:]
class
FixedSizeData
(
ProxyDataFlow
):
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed epoch size.
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
The state of the underlying DataFlow is maintained among each epoch.
"""
"""
def
__init__
(
self
,
ds
,
size
):
def
__init__
(
self
,
ds
,
size
):
"""
"""
:param ds: a :mod:`DataFlow` to produce data
:param ds: a :mod:`DataFlow` to produce data
...
@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
if
cnt
==
self
.
_size
:
if
cnt
==
self
.
_size
:
return
return
class
RepeatedData
(
ProxyDataFlow
):
class
RepeatedData
(
ProxyDataFlow
):
""" Take data points from another `DataFlow` and produce them until
""" Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times.
it's exhausted for certain amount of times.
"""
"""
def
__init__
(
self
,
ds
,
nr
):
def
__init__
(
self
,
ds
,
nr
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
...
@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
""" Apply map/filter a function on the datapoint"""
""" Apply map/filter a function on the datapoint"""
def
__init__
(
self
,
ds
,
func
):
def
__init__
(
self
,
ds
,
func
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
...
@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
if
ret
is
not
None
:
if
ret
is
not
None
:
yield
ret
yield
ret
class
MapDataComponent
(
ProxyDataFlow
):
class
MapDataComponent
(
ProxyDataFlow
):
""" Apply map/filter on the given index in the datapoint"""
""" Apply map/filter on the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
...
@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
dp
[
self
.
index
]
=
repl
# NOTE modifying
dp
[
self
.
index
]
=
repl
# NOTE modifying
yield
dp
yield
dp
class
RandomChooseData
(
RNGDataFlow
):
class
RandomChooseData
(
RNGDataFlow
):
"""
"""
Randomly choose from several DataFlow. Stop producing when any of them is
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
exhausted.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
...
@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
...
@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
except
StopIteration
:
except
StopIteration
:
return
return
class
RandomMixData
(
RNGDataFlow
):
class
RandomMixData
(
RNGDataFlow
):
"""
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of dataflow.
:param df_lists: list of dataflow.
...
@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
...
@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
idxs
=
np
.
array
(
list
(
map
(
idxs
=
np
.
array
(
list
(
map
(
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
)))
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
)))
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
for
k
in
idxs
:
for
k
in
idxs
:
yield
next
(
itrs
[
k
])
yield
next
(
itrs
[
k
])
class
ConcatData
(
DataFlow
):
class
ConcatData
(
DataFlow
):
"""
"""
Concatenate several dataflows.
Concatenate several dataflows.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of :mod:`DataFlow` instances
:param df_lists: list of :mod:`DataFlow` instances
...
@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
...
@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
for
dp
in
d
.
get_data
():
for
dp
in
d
.
get_data
():
yield
dp
yield
dp
class
JoinData
(
DataFlow
):
class
JoinData
(
DataFlow
):
"""
"""
Join the components from each DataFlow.
Join the components from each DataFlow.
...
@@ -321,6 +342,7 @@ class JoinData(DataFlow):
...
@@ -321,6 +342,7 @@ class JoinData(DataFlow):
df2: [dp3, dp4]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of :mod:`DataFlow` instances
:param df_lists: list of :mod:`DataFlow` instances
...
@@ -352,7 +374,9 @@ class JoinData(DataFlow):
...
@@ -352,7 +374,9 @@ class JoinData(DataFlow):
for
itr
in
itrs
:
for
itr
in
itrs
:
del
itr
del
itr
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
,
nr_reuse
=
1
):
def
__init__
(
self
,
ds
,
cache_size
,
nr_reuse
=
1
):
"""
"""
Cache a number of datapoints and shuffle them.
Cache a number of datapoints and shuffle them.
...
@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
...
@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield
v
yield
v
return
return
def
SelectComponent
(
ds
,
idxs
):
def
SelectComponent
(
ds
,
idxs
):
"""
"""
:param ds: a :mod:`DataFlow` instance
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
:param idxs: a list of datapoint component index of the original dataflow
"""
"""
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
tensorpack/dataflow/dataset/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,8 @@ import os
...
@@ -7,6 +7,8 @@ import os
import
os.path
import
os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
...
@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/dataflow/dataset/bsds500.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: bsds500.py
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
glob
import
os
import
glob
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -21,6 +22,7 @@ except ImportError:
...
@@ -21,6 +22,7 @@ except ImportError:
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
RNGDataFlow
):
class
BSDS500
(
RNGDataFlow
):
"""
"""
`Berkeley Segmentation Data Set and Benchmarks 500
`Berkeley Segmentation Data Set and Benchmarks 500
...
@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
...
@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
assert
im
is
not
None
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
...
@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
...
@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
a
=
BSDS500
(
'val'
)
a
=
BSDS500
(
'val'
)
for
k
in
a
.
get_data
():
for
k
in
a
.
get_data
():
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
waitKey
(
1000
)
cv2
.
waitKey
(
1000
)
tensorpack/dataflow/dataset/cifar.py
View file @
fb2a051c
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import
os
,
sys
import
os
import
sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
random
import
random
...
@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
...
@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
copied from tensorflow example """
...
@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
...
@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
import
tarfile
import
tarfile
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar
(
filenames
,
cifar_classnum
):
def
read_cifar
(
filenames
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
ret
=
[]
ret
=
[]
...
@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
...
@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
ret
.
append
([
img
,
label
[
k
]])
ret
.
append
([
img
,
label
[
k
]])
return
ret
return
ret
def
get_filenames
(
dir
,
cifar_classnum
):
def
get_filenames
(
dir
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
...
@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
...
@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
return
filenames
class
CifarBase
(
RNGDataFlow
):
class
CifarBase
(
RNGDataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 32x32x3 in the range [0,255]
image is 32x32x3 in the range [0,255]
"""
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
"""
"""
Args:
Args:
...
@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
...
@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
return three values as mean of each channel
return three values as mean of each channel
"""
"""
mean
=
self
.
get_per_pixel_mean
()
mean
=
self
.
get_per_pixel_mean
()
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
class
Cifar10
(
CifarBase
):
class
Cifar10
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
class
Cifar100
(
CifarBase
):
class
Cifar100
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
...
@@ -149,7 +159,6 @@ if __name__ == '__main__':
...
@@ -149,7 +159,6 @@ if __name__ == '__main__':
print
(
mean
)
print
(
mean
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
#for (img, label) in ds.get_data():
# for (img, label) in ds.get_data():
#from IPython import embed; embed()
# from IPython import embed; embed()
#break
# break
tensorpack/dataflow/dataset/ilsvrc.py
View file @
fb2a051c
...
@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
...
@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
class
ILSVRCMeta
(
object
):
"""
"""
Some metadata for ILSVRC dataset.
Some metadata for ILSVRC dataset.
"""
"""
def
__init__
(
self
,
dir
=
None
):
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
...
@@ -82,12 +84,14 @@ class ILSVRCMeta(object):
...
@@ -82,12 +84,14 @@ class ILSVRCMeta(object):
with
open
(
mean_file
,
'rb'
)
as
f
:
with
open
(
mean_file
,
'rb'
)
as
f
:
obj
.
ParseFromString
(
f
.
read
())
obj
.
ParseFromString
(
f
.
read
())
arr
=
np
.
array
(
obj
.
data
)
.
reshape
((
3
,
256
,
256
))
.
astype
(
'float32'
)
arr
=
np
.
array
(
obj
.
data
)
.
reshape
((
3
,
256
,
256
))
.
astype
(
'float32'
)
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
if
size
is
not
None
:
if
size
is
not
None
:
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
return
arr
return
arr
class
ILSVRC12
(
RNGDataFlow
):
class
ILSVRC12
(
RNGDataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
,
include_bb
=
False
):
dir_structure
=
'original'
,
include_bb
=
False
):
"""
"""
...
@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
...
@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
im
=
cv2
.
imread
(
fname
.
strip
(),
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
fname
.
strip
(),
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
if
im
.
ndim
==
2
:
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
if
self
.
include_bb
:
if
self
.
include_bb
:
bb
=
self
.
bblist
[
k
]
bb
=
self
.
bblist
[
k
]
if
bb
is
None
:
if
bb
is
None
:
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
yield
[
im
,
label
,
bb
]
yield
[
im
,
label
,
bb
]
else
:
else
:
yield
[
im
,
label
]
yield
[
im
,
label
]
...
@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
...
@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
meta
=
ILSVRCMeta
()
#print(meta.get_synset_words_1000())
#
print(meta.get_synset_words_1000())
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
shuffle
=
False
)
shuffle
=
False
)
ds
.
reset_state
()
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
for
k
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
break
tensorpack/dataflow/dataset/mnist.py
View file @
fb2a051c
...
@@ -17,6 +17,7 @@ __all__ = ['Mnist']
...
@@ -17,6 +17,7 @@ __all__ = ['Mnist']
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
def
maybe_download
(
filename
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
"""Download the data from Yann's website, unless it's already here."""
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
...
@@ -25,10 +26,12 @@ def maybe_download(filename, work_directory):
...
@@ -25,10 +26,12 @@ def maybe_download(filename, work_directory):
download
(
SOURCE_URL
+
filename
,
work_directory
)
download
(
SOURCE_URL
+
filename
,
work_directory
)
return
filepath
return
filepath
def
_read32
(
bytestream
):
def
_read32
(
bytestream
):
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
extract_images
(
filename
):
def
extract_images
(
filename
):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
with
gzip
.
open
(
filename
)
as
bytestream
:
...
@@ -46,6 +49,7 @@ def extract_images(filename):
...
@@ -46,6 +49,7 @@ def extract_images(filename):
data
=
data
.
astype
(
'float32'
)
/
255.0
data
=
data
.
astype
(
'float32'
)
/
255.0
return
data
return
data
def
extract_labels
(
filename
):
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
with
gzip
.
open
(
filename
)
as
bytestream
:
...
@@ -59,11 +63,13 @@ def extract_labels(filename):
...
@@ -59,11 +63,13 @@ def extract_labels(filename):
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
return
labels
class
Mnist
(
RNGDataFlow
):
class
Mnist
(
RNGDataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 28x28 in the range [0,1]
image is 28x28 in the range [0,1]
"""
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
"""
Args:
Args:
...
@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
...
@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
for
(
img
,
label
)
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
break
tensorpack/dataflow/dataset/ptb.py
View file @
fb2a051c
...
@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
...
@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
@
memoized_ignoreargs
def
get_PennTreeBank
(
data_dir
=
None
):
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
if
data_dir
is
None
:
...
@@ -37,4 +38,3 @@ def get_PennTreeBank(data_dir=None):
...
@@ -37,4 +38,3 @@ def get_PennTreeBank(data_dir=None):
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
return
data3
,
word_to_id
tensorpack/dataflow/dataset/svhn.py
View file @
fb2a051c
...
@@ -19,6 +19,7 @@ except ImportError:
...
@@ -19,6 +19,7 @@ except ImportError:
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
RNGDataFlow
):
class
SVHNDigit
(
RNGDataFlow
):
"""
"""
SVHN Cropped Digit Dataset.
SVHN Cropped Digit Dataset.
...
@@ -44,9 +45,9 @@ class SVHNDigit(RNGDataFlow):
...
@@ -44,9 +45,9 @@ class SVHNDigit(RNGDataFlow):
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
data
=
scipy
.
io
.
loadmat
(
filename
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
Y
=
data
[
'y'
]
.
reshape
((
-
1
))
self
.
Y
=
data
[
'y'
]
.
reshape
((
-
1
))
self
.
Y
[
self
.
Y
==
10
]
=
0
self
.
Y
[
self
.
Y
==
10
]
=
0
SVHNDigit
.
_Cache
[
name
]
=
(
self
.
X
,
self
.
Y
)
SVHNDigit
.
_Cache
[
name
]
=
(
self
.
X
,
self
.
Y
)
def
size
(
self
):
def
size
(
self
):
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ import json
...
@@ -12,6 +12,7 @@ import json
__all__
=
[
'VisualQA'
]
__all__
=
[
'VisualQA'
]
def
read_json
(
fname
):
def
read_json
(
fname
):
f
=
open
(
fname
)
f
=
open
(
fname
)
ret
=
json
.
load
(
f
)
ret
=
json
.
load
(
f
)
...
@@ -19,11 +20,14 @@ def read_json(fname):
...
@@ -19,11 +20,14 @@ def read_json(fname):
return
ret
return
ret
# TODO shuffle
# TODO shuffle
class
VisualQA
(
DataFlow
):
class
VisualQA
(
DataFlow
):
"""
"""
Visual QA dataset. See http://visualqa.org/
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
Simply read q/a json file and produce q/a pairs in their original format.
"""
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
def
__init__
(
self
,
question_file
,
annotation_file
):
with
timed_operation
(
'Reading VQA JSON file'
):
with
timed_operation
(
'Reading VQA JSON file'
):
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
...
...
tensorpack/dataflow/dftools.py
View file @
fb2a051c
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# File: dftools.py
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
sys
,
os
import
sys
import
os
import
cv2
import
cv2
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
six
import
six
...
@@ -23,6 +24,8 @@ else:
...
@@ -23,6 +24,8 @@ else:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
# TODO pass a name_func to write label as filename?
# TODO pass a name_func to write label as filename?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a `DataFlow` to a directory.
""" Dump images from a `DataFlow` to a directory.
...
@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img
=
dp
[
index
]
img
=
dp
[
index
]
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
and the data is the serialized datapoint.
...
@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
...
@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
the queue once you start it. Each element is (task_id, dp).
the queue once you start it. Each element is (task_id, dp).
"""
"""
q
=
mp
.
Queue
(
size
)
q
=
mp
.
Queue
(
size
)
class
EnqueProc
(
mp
.
Process
):
class
EnqueProc
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
...
@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
return
q
,
proc
tensorpack/dataflow/format.py
View file @
fb2a051c
...
@@ -40,10 +40,13 @@ Adapters for different data format.
...
@@ -40,10 +40,13 @@ Adapters for different data format.
"""
"""
# TODO lazy load
# TODO lazy load
class
HDF5Data
(
RNGDataFlow
):
class
HDF5Data
(
RNGDataFlow
):
"""
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
"""
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
"""
:param filename: h5 data file.
:param filename: h5 data file.
...
@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
...
@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
self
.
dps
=
[
self
.
f
[
k
]
.
value
for
k
in
data_paths
]
self
.
dps
=
[
self
.
f
[
k
]
.
value
for
k
in
data_paths
]
lens
=
[
len
(
k
)
for
k
in
self
.
dps
]
lens
=
[
len
(
k
)
for
k
in
self
.
dps
]
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
self
.
_size
=
lens
[
0
]
self
.
_size
=
lens
[
0
]
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
...
@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
self
.
_lmdb_path
=
lmdb_path
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
...
@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
...
@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
v
=
self
.
_txn
.
get
(
k
)
v
=
self
.
_txn
.
get
(
k
)
yield
[
k
,
v
]
yield
[
k
,
v
]
class
LMDBDataDecoder
(
LMDBData
):
class
LMDBDataDecoder
(
LMDBData
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
"""
"""
:param decoder: a function taking k, v and return a data point,
:param decoder: a function taking k, v and return a data point,
...
@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
...
@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
if
v
:
yield
v
if
v
:
yield
v
class
LMDBDataPoint
(
LMDBDataDecoder
):
class
LMDBDataPoint
(
LMDBDataDecoder
):
""" Read a LMDB file where each value is a serialized datapoint"""
""" Read a LMDB file where each value is a serialized datapoint"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
super
(
LMDBDataPoint
,
self
)
.
__init__
(
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
class
CaffeLMDB
(
LMDBDataDecoder
):
class
CaffeLMDB
(
LMDBDataDecoder
):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
cpb
=
get_caffe_pb
()
cpb
=
get_caffe_pb
()
def
decoder
(
k
,
v
):
def
decoder
(
k
,
v
):
try
:
try
:
datum
=
cpb
.
Datum
()
datum
=
cpb
.
Datum
()
...
@@ -154,8 +166,10 @@ class CaffeLMDB(LMDBDataDecoder):
...
@@ -154,8 +166,10 @@ class CaffeLMDB(LMDBDataDecoder):
super
(
CaffeLMDB
,
self
)
.
__init__
(
super
(
CaffeLMDB
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
class
SVMLightData
(
RNGDataFlow
):
class
SVMLightData
(
RNGDataFlow
):
""" Read X,y from a svmlight file """
""" Read X,y from a svmlight file """
def
__init__
(
self
,
filename
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
shuffle
=
True
):
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
...
@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
...
@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
id
in
idxs
:
for
id
in
idxs
:
yield
[
self
.
X
[
id
,:],
self
.
y
[
id
]]
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
tensorpack/dataflow/image.py
View file @
fb2a051c
...
@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
...
@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageComponents'
]
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageComponents'
]
class
ImageFromFile
(
RNGDataFlow
):
class
ImageFromFile
(
RNGDataFlow
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
"""
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
...
@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
...
@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
if
self
.
resize
is
not
None
:
if
self
.
resize
is
not
None
:
im
=
cv2
.
resize
(
im
,
self
.
resize
[::
-
1
])
im
=
cv2
.
resize
(
im
,
self
.
resize
[::
-
1
])
if
self
.
channel
==
1
:
if
self
.
channel
==
1
:
im
=
im
[:,
:,
np
.
newaxis
]
im
=
im
[:,
:,
np
.
newaxis
]
yield
[
im
]
yield
[
im
]
class
AugmentImageComponent
(
MapDataComponent
):
class
AugmentImageComponent
(
MapDataComponent
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
"""
"""
Augment the image component of datapoints
Augment the image component of datapoints
...
@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
...
@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
class
AugmentImageComponents
(
MapData
):
class
AugmentImageComponents
(
MapData
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
""" Augment a list of images of the same shape, with the same parameters
""" Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
...
...
tensorpack/dataflow/imgaug/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,7 @@ from pkgutil import walk_packages
...
@@ -7,6 +7,7 @@ from pkgutil import walk_packages
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
...
@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/dataflow/imgaug/_test.py
View file @
fb2a051c
...
@@ -15,10 +15,10 @@ from .noise import SaltPepperNoise
...
@@ -15,10 +15,10 @@ from .noise import SaltPepperNoise
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
augmentors
=
AugmentorList
([
augmentors
=
AugmentorList
([
Contrast
((
0.8
,
1.2
)),
Contrast
((
0.8
,
1.2
)),
Flip
(
horiz
=
True
),
Flip
(
horiz
=
True
),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
#RandomCropRandomShape(0.3),
#
RandomCropRandomShape(0.3),
SaltPepperNoise
()
SaltPepperNoise
()
])
])
...
...
tensorpack/dataflow/imgaug/base.py
View file @
fb2a051c
...
@@ -9,6 +9,7 @@ from six.moves import zip
...
@@ -9,6 +9,7 @@ from six.moves import zip
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Augmentor
(
object
):
class
Augmentor
(
object
):
""" Base class for an augmentor"""
""" Base class for an augmentor"""
...
@@ -58,7 +59,9 @@ class Augmentor(object):
...
@@ -58,7 +59,9 @@ class Augmentor(object):
size
=
[]
size
=
[]
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
class
ImageAugmentor
(
Augmentor
):
class
ImageAugmentor
(
Augmentor
):
def
augment
(
self
,
img
):
def
augment
(
self
,
img
):
"""
"""
Perform augmentation on the image in-place.
Perform augmentation on the image in-place.
...
@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
...
@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
return
coord
return
coord
class
AugmentorList
(
ImageAugmentor
):
class
AugmentorList
(
ImageAugmentor
):
"""
"""
Augment by a list of augmentors
Augment by a list of augmentors
"""
"""
def
__init__
(
self
,
augmentors
):
def
__init__
(
self
,
augmentors
):
"""
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
:param augmentors: list of `ImageAugmentor` instance to be applied
...
@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
...
@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """
""" Will reset state of each augmentor """
for
a
in
self
.
augs
:
for
a
in
self
.
augs
:
a
.
reset_state
()
a
.
reset_state
()
tensorpack/dataflow/imgaug/crop.py
View file @
fb2a051c
...
@@ -12,8 +12,10 @@ import numpy as np
...
@@ -12,8 +12,10 @@ import numpy as np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
class
RandomCrop
(
ImageAugmentor
):
class
RandomCrop
(
ImageAugmentor
):
""" Randomly crop the image into a smaller one """
""" Randomly crop the image into a smaller one """
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
:param crop_shape: a shape like (h, w)
...
@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
...
@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
param
):
def
_augment
(
self
,
img
,
param
):
h0
,
w0
=
param
h0
,
w0
=
param
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
CenterCrop
(
ImageAugmentor
):
class
CenterCrop
(
ImageAugmentor
):
""" Crop the image at the center"""
""" Crop the image at the center"""
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
:param crop_shape: a shape like (h, w)
...
@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
...
@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
FixedCrop
(
ImageAugmentor
):
class
FixedCrop
(
ImageAugmentor
):
""" Crop a rectangle at a given location"""
""" Crop a rectangle at a given location"""
def
__init__
(
self
,
rect
):
def
__init__
(
self
,
rect
):
"""
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
Two arguments defined the range in both axes to crop, min inclued, max excluded.
...
@@ -69,12 +75,13 @@ class FixedCrop(ImageAugmentor):
...
@@ -69,12 +75,13 @@ class FixedCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
perturb_BB
(
image_shape
,
bb
,
max_pertub_pixel
,
def
perturb_BB
(
image_shape
,
bb
,
max_pertub_pixel
,
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
max_try
=
100
):
max_try
=
100
):
...
@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
"""
"""
Crop a box around a bounding box
Crop a box around a bounding box
"""
"""
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
"""
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
...
@@ -124,7 +132,7 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -124,7 +132,7 @@ class RandomCropAroundBox(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
shape
=
img
.
shape
[:
2
]
shape
=
img
.
shape
[:
2
]
box
=
Rect
(
0
,
0
,
shape
[
1
]
-
1
,
shape
[
0
]
-
1
)
box
=
Rect
(
0
,
0
,
shape
[
1
]
-
1
,
shape
[
0
]
-
1
)
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
newbox
=
perturb_BB
(
shape
,
box
,
dist
,
newbox
=
perturb_BB
(
shape
,
box
,
dist
,
self
.
rng
,
self
.
max_aspect_ratio_diff
)
self
.
rng
,
self
.
max_aspect_ratio_diff
)
return
newbox
return
newbox
...
@@ -135,7 +143,9 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -135,7 +143,9 @@ class RandomCropAroundBox(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
RandomCropRandomShape
(
ImageAugmentor
):
class
RandomCropRandomShape
(
ImageAugmentor
):
def
__init__
(
self
,
wmin
,
hmin
,
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
max_aspect_ratio
=
None
):
...
@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
diffh
=
img
.
shape
[
0
]
-
h
diffh
=
img
.
shape
[
0
]
-
h
diffw
=
img
.
shape
[
1
]
-
w
diffw
=
img
.
shape
[
1
]
-
w
assert
diffh
>=
0
and
diffw
>=
0
assert
diffh
>=
0
and
diffw
>=
0
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
return
(
y0
,
x0
,
h
,
w
)
return
(
y0
,
x0
,
h
,
w
)
def
_augment
(
self
,
img
,
param
):
def
_augment
(
self
,
img
,
param
):
y0
,
x0
,
h
,
w
=
param
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/dataflow/imgaug/deform.py
View file @
fb2a051c
...
@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
...
@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup
# TODO really needs speedup
class
GaussianMap
(
object
):
class
GaussianMap
(
object
):
""" Generate gaussian weighted deformation map"""
""" Generate gaussian weighted deformation map"""
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
assert
len
(
image_shape
)
==
2
assert
len
(
image_shape
)
==
2
self
.
shape
=
image_shape
self
.
shape
=
image_shape
...
@@ -25,17 +27,18 @@ class GaussianMap(object):
...
@@ -25,17 +27,18 @@ class GaussianMap(object):
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
#cv2.imshow(" ", g)
#cv2.imshow(" ", g)
#cv2.waitKey()
#
cv2.waitKey()
return
g
return
g
def
np_sample
(
img
,
coords
):
def
np_sample
(
img
,
coords
):
# a numpy implementation of ImageSample layer
# a numpy implementation of ImageSample layer
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
lcoor
=
np
.
floor
(
coords
)
.
astype
(
'int32'
)
lcoor
=
np
.
floor
(
coords
)
.
astype
(
'int32'
)
ucoor
=
lcoor
+
1
ucoor
=
lcoor
+
1
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
diff
=
coords
-
lcoor
diff
=
coords
-
lcoor
neg_diff
=
1.0
-
diff
neg_diff
=
1.0
-
diff
...
@@ -46,17 +49,20 @@ def np_sample(img, coords):
...
@@ -46,17 +49,20 @@ def np_sample(img, coords):
diffy
,
diffx
=
np
.
split
(
diff
,
2
,
axis
=
2
)
diffy
,
diffx
=
np
.
split
(
diff
,
2
,
axis
=
2
)
ndiffy
,
ndiffx
=
np
.
split
(
neg_diff
,
2
,
axis
=
2
)
ndiffy
,
ndiffx
=
np
.
split
(
neg_diff
,
2
,
axis
=
2
)
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
return
ret
[:,
:,
0
,
:]
# TODO input/output with different shape
# TODO input/output with different shape
class
GaussianDeform
(
ImageAugmentor
):
class
GaussianDeform
(
ImageAugmentor
):
"""
"""
Some kind of deformation. Quite slow.
Some kind of deformation. Quite slow.
"""
"""
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
"""
"""
:param anchors: in [0,1] coordinate
:param anchors: in [0,1] coordinate
...
@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
...
@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
self
.
anchors
=
anchors
self
.
anchors
=
anchors
self
.
K
=
len
(
self
.
anchors
)
self
.
K
=
len
(
self
.
anchors
)
self
.
shape
=
shape
self
.
shape
=
shape
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
gm
=
GaussianMap
(
self
.
shape
,
sigma
=
sigma
)
gm
=
GaussianMap
(
self
.
shape
,
sigma
=
sigma
)
self
.
gws
=
np
.
array
([
gm
.
get_gaussian_weight
(
ank
)
self
.
gws
=
np
.
array
([
gm
.
get_gaussian_weight
(
ank
)
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#HxWxK
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#
HxWxK
if
randrange
is
None
:
if
randrange
is
None
:
self
.
randrange
=
self
.
shape
[
0
]
/
8
self
.
randrange
=
self
.
shape
[
0
]
/
8
else
:
else
:
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
fb2a051c
...
@@ -10,9 +10,11 @@ import numpy as np
...
@@ -10,9 +10,11 @@ import numpy as np
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
class
Rotation
(
ImageAugmentor
):
class
Rotation
(
ImageAugmentor
):
""" Random rotate the image w.r.t a random center"""
""" Random rotate the image w.r.t a random center"""
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
interp
=
cv2
.
INTER_CUBIC
,
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
border
=
cv2
.
BORDER_REPLICATE
):
"""
"""
...
@@ -33,10 +35,12 @@ class Rotation(ImageAugmentor):
...
@@ -33,10 +35,12 @@ class Rotation(ImageAugmentor):
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
return
ret
return
ret
class
RotationAndCropValid
(
ImageAugmentor
):
class
RotationAndCropValid
(
ImageAugmentor
):
""" Random rotate and crop the largest possible rect without the border
""" Random rotate and crop the largest possible rect without the border
This will produce images of different shapes.
This will produce images of different shapes.
"""
"""
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
super
(
RotationAndCropValid
,
self
)
.
__init__
()
super
(
RotationAndCropValid
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -46,7 +50,7 @@ class RotationAndCropValid(ImageAugmentor):
...
@@ -46,7 +50,7 @@ class RotationAndCropValid(ImageAugmentor):
return
deg
return
deg
def
_augment
(
self
,
img
,
deg
):
def
_augment
(
self
,
img
,
deg
):
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
rot_m
=
cv2
.
getRotationMatrix2D
(
center
,
deg
,
1
)
rot_m
=
cv2
.
getRotationMatrix2D
(
center
,
deg
,
1
)
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
...
@@ -56,29 +60,29 @@ class RotationAndCropValid(ImageAugmentor):
...
@@ -56,29 +60,29 @@ class RotationAndCropValid(ImageAugmentor):
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
#print(ret.shape, deg, newx, newy, neww, newh)
#print(ret.shape, deg, newx, newy, neww, newh)
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
@
staticmethod
@
staticmethod
def
largest_rotated_rect
(
w
,
h
,
angle
):
def
largest_rotated_rect
(
w
,
h
,
angle
):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle
=
angle
/
180.0
*
math
.
pi
angle
=
angle
/
180.0
*
math
.
pi
if
w
<=
0
or
h
<=
0
:
if
w
<=
0
or
h
<=
0
:
return
0
,
0
return
0
,
0
width_is_longer
=
w
>=
h
width_is_longer
=
w
>=
h
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
# since the solutions for angle, -angle and 180-angle are all the same,
# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a
,
cos_a
=
abs
(
math
.
sin
(
angle
)),
abs
(
math
.
cos
(
angle
))
sin_a
,
cos_a
=
abs
(
math
.
sin
(
angle
)),
abs
(
math
.
cos
(
angle
))
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
# half constrained case: two crop corners touch the longer side,
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
# the other two corners are on the mid-line parallel to the longer line
x
=
0.5
*
side_short
x
=
0.5
*
side_short
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
else
:
else
:
# fully constrained case: crop touches all 4 sides
# fully constrained case: crop touches all 4 sides
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
return
int
(
wr
),
int
(
hr
)
return
int
(
wr
),
int
(
hr
)
tensorpack/dataflow/imgaug/imgproc.py
View file @
fb2a051c
...
@@ -9,10 +9,12 @@ import cv2
...
@@ -9,10 +9,12 @@ import cv2
__all__
=
[
'Brightness'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
__all__
=
[
'Brightness'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
class
Brightness
(
ImageAugmentor
):
class
Brightness
(
ImageAugmentor
):
"""
"""
Random adjust brightness.
Random adjust brightness.
"""
"""
def
__init__
(
self
,
delta
,
clip
=
True
):
def
__init__
(
self
,
delta
,
clip
=
True
):
"""
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...
@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
...
@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
img
=
np
.
clip
(
img
,
0
,
255
)
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
return
img
class
Contrast
(
ImageAugmentor
):
class
Contrast
(
ImageAugmentor
):
"""
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
and clip to [0, 255]
"""
"""
def
__init__
(
self
,
factor_range
,
clip
=
True
):
def
__init__
(
self
,
factor_range
,
clip
=
True
):
"""
"""
:param factor_range: an interval to random sample the `contrast_factor`.
:param factor_range: an interval to random sample the `contrast_factor`.
...
@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
...
@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
return
self
.
_rand_range
(
*
self
.
factor_range
)
return
self
.
_rand_range
(
*
self
.
factor_range
)
def
_augment
(
self
,
img
,
r
):
def
_augment
(
self
,
img
,
r
):
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
img
=
(
img
-
mean
)
*
r
+
mean
img
=
(
img
-
mean
)
*
r
+
mean
if
self
.
clip
:
if
self
.
clip
:
img
=
np
.
clip
(
img
,
0
,
255
)
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
return
img
class
MeanVarianceNormalize
(
ImageAugmentor
):
class
MeanVarianceNormalize
(
ImageAugmentor
):
"""
"""
Linearly scales image to have zero mean and unit norm.
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
"""
def
__init__
(
self
,
all_channel
=
True
):
def
__init__
(
self
,
all_channel
=
True
):
"""
"""
:param all_channel: if True, normalize all channels together. else separately.
:param all_channel: if True, normalize all channels together. else separately.
...
@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
...
@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
mean
=
np
.
mean
(
img
)
mean
=
np
.
mean
(
img
)
std
=
np
.
std
(
img
)
std
=
np
.
std
(
img
)
else
:
else
:
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
maximum
(
std
,
1.0
/
np
.
sqrt
(
np
.
prod
(
img
.
shape
)))
std
=
np
.
maximum
(
std
,
1.0
/
np
.
sqrt
(
np
.
prod
(
img
.
shape
)))
img
=
(
img
-
mean
)
/
std
img
=
(
img
-
mean
)
/
std
return
img
return
img
class
GaussianBlur
(
ImageAugmentor
):
class
GaussianBlur
(
ImageAugmentor
):
def
__init__
(
self
,
max_size
=
3
):
def
__init__
(
self
,
max_size
=
3
):
""":params max_size: (maximum kernel size-1)/2"""
""":params max_size: (maximum kernel size-1)/2"""
super
(
GaussianBlur
,
self
)
.
__init__
()
super
(
GaussianBlur
,
self
)
.
__init__
()
...
@@ -96,6 +103,7 @@ class GaussianBlur(ImageAugmentor):
...
@@ -96,6 +103,7 @@ class GaussianBlur(ImageAugmentor):
class
Gamma
(
ImageAugmentor
):
class
Gamma
(
ImageAugmentor
):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
super
(
Gamma
,
self
)
.
__init__
()
super
(
Gamma
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
...
@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
img
=
cv2
.
LUT
(
img
,
lut
)
.
astype
(
'float32'
)
img
=
cv2
.
LUT
(
img
,
lut
)
.
astype
(
'float32'
)
return
img
return
img
class
Clip
(
ImageAugmentor
):
class
Clip
(
ImageAugmentor
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
...
@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
img
=
np
.
clip
(
img
,
self
.
min
,
self
.
max
)
img
=
np
.
clip
(
img
,
self
.
min
,
self
.
max
)
return
img
return
img
class
Saturation
(
ImageAugmentor
):
class
Saturation
(
ImageAugmentor
):
def
__init__
(
self
,
alpha
=
0.4
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
"""
...
@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
...
@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
def
_augment
(
self
,
img
,
v
):
def
_augment
(
self
,
img
,
v
):
grey
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
grey
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,:,
np
.
newaxis
]
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,
:,
np
.
newaxis
]
class
Lighting
(
ImageAugmentor
):
class
Lighting
(
ImageAugmentor
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
...
@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
...
@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
eigval
=
np
.
asarray
(
eigval
)
eigval
=
np
.
asarray
(
eigval
)
eigvec
=
np
.
asarray
(
eigvec
)
eigvec
=
np
.
asarray
(
eigvec
)
assert
eigval
.
shape
==
(
3
,)
assert
eigval
.
shape
==
(
3
,)
assert
eigvec
.
shape
==
(
3
,
3
)
assert
eigvec
.
shape
==
(
3
,
3
)
self
.
_init
(
locals
())
self
.
_init
(
locals
())
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
...
@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
...
@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
inc
=
np
.
dot
(
self
.
eigvec
,
v
)
.
reshape
((
3
,))
inc
=
np
.
dot
(
self
.
eigvec
,
v
)
.
reshape
((
3
,))
img
+=
inc
img
+=
inc
return
img
return
img
tensorpack/dataflow/imgaug/meta.py
View file @
fb2a051c
...
@@ -9,12 +9,16 @@ from .base import ImageAugmentor
...
@@ -9,12 +9,16 @@ from .base import ImageAugmentor
__all__
=
[
'RandomChooseAug'
,
'MapImage'
,
'Identity'
,
'RandomApplyAug'
,
__all__
=
[
'RandomChooseAug'
,
'MapImage'
,
'Identity'
,
'RandomApplyAug'
,
'RandomOrderAug'
]
'RandomOrderAug'
]
class
Identity
(
ImageAugmentor
):
class
Identity
(
ImageAugmentor
):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
return
img
return
img
class
RandomApplyAug
(
ImageAugmentor
):
class
RandomApplyAug
(
ImageAugmentor
):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
def
__init__
(
self
,
aug
,
prob
):
def
__init__
(
self
,
aug
,
prob
):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
super
(
RandomApplyAug
,
self
)
.
__init__
()
super
(
RandomApplyAug
,
self
)
.
__init__
()
...
@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
...
@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
else
:
else
:
return
self
.
aug
.
_augment
(
img
,
prm
[
1
])
return
self
.
aug
.
_augment
(
img
,
prm
[
1
])
class
RandomChooseAug
(
ImageAugmentor
):
class
RandomChooseAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
...
@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
...
@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
idx
,
prm
=
prm
idx
,
prm
=
prm
return
self
.
aug_lists
[
idx
]
.
_augment
(
img
,
prm
)
return
self
.
aug_lists
[
idx
]
.
_augment
(
img
,
prm
)
class
RandomOrderAug
(
ImageAugmentor
):
class
RandomOrderAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
Shuffle the augmentors into random order.
Shuffle the augmentors into random order.
...
@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
...
@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
img
=
self
.
aug_lists
[
k
]
.
_augment
(
img
,
prms
[
k
])
img
=
self
.
aug_lists
[
k
]
.
_augment
(
img
,
prms
[
k
])
return
img
return
img
class
MapImage
(
ImageAugmentor
):
class
MapImage
(
ImageAugmentor
):
"""
"""
Map the image array by a function.
Map the image array by a function.
"""
"""
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
"""
"""
:param func: a function which takes a image array and return a augmented one
:param func: a function which takes a image array and return a augmented one
...
@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
...
@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
return
self
.
func
(
img
)
return
self
.
func
(
img
)
tensorpack/dataflow/imgaug/noise.py
View file @
fb2a051c
...
@@ -9,7 +9,9 @@ import cv2
...
@@ -9,7 +9,9 @@ import cv2
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
class
JpegNoise
(
ImageAugmentor
):
class
JpegNoise
(
ImageAugmentor
):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
super
(
JpegNoise
,
self
)
.
__init__
()
super
(
JpegNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
...
@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
class
GaussianNoise
(
ImageAugmentor
):
class
GaussianNoise
(
ImageAugmentor
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
"""
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
Add a gaussian noise N(0, sigma^2) of the same shape to img.
...
@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
...
@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
ret
=
np
.
clip
(
ret
,
0
,
255
)
ret
=
np
.
clip
(
ret
,
0
,
255
)
return
ret
return
ret
class
SaltPepperNoise
(
ImageAugmentor
):
class
SaltPepperNoise
(
ImageAugmentor
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
""" Salt and pepper noise.
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
Randomly set some elements in img to 0 or 255, regardless of its channels.
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
fb2a051c
...
@@ -10,10 +10,12 @@ import cv2
...
@@ -10,10 +10,12 @@ import cv2
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
]
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
]
class
Flip
(
ImageAugmentor
):
class
Flip
(
ImageAugmentor
):
"""
"""
Random flip.
Random flip.
"""
"""
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
"""
"""
Only one of horiz, vert can be set.
Only one of horiz, vert can be set.
...
@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
...
@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
Resize
(
ImageAugmentor
):
class
Resize
(
ImageAugmentor
):
""" Resize image to a target size"""
""" Resize image to a target size"""
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param shape: shape in (h, w)
:param shape: shape in (h, w)
...
@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
...
@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
img
,
self
.
shape
[::
-
1
],
img
,
self
.
shape
[::
-
1
],
interpolation
=
self
.
interp
)
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
class
ResizeShortestEdge
(
ImageAugmentor
):
class
ResizeShortestEdge
(
ImageAugmentor
):
""" Resize the shortest edge to a certain number while
""" Resize the shortest edge to a certain number while
keeping the aspect ratio
keeping the aspect ratio
"""
"""
def
__init__
(
self
,
size
):
def
__init__
(
self
,
size
):
size
=
size
*
1.0
size
=
size
*
1.0
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -76,12 +82,14 @@ class ResizeShortestEdge(ImageAugmentor):
...
@@ -76,12 +82,14 @@ class ResizeShortestEdge(ImageAugmentor):
desSize
=
map
(
int
,
[
scale
*
w
,
scale
*
h
])
desSize
=
map
(
int
,
[
scale
*
w
,
scale
*
h
])
ret
=
cv2
.
resize
(
img
,
tuple
(
desSize
),
interpolation
=
cv2
.
INTER_CUBIC
)
ret
=
cv2
.
resize
(
img
,
tuple
(
desSize
),
interpolation
=
cv2
.
INTER_CUBIC
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
class
RandomResize
(
ImageAugmentor
):
class
RandomResize
(
ImageAugmentor
):
""" randomly rescale w and h of the image"""
""" randomly rescale w and h of the image"""
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param xrange: (min, max) scaling ratio
:param xrange: (min, max) scaling ratio
...
@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
...
@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
def
_augment
(
self
,
img
,
dsize
):
def
_augment
(
self
,
img
,
dsize
):
ret
=
cv2
.
resize
(
img
,
dsize
,
interpolation
=
self
.
interp
)
ret
=
cv2
.
resize
(
img
,
dsize
,
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
tensorpack/dataflow/imgaug/paste.py
View file @
fb2a051c
...
@@ -14,6 +14,7 @@ __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
...
@@ -14,6 +14,7 @@ __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
class
BackgroundFiller
(
object
):
class
BackgroundFiller
(
object
):
""" Base class for all BackgroundFiller"""
""" Base class for all BackgroundFiller"""
def
fill
(
self
,
background_shape
,
img
):
def
fill
(
self
,
background_shape
,
img
):
"""
"""
Return a proper background image of background_shape, given img
Return a proper background image of background_shape, given img
...
@@ -28,8 +29,10 @@ class BackgroundFiller(object):
...
@@ -28,8 +29,10 @@ class BackgroundFiller(object):
def
_fill
(
self
,
background_shape
,
img
):
def
_fill
(
self
,
background_shape
,
img
):
pass
pass
class
ConstantBackgroundFiller
(
BackgroundFiller
):
class
ConstantBackgroundFiller
(
BackgroundFiller
):
""" Fill the background by a constant """
""" Fill the background by a constant """
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
"""
"""
:param value: the value to fill the background.
:param value: the value to fill the background.
...
@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
...
@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
return_shape
=
background_shape
return_shape
=
background_shape
return
np
.
zeros
(
return_shape
)
+
self
.
value
return
np
.
zeros
(
return_shape
)
+
self
.
value
class
CenterPaste
(
ImageAugmentor
):
class
CenterPaste
(
ImageAugmentor
):
"""
"""
Paste the image onto the center of a background canvas.
Paste the image onto the center of a background canvas.
"""
"""
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
"""
"""
:param background_shape: shape of the background canvas.
:param background_shape: shape of the background canvas.
...
@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
...
@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
self
.
background_shape
,
img
)
self
.
background_shape
,
img
)
y0
=
int
((
self
.
background_shape
[
0
]
-
img_shape
[
0
])
*
0.5
)
y0
=
int
((
self
.
background_shape
[
0
]
-
img_shape
[
0
])
*
0.5
)
x0
=
int
((
self
.
background_shape
[
1
]
-
img_shape
[
1
])
*
0.5
)
x0
=
int
((
self
.
background_shape
[
1
]
-
img_shape
[
1
])
*
0.5
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
return
background
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
RandomPaste
(
CenterPaste
):
class
RandomPaste
(
CenterPaste
):
"""
"""
Randomly paste the image onto a background convas
Randomly paste the image onto a background convas
"""
"""
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
assert
self
.
background_shape
[
0
]
>
img_shape
[
0
]
and
self
.
background_shape
[
1
]
>
img_shape
[
1
]
assert
self
.
background_shape
[
0
]
>
img_shape
[
0
]
and
self
.
background_shape
[
1
]
>
img_shape
[
1
]
...
@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
...
@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
background
=
self
.
background_filler
.
fill
(
background
=
self
.
background_filler
.
fill
(
self
.
background_shape
,
img
)
self
.
background_shape
,
img
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
return
background
tensorpack/dataflow/prefetch.py
View file @
fb2a051c
...
@@ -28,6 +28,7 @@ else:
...
@@ -28,6 +28,7 @@ else:
class
PrefetchProcess
(
mp
.
Process
):
class
PrefetchProcess
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
"""
"""
:param ds: ds to take data from
:param ds: ds to take data from
...
@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
...
@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
self
.
queue
.
put
(
dp
)
class
PrefetchData
(
ProxyDataFlow
):
class
PrefetchData
(
ProxyDataFlow
):
"""
"""
Prefetch data from a `DataFlow` using multiprocessing
Prefetch data from a `DataFlow` using multiprocessing
"""
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
...
@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes
# do nothing. all ds are reset once and only once in spawned processes
pass
pass
def
BlockParallel
(
ds
,
queue_size
):
def
BlockParallel
(
ds
,
queue_size
):
# TODO more doc
# TODO more doc
"""
"""
...
@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
...
@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
"""
"""
return
PrefetchData
(
ds
,
queue_size
,
1
)
return
PrefetchData
(
ds
,
queue_size
,
1
)
class
PrefetchProcessZMQ
(
mp
.
Process
):
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
...
@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
class
PrefetchDataZMQ
(
ProxyDataFlow
):
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
""" Work the same as `PrefetchData`, but faster. """
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
except
:
except
:
pass
pass
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
variable"""
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
self
.
gpus
=
gpus
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
...
@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
...
@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
with
change_gpu
(
gpu
):
with
change_gpu
(
gpu
):
proc
.
start
()
proc
.
start
()
tensorpack/dataflow/raw.py
View file @
fb2a051c
...
@@ -17,8 +17,10 @@ except:
...
@@ -17,8 +17,10 @@ except:
else
:
else
:
__all__
.
append
(
'DataFromSocket'
)
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
class
FakeData
(
RNGDataFlow
):
""" Generate fake fixed data of given shapes"""
""" Generate fake fixed data of given shapes"""
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
"""
"""
:param shapes: a list of lists/tuples
:param shapes: a list of lists/tuples
...
@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
...
@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
for
_
in
range
(
self
.
_size
):
for
_
in
range
(
self
.
_size
):
yield
copy
.
deepcopy
(
v
)
yield
copy
.
deepcopy
(
v
)
class
DataFromQueue
(
DataFlow
):
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
...
@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
while
True
:
while
True
:
yield
self
.
queue
.
get
()
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
def
__init__
(
self
,
lst
,
shuffle
=
True
):
super
(
DataFromList
,
self
)
.
__init__
()
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
self
.
lst
=
lst
...
@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
...
@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
for
k
in
idxs
:
for
k
in
idxs
:
yield
self
.
lst
[
k
]
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
self
.
_name
=
socket_name
...
@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
...
@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
yield
dp
yield
dp
finally
:
finally
:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
0
)
tensorpack/dataflow/remote.py
View file @
fb2a051c
...
@@ -17,6 +17,7 @@ from .common import RepeatedData
...
@@ -17,6 +17,7 @@ from .common import RepeatedData
from
..utils
import
logger
from
..utils
import
logger
from
..utils.serialize
import
dumps
,
loads
from
..utils.serialize
import
dumps
,
loads
def
serve_data
(
ds
,
addr
):
def
serve_data
(
ds
,
addr
):
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
...
@@ -36,7 +37,9 @@ def serve_data(ds, addr):
...
@@ -36,7 +37,9 @@ def serve_data(ds, addr):
if
not
ctx
.
closed
:
if
not
ctx
.
closed
:
ctx
.
destroy
(
0
)
ctx
.
destroy
(
0
)
class
RemoteData
(
DataFlow
):
class
RemoteData
(
DataFlow
):
def
__init__
(
self
,
addr
):
def
__init__
(
self
,
addr
):
self
.
ctx
=
zmq
.
Context
()
self
.
ctx
=
zmq
.
Context
()
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
...
@@ -54,7 +57,7 @@ if __name__ == '__main__':
...
@@ -54,7 +57,7 @@ if __name__ == '__main__':
from
.raw
import
FakeData
from
.raw
import
FakeData
addr
=
"tcp://127.0.0.1:8877"
addr
=
"tcp://127.0.0.1:8877"
if
sys
.
argv
[
1
]
==
'serve'
:
if
sys
.
argv
[
1
]
==
'serve'
:
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
serve_data
(
ds
,
addr
)
serve_data
(
ds
,
addr
)
else
:
else
:
ds
=
RemoteData
(
addr
)
ds
=
RemoteData
(
addr
)
...
@@ -62,4 +65,3 @@ if __name__ == '__main__':
...
@@ -62,4 +65,3 @@ if __name__ == '__main__':
with
tqdm
(
total
=
10000
)
as
pbar
:
with
tqdm
(
total
=
10000
)
as
pbar
:
for
k
in
ds
.
get_data
():
for
k
in
ds
.
get_data
():
pbar
.
update
()
pbar
.
update
()
tensorpack/dataflow/tf_func.py
View file @
fb2a051c
...
@@ -14,7 +14,9 @@ except ImportError:
...
@@ -14,7 +14,9 @@ except ImportError:
else
:
else
:
__all__
=
[
'TFFuncMapper'
]
__all__
=
[
'TFFuncMapper'
]
class
TFFuncMapper
(
ProxyDataFlow
):
class
TFFuncMapper
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
def
__init__
(
self
,
ds
,
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
"""
"""
...
@@ -67,12 +69,12 @@ if __name__ == '__main__':
...
@@ -67,12 +69,12 @@ if __name__ == '__main__':
tf_aug
,
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
)
#ds = AugmentImageComponent(ds,
#
ds = AugmentImageComponent(ds,
#
[imgaug.Brightness(0.1, clip=False),
#
[imgaug.Brightness(0.1, clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Flip(horiz=True)
#
imgaug.Flip(horiz=True)
#
])
#
])
#ds = PrefetchDataZMQ(ds, 4)
#
ds = PrefetchDataZMQ(ds, 4)
ds
.
reset_state
()
ds
.
reset_state
()
import
tqdm
import
tqdm
...
...
tensorpack/models/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..utils import logger
...
@@ -12,6 +12,7 @@ from ..utils import logger
__all__
=
[
'LinearWrap'
]
__all__
=
[
'LinearWrap'
]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -32,6 +33,7 @@ class LinearWrap(object):
...
@@ -32,6 +33,7 @@ class LinearWrap(object):
"""
"""
class
TFModuleFunc
(
object
):
class
TFModuleFunc
(
object
):
def
__init__
(
self
,
mod
,
tensor
):
def
__init__
(
self
,
mod
,
tensor
):
self
.
_mod
=
mod
self
.
_mod
=
mod
self
.
_t
=
tensor
self
.
_t
=
tensor
...
@@ -88,4 +90,3 @@ class LinearWrap(object):
...
@@ -88,4 +90,3 @@ class LinearWrap(object):
def
print_tensor
(
self
):
def
print_tensor
(
self
):
print
(
self
.
_t
)
print
(
self
.
_t
)
return
self
return
self
tensorpack/models/_common.py
View file @
fb2a051c
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
functools
import
wraps
from
functools
import
wraps
import
six
import
six
import
copy
,
os
import
copy
import
os
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.modelutils
import
get_shape_str
from
..tfutils.modelutils
import
get_shape_str
...
@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
...
@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
# make sure each layer is only logged once
# make sure each layer is only logged once
_layer_logged
=
set
()
_layer_logged
=
set
()
def
disable_layer_logging
():
def
disable_layer_logging
():
class
ContainEverything
:
class
ContainEverything
:
def
__contains__
(
self
,
x
):
def
__contains__
(
self
,
x
):
return
True
return
True
# can use nonlocal in python3, but how
# can use nonlocal in python3, but how
globals
()[
'_layer_logged'
]
=
ContainEverything
()
globals
()[
'_layer_logged'
]
=
ContainEverything
()
def
layer_register
(
def
layer_register
(
summary_activation
=
False
,
summary_activation
=
False
,
log_shape
=
True
,
log_shape
=
True
,
...
@@ -104,6 +108,7 @@ def layer_register(
...
@@ -104,6 +108,7 @@ def layer_register(
return
wrapper
return
wrapper
def
shape4d
(
a
):
def
shape4d
(
a
):
# for use with tensorflow NHWC ops
# for use with tensorflow NHWC ops
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
tensorpack/models/_test.py
View file @
fb2a051c
...
@@ -7,7 +7,9 @@ import tensorflow as tf
...
@@ -7,7 +7,9 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
class
TestModel
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
def
run_variable
(
self
,
var
):
def
run_variable
(
self
,
var
):
sess
=
tf
.
Session
()
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
...
@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
...
@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
else
:
else
:
return
tf
.
Variable
(
args
[
0
])
return
tf
.
Variable
(
args
[
0
])
def
run_test_case
(
case
):
def
run_test_case
(
case
):
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
...
@@ -34,5 +37,3 @@ if __name__ == '__main__':
...
@@ -34,5 +37,3 @@ if __name__ == '__main__':
subs
=
tensorpack
.
models
.
_test
.
TestModel
.
__subclasses__
()
subs
=
tensorpack
.
models
.
_test
.
TestModel
.
__subclasses__
()
for
cls
in
subs
:
for
cls
in
subs
:
run_test_case
(
cls
)
run_test_case
(
cls
)
tensorpack/models/batch_norm.py
View file @
fb2a051c
...
@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
...
@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
"""
...
@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
else
:
#
#
use statistics in another tower
# use statistics in another tower
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
...
@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
"""
...
@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# consider some fixed-param tasks, such as load model and fine tune one layer
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference
# fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
moving_mean, moving_var,
#
moving_mean, moving_var,
#
epsilon=epsilon, is_training=False, name='output')
#
epsilon=epsilon, is_training=False, name='output')
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
...
...
tensorpack/models/conv2d.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
...
@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
@
layer_register
()
@
layer_register
()
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'SAME'
,
stride
=
1
,
padding
=
'SAME'
,
stride
=
1
,
...
@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
for
i
,
k
in
zip
(
inputs
,
kernels
)]
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
3
,
outputs
)
conv
=
tf
.
concat
(
3
,
outputs
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
class
StaticDynamicShape
(
object
):
class
StaticDynamicShape
(
object
):
def
__init__
(
self
,
static
,
dynamic
):
def
__init__
(
self
,
static
,
dynamic
):
self
.
static
=
static
self
.
static
=
static
self
.
dynamic
=
dynamic
self
.
dynamic
=
dynamic
def
apply
(
self
,
f
):
def
apply
(
self
,
f
):
try
:
try
:
st
=
f
(
self
.
static
)
st
=
f
(
self
.
static
)
...
@@ -76,6 +81,7 @@ class StaticDynamicShape(object):
...
@@ -76,6 +81,7 @@ class StaticDynamicShape(object):
except
:
except
:
return
StaticDynamicShape
(
None
,
f
(
self
.
dynamic
))
return
StaticDynamicShape
(
None
,
f
(
self
.
dynamic
))
@
layer_register
()
@
layer_register
()
def
Deconv2D
(
x
,
out_shape
,
kernel_shape
,
def
Deconv2D
(
x
,
out_shape
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
stride
,
padding
=
'SAME'
,
...
...
tensorpack/models/fc.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
...
@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'FullyConnected'
]
__all__
=
[
'FullyConnected'
]
@
layer_register
()
@
layer_register
()
def
FullyConnected
(
x
,
out_dim
,
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
...
@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
...
@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
prod
,
name
=
'output'
)
return
nl
(
prod
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
fb2a051c
...
@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
...
@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
# XXX TODO ugly.
# XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing
# really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206
# See github:tensorflow#418,#206
def
sample
(
img
,
coords
):
def
sample
(
img
,
coords
):
"""
"""
:param img: bxhxwxc
:param img: bxhxwxc
...
@@ -33,14 +35,15 @@ def sample(img, coords):
...
@@ -33,14 +35,15 @@ def sample(img, coords):
# bxh2xw2
# bxh2xw2
batch_add
=
tf
.
range
(
tf
.
shape
(
img
)[
0
])
*
(
shape
[
0
]
*
shape
[
1
])
batch_add
=
tf
.
range
(
tf
.
shape
(
img
)[
0
])
*
(
shape
[
0
]
*
shape
[
1
])
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
flat_coords
=
coords
+
batch_add
flat_coords
=
coords
+
batch_add
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
sampled
=
tf
.
gather
(
img
,
flat_coords
)
sampled
=
tf
.
gather
(
img
,
flat_coords
)
return
sampled
return
sampled
@
layer_register
()
@
layer_register
()
def
ImageSample
(
inputs
,
borderMode
=
'repeat'
):
def
ImageSample
(
inputs
,
borderMode
=
'repeat'
):
"""
"""
...
@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
ucoor
=
lcoor
+
1
ucoor
=
lcoor
+
1
diff
=
mapping
-
lcoor
diff
=
mapping
-
lcoor
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
...
@@ -80,8 +83,8 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -80,8 +83,8 @@ def ImageSample(inputs, borderMode='repeat'):
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
tf.reduce_max(diff), diff], summarize=50)
#
tf.reduce_max(diff), diff], summarize=50)
ret
=
tf
.
add_n
([
sample
(
template
,
lcoor
)
*
neg_diffx
*
neg_diffy
,
ret
=
tf
.
add_n
([
sample
(
template
,
lcoor
)
*
neg_diffx
*
neg_diffy
,
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
...
@@ -91,36 +94,40 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -91,36 +94,40 @@ def ImageSample(inputs, borderMode='repeat'):
max_coor
=
tf
.
constant
([
input_shape
[
0
]
-
1
,
input_shape
[
1
]
-
1
],
dtype
=
tf
.
float32
)
max_coor
=
tf
.
constant
([
input_shape
[
0
]
-
1
,
input_shape
[
1
]
-
1
],
dtype
=
tf
.
float32
)
mask
=
tf
.
greater_equal
(
orig_mapping
,
0.0
)
mask
=
tf
.
greater_equal
(
orig_mapping
,
0.0
)
mask2
=
tf
.
less_equal
(
orig_mapping
,
max_coor
)
mask2
=
tf
.
less_equal
(
orig_mapping
,
max_coor
)
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
expand_dims
(
mask
,
3
)
mask
=
tf
.
expand_dims
(
mask
,
3
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
return
ret
return
ret
from
._test
import
TestModel
from
._test
import
TestModel
class
TestSample
(
TestModel
):
class
TestSample
(
TestModel
):
def
test_sample
(
self
):
def
test_sample
(
self
):
import
numpy
as
np
import
numpy
as
np
h
,
w
=
3
,
4
h
,
w
=
3
,
4
def
np_sample
(
img
,
coords
):
def
np_sample
(
img
,
coords
):
# a reference implementation
# a reference implementation
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
xs
=
coords
[:,
:,:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
xs
=
coords
[:,
:,
:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,
:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ret
=
np
.
zeros
((
img
.
shape
[
0
],
coords
.
shape
[
1
],
coords
.
shape
[
2
],
ret
=
np
.
zeros
((
img
.
shape
[
0
],
coords
.
shape
[
1
],
coords
.
shape
[
2
],
img
.
shape
[
3
]),
dtype
=
'float32'
)
img
.
shape
[
3
]),
dtype
=
'float32'
)
for
k
in
range
(
img
.
shape
[
0
]):
for
k
in
range
(
img
.
shape
[
0
]):
xss
,
yss
=
xs
[
k
],
ys
[
k
]
xss
,
yss
=
xs
[
k
],
ys
[
k
]
ret
[
k
,
:,:,:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
ret
[
k
,
:,
:,
:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
coords
.
shape
[
2
],
3
))
coords
.
shape
[
2
],
3
))
return
ret
return
ret
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
#mat = np.array([
#
mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
#], dtype='float32') #2x2x2x2
...
@@ -128,7 +135,7 @@ class TestSample(TestModel):
...
@@ -128,7 +135,7 @@ class TestSample(TestModel):
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
inp
,
mapping
=
self
.
make_variable
(
bimg
,
mat
)
inp
,
mapping
=
self
.
make_variable
(
bimg
,
mat
)
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
res
=
self
.
run_variable
(
output
)
res
=
self
.
run_variable
(
output
)
self
.
assertTrue
((
res
==
true_res
)
.
all
())
self
.
assertTrue
((
res
==
true_res
)
.
all
())
...
@@ -146,7 +153,7 @@ if __name__ == '__main__':
...
@@ -146,7 +153,7 @@ if __name__ == '__main__':
diff
=
200
diff
=
200
for
x
in
range
(
w
):
for
x
in
range
(
w
):
for
y
in
range
(
h
):
for
y
in
range
(
h
):
mapping
[
0
,
y
,
x
,:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapping
[
0
,
y
,
x
,
:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapv
=
tf
.
Variable
(
mapping
)
mapv
=
tf
.
Variable
(
mapping
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
],
borderMode
=
'constant'
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
],
borderMode
=
'constant'
)
...
@@ -155,12 +162,10 @@ if __name__ == '__main__':
...
@@ -155,12 +162,10 @@ if __name__ == '__main__':
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#out = sess.run(output)
#print(out[0].min())
#
print(out[0].min())
#print(out[0].max())
#
print(out[0].max())
#print(out[0].sum())
#
print(out[0].sum())
out
=
sess
.
run
([
output
])[
0
]
out
=
sess
.
run
([
output
])[
0
]
im
=
out
[
0
]
im
=
out
[
0
]
cv2
.
imwrite
(
'sampled.jpg'
,
im
)
cv2
.
imwrite
(
'sampled.jpg'
,
im
)
tensorpack/models/model_desc.py
View file @
fb2a051c
...
@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
...
@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class
InputVar
(
object
):
class
InputVar
(
object
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
self
.
type
=
type
self
.
type
=
type
self
.
shape
=
shape
self
.
shape
=
shape
self
.
name
=
name
self
.
name
=
name
self
.
sparse
=
sparse
self
.
sparse
=
sparse
def
dumps
(
self
):
def
dumps
(
self
):
return
pickle
.
dumps
(
self
)
return
pickle
.
dumps
(
self
)
@
staticmethod
@
staticmethod
def
loads
(
buf
):
def
loads
(
buf
):
return
pickle
.
loads
(
buf
)
return
pickle
.
loads
(
buf
)
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description """
""" Base class for a model description """
...
@@ -99,15 +105,17 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
...
@@ -99,15 +105,17 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def
get_gradient_processor
(
self
):
def
get_gradient_processor
(
self
):
""" Return a list of GradientProcessor. They will be executed in order"""
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
#
SummaryGradient(),
return
[
#
SummaryGradient(),
CheckGradient
()
CheckGradient
()
]
]
class
ModelFromMetaGraph
(
ModelDesc
):
class
ModelFromMetaGraph
(
ModelDesc
):
"""
"""
Load the whole exact TF graph from a saved meta_graph.
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
Only useful for inference.
"""
"""
def
__init__
(
self
,
filename
):
def
__init__
(
self
,
filename
):
tf
.
train
.
import_meta_graph
(
filename
)
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
...
...
tensorpack/models/nonlin.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
...
@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
@
layer_register
()
@
layer_register
()
def
Maxout
(
x
,
num_unit
):
def
Maxout
(
x
,
num_unit
):
"""
"""
...
@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
...
@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
x
=
tf
.
reshape
(
x
,
[
-
1
,
ch
/
num_unit
,
num_unit
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
ch
/
num_unit
,
num_unit
])
return
tf
.
reduce_max
(
x
,
ndim
,
name
=
'output'
)
return
tf
.
reduce_max
(
x
,
ndim
,
name
=
'output'
)
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
"""
"""
...
@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
...
@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name
=
'output'
name
=
'output'
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
"""
"""
...
@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
...
@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
#alpha = float(alpha)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
# return tf.mul(x, 0.5, name=name)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
def
BNReLU
(
x
,
name
=
None
):
def
BNReLU
(
x
,
name
=
None
):
...
...
tensorpack/models/pool.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
...
@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
'BilinearUpSample'
]
@
layer_register
()
@
layer_register
()
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
"""
...
@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
...
@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
@
layer_register
()
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
"""
...
@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
...
@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
avg_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
return
tf
.
nn
.
avg_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
@
layer_register
()
def
GlobalAvgPooling
(
x
):
def
GlobalAvgPooling
(
x
):
"""
"""
...
@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
...
@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
return
tf
.
reduce_mean
(
x
,
[
1
,
2
])
return
tf
.
reduce_mean
(
x
,
[
1
,
2
])
# https://github.com/tensorflow/tensorflow/issues/2169
# https://github.com/tensorflow/tensorflow/issues/2169
def
UnPooling2x2ZeroFilled
(
x
):
def
UnPooling2x2ZeroFilled
(
x
):
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)])
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)])
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)])
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)])
...
@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
...
@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
return
ret
return
ret
@
layer_register
()
@
layer_register
()
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
"""
"""
...
@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
...
@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
# perform a tensor-matrix kronecker product
# perform a tensor-matrix kronecker product
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
...
@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
...
@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
return
prod
return
prod
@
layer_register
()
@
layer_register
()
def
BilinearUpSample
(
x
,
shape
):
def
BilinearUpSample
(
x
,
shape
):
"""
"""
...
@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
...
@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
:param shape: an integer, the upsample factor
:param shape: an integer, the upsample factor
"""
"""
#inp_shape = tf.shape(x)
#inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x,
#
return tf.image.resize_bilinear(x,
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
align_corners=True)
#
align_corners=True)
inp_shape
=
x
.
get_shape
()
.
as_list
()
inp_shape
=
x
.
get_shape
()
.
as_list
()
ch
=
inp_shape
[
3
]
ch
=
inp_shape
[
3
]
...
@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
...
@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
shape
=
int
(
shape
)
shape
=
int
(
shape
)
filter_shape
=
2
*
shape
filter_shape
=
2
*
shape
def
bilinear_conv_filler
(
s
):
def
bilinear_conv_filler
(
s
):
"""
"""
s: width, height of the conv filter
s: width, height of the conv filter
...
@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
...
@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
ret
=
np
.
zeros
((
s
,
s
),
dtype
=
'float32'
)
ret
=
np
.
zeros
((
s
,
s
),
dtype
=
'float32'
)
for
x
in
range
(
s
):
for
x
in
range
(
s
):
for
y
in
range
(
s
):
for
y
in
range
(
s
):
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
return
ret
return
ret
w
=
bilinear_conv_filler
(
filter_shape
)
w
=
bilinear_conv_filler
(
filter_shape
)
w
=
np
.
repeat
(
w
,
ch
*
ch
)
.
reshape
((
filter_shape
,
filter_shape
,
ch
,
ch
))
w
=
np
.
repeat
(
w
,
ch
*
ch
)
.
reshape
((
filter_shape
,
filter_shape
,
ch
,
ch
))
...
@@ -156,16 +162,21 @@ def BilinearUpSample(x, shape):
...
@@ -156,16 +162,21 @@ def BilinearUpSample(x, shape):
name
=
'bilinear_upsample_filter'
)
name
=
'bilinear_upsample_filter'
)
deconv
=
tf
.
nn
.
conv2d_transpose
(
x
,
weight_var
,
deconv
=
tf
.
nn
.
conv2d_transpose
(
x
,
weight_var
,
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
[
1
,
shape
,
shape
,
1
],
'SAME'
)
[
1
,
shape
,
shape
,
1
],
'SAME'
)
if
inp_shape
[
1
]:
inp_shape
[
1
]
*=
shape
if
inp_shape
[
1
]:
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
inp_shape
[
1
]
*=
shape
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
deconv
.
set_shape
(
inp_shape
)
deconv
.
set_shape
(
inp_shape
)
return
deconv
return
deconv
from
._test
import
TestModel
from
._test
import
TestModel
class
TestPool
(
TestModel
):
class
TestPool
(
TestModel
):
def
test_fixed_unpooling
(
self
):
def
test_fixed_unpooling
(
self
):
h
,
w
=
3
,
4
h
,
w
=
3
,
4
mat
=
np
.
random
.
rand
(
h
,
w
,
3
)
.
astype
(
'float32'
)
mat
=
np
.
random
.
rand
(
h
,
w
,
3
)
.
astype
(
'float32'
)
...
@@ -173,13 +184,13 @@ class TestPool(TestModel):
...
@@ -173,13 +184,13 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
3
])
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
3
])
output
=
FixedUnPooling
(
'unpool'
,
inp
,
2
)
output
=
FixedUnPooling
(
'unpool'
,
inp
,
2
)
res
=
self
.
run_variable
(
output
)
res
=
self
.
run_variable
(
output
)
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
# mat is on cornser
# mat is on cornser
ele
=
res
[
0
,
::
2
,::
2
,
0
]
ele
=
res
[
0
,
::
2
,
::
2
,
0
]
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
# the rest are zeros
# the rest are zeros
res
[
0
,
::
2
,::
2
,
:]
=
0
res
[
0
,
::
2
,
::
2
,
:]
=
0
self
.
assertTrue
((
res
==
0
)
.
all
())
self
.
assertTrue
((
res
==
0
)
.
all
())
def
test_upsample
(
self
):
def
test_upsample
(
self
):
...
@@ -191,7 +202,7 @@ class TestPool(TestModel):
...
@@ -191,7 +202,7 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
1
])
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
1
])
output
=
BilinearUpSample
(
'upsample'
,
inp
,
scale
)
output
=
BilinearUpSample
(
'upsample'
,
inp
,
scale
)
res
=
self
.
run_variable
(
output
)[
0
,
:,:,
0
]
res
=
self
.
run_variable
(
output
)[
0
,
:,
:,
0
]
from
skimage.transform
import
rescale
from
skimage.transform
import
rescale
res2
=
rescale
(
mat
,
scale
)
res2
=
rescale
(
mat
,
scale
)
...
@@ -199,9 +210,9 @@ class TestPool(TestModel):
...
@@ -199,9 +210,9 @@ class TestPool(TestModel):
diff
=
np
.
abs
(
res2
-
res
)
diff
=
np
.
abs
(
res2
-
res
)
# not equivalent to rescale on edge?
# not equivalent to rescale on edge?
diff
[
0
,:]
=
0
diff
[
0
,
:]
=
0
diff
[:,
0
]
=
0
diff
[:,
0
]
=
0
if
not
diff
.
max
()
<
1e-4
:
if
not
diff
.
max
()
<
1e-4
:
import
IPython
;
import
IPython
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
self
.
assertTrue
(
diff
.
max
()
<
1e-4
)
self
.
assertTrue
(
diff
.
max
()
<
1e-4
)
tensorpack/models/regularize.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ._common import layer_register
...
@@ -12,6 +12,7 @@ from ._common import layer_register
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
@
memoized
@
memoized
def
_log_regularizer
(
name
):
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
...
@@ -19,6 +20,7 @@ def _log_regularizer(name):
...
@@ -19,6 +20,7 @@ def _log_regularizer(name):
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
def
regularize_cost
(
regex
,
func
,
name
=
None
):
def
regularize_cost
(
regex
,
func
,
name
=
None
):
"""
"""
Apply a regularizer on every trainable variable matching the regex.
Apply a regularizer on every trainable variable matching the regex.
...
@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
...
@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
is_training
=
get_current_tower_context
()
.
is_training
is_training
=
get_current_tower_context
()
.
is_training
keep_prob
=
tf
.
constant
(
keep_prob
if
is_training
else
1.0
)
keep_prob
=
tf
.
constant
(
keep_prob
if
is_training
else
1.0
)
return
tf
.
nn
.
dropout
(
x
,
keep_prob
)
return
tf
.
nn
.
dropout
(
x
,
keep_prob
)
tensorpack/models/shapes.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ from ._common import layer_register
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'ConcatWith'
]
__all__
=
[
'ConcatWith'
]
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
ConcatWith
(
x
,
dim
,
tensor
):
def
ConcatWith
(
x
,
dim
,
tensor
):
"""
"""
...
...
tensorpack/models/softmax.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ from ._common import layer_register
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'SoftMax'
]
__all__
=
[
'SoftMax'
]
@
layer_register
()
@
layer_register
()
def
SoftMax
(
x
,
use_temperature
=
False
,
temperature_init
=
1.0
):
def
SoftMax
(
x
,
use_temperature
=
False
,
temperature_init
=
1.0
):
"""
"""
...
...
tensorpack/predict/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ import os.path
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
if
module_name
.
startswith
(
'_'
):
continue
continue
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/predict/base.py
View file @
fb2a051c
...
@@ -16,6 +16,7 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
...
@@ -16,6 +16,7 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
'DataParallelOfflinePredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
PredictorBase
(
object
):
class
PredictorBase
(
object
):
"""
"""
...
@@ -46,7 +47,9 @@ class PredictorBase(object):
...
@@ -46,7 +47,9 @@ class PredictorBase(object):
:return: output as defined by the config
:return: output as defined by the config
"""
"""
class
AsyncPredictorBase
(
PredictorBase
):
class
AsyncPredictorBase
(
PredictorBase
):
@
abstractmethod
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
"""
...
@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
# in Tornado, Future.result() doesn't wait
# in Tornado, Future.result() doesn't wait
return
fut
.
result
()
return
fut
.
result
()
class
OnlinePredictor
(
PredictorBase
):
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
self
.
session
=
sess
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
return_input
=
return_input
...
@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
...
@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
class
OfflinePredictor
(
OnlinePredictor
):
""" Build a predictor from a given config, in an independent graph"""
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
...
@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
...
@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
"""
"""
for
k
in
towers
:
for
k
in
towers
:
logger
.
info
(
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
build_tower_fn
(
k
)
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
self
.
predictors
=
[]
...
@@ -130,7 +138,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -130,7 +138,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
for
k
in
towers
:
output_vars
=
get_tensors_by_names
(
output_vars
=
get_tensors_by_names
(
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
\
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
for
n
in
config
.
output_names
])
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictors
(
self
,
n
):
def
get_predictors
(
self
,
n
):
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
...
@@ -161,7 +171,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -161,7 +171,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
\
[
towername
+
'/'
+
n
for
n
in
config
.
output_names
]))
for
n
in
config
.
output_names
]))
input_vars
=
get_tensors_by_names
(
input_var_names
)
input_vars
=
get_tensors_by_names
(
input_var_names
)
...
...
tensorpack/predict/common.py
View file @
fb2a051c
...
@@ -15,11 +15,13 @@ from .base import OfflinePredictor
...
@@ -15,11 +15,13 @@ from .base import OfflinePredictor
import
multiprocessing
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
The config used by `get_predict_func`.
The config used by `get_predict_func`.
...
@@ -61,12 +63,14 @@ class PredictConfig(object):
...
@@ -61,12 +63,14 @@ class PredictConfig(object):
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_names
),
self
.
output_names
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Produce a offline predictor run inside a new session.
Produce a offline predictor run inside a new session.
...
@@ -76,4 +80,3 @@ def get_predict_func(config):
...
@@ -76,4 +80,3 @@ def get_predict_func(config):
a list of output values defined in ``config.output_var_names``.
a list of output values defined in ``config.output_var_names``.
"""
"""
return
OfflinePredictor
(
config
)
return
OfflinePredictor
(
config
)
tensorpack/predict/concurrency.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: concurrency.py
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
,
threading
import
multiprocessing
import
threading
import
tensorflow
as
tf
import
tensorflow
as
tf
import
time
import
time
import
six
import
six
...
@@ -27,8 +28,10 @@ else:
...
@@ -27,8 +28,10 @@ else:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
'MultiThreadAsyncPredictor'
]
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
""" Base class for predict worker that runs offline in multiprocess"""
""" Base class for predict worker that runs offline in multiprocess"""
def
__init__
(
self
,
idx
,
config
):
def
__init__
(
self
,
idx
,
config
):
"""
"""
:param idx: index of the worker. the 0th worker will print log.
:param idx: index of the worker. the 0th worker will print log.
...
@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
with
self
.
predictor
.
graph
.
as_default
():
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
describe_model
()
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" An offline predictor worker that takes input and produces output by queue"""
""" An offline predictor worker that takes input and produces output by queue"""
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
"""
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param inqueue: input queue to get data point. elements are (task_id, dp)
...
@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class
PredictorWorkerThread
(
threading
.
Thread
):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
while
True
:
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
outputs
=
self
.
func
(
batched
)
#print "Worker {} batched {} Queue {}".format(
#
print "Worker {} batched {} Queue {}".format(
#
self.id, len(futures), self.queue.qsize())
#
self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# debug, for speed testing
#if not hasattr(self, 'xxx'):
#
if not hasattr(self, 'xxx'):
#
self.xxx = outputs = self.func(batched)
#
self.xxx = outputs = self.func(batched)
#else:
#
else:
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
cnt
+=
1
cnt
+=
1
return
batched
,
futures
return
batched
,
futures
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
"""
An multithread online async predictor which run a list of PredictorBase.
An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally.
It would do an extra batching internally.
"""
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
""" :param predictors: a list of OnlinePredictor"""
assert
len
(
predictors
)
assert
len
(
predictors
)
...
@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
#assert isinstance(k, OnlinePredictor), type(k)
#assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
...
...
tensorpack/predict/dataset.py
View file @
fb2a051c
...
@@ -22,8 +22,10 @@ from .base import OfflinePredictor
...
@@ -22,8 +22,10 @@ from .base import OfflinePredictor
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
'MultiProcessDatasetPredictor'
]
'MultiProcessDatasetPredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DatasetPredictorBase
(
object
):
class
DatasetPredictorBase
(
object
):
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
"""
"""
:param config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
...
@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
...
@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
"""
"""
return
list
(
self
.
get_result
())
return
list
(
self
.
get_result
())
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
"""
"""
Run the predict_config on a given `DataFlow`.
Run the predict_config on a given `DataFlow`.
"""
"""
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
self
.
predictor
=
OfflinePredictor
(
config
)
...
@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
...
@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
sz
=
0
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
for
dp
in
self
.
dataset
.
get_data
():
for
dp
in
self
.
dataset
.
get_data
():
res
=
self
.
predictor
(
dp
)
res
=
self
.
predictor
(
dp
)
yield
res
yield
res
pbar
.
update
()
pbar
.
update
()
# TODO allow unordered
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
...
@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
sz
=
0
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
die_cnt
=
0
die_cnt
=
0
while
True
:
while
True
:
res
=
self
.
result_queue
.
get
()
res
=
self
.
result_queue
.
get
()
...
@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
result_queue
.
join
()
self
.
result_queue
.
join
()
self
.
result_queue
.
terminate
()
self
.
result_queue
.
terminate
()
for
p
in
self
.
workers
:
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
p
.
join
()
p
.
terminate
()
tensorpack/tfutils/argscope.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
...
@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack
=
[]
_ArgScopeStack
=
[]
@
contextmanager
@
contextmanager
def
argscope
(
layers
,
**
param
):
def
argscope
(
layers
,
**
param
):
if
not
isinstance
(
layers
,
list
):
if
not
isinstance
(
layers
,
list
):
...
@@ -33,6 +34,7 @@ def argscope(layers, **param):
...
@@ -33,6 +34,7 @@ def argscope(layers, **param):
yield
yield
del
_ArgScopeStack
[
-
1
]
del
_ArgScopeStack
[
-
1
]
def
get_arg_scope
():
def
get_arg_scope
():
"""
"""
:returns: the current argscope.
:returns: the current argscope.
...
...
tensorpack/tfutils/common.py
View file @
fb2a051c
...
@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
...
@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
'freeze_collection'
,
'freeze_collection'
,
'get_tf_version'
]
'get_tf_version'
]
def
get_default_sess_config
(
mem_fraction
=
0.99
):
def
get_default_sess_config
(
mem_fraction
=
0.99
):
"""
"""
Return a better session config to use as default.
Return a better session config to use as default.
...
@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
#conf.log_device_placement = True
#conf.log_device_placement = True
return
conf
return
conf
def
get_global_step_var
():
def
get_global_step_var
():
""" :returns: the global_step variable in the current graph. create if not existed"""
""" :returns: the global_step variable in the current graph. create if not existed"""
try
:
try
:
...
@@ -52,12 +54,14 @@ def get_global_step_var():
...
@@ -52,12 +54,14 @@ def get_global_step_var():
trainable
=
False
,
dtype
=
tf
.
int32
)
trainable
=
False
,
dtype
=
tf
.
int32
)
return
var
return
var
def
get_global_step
():
def
get_global_step
():
""" :returns: global_step value in current graph and session"""
""" :returns: global_step value in current graph and session"""
return
tf
.
train
.
global_step
(
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
get_global_step_var
())
get_global_step_var
())
def
get_op_tensor_name
(
name
):
def
get_op_tensor_name
(
name
):
"""
"""
Tensor name is assumed to be ``op_name + ':0'``
Tensor name is assumed to be ``op_name + ':0'``
...
@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
...
@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
get_op_var_name
=
get_op_tensor_name
get_op_var_name
=
get_op_tensor_name
def
get_tensors_by_names
(
names
):
def
get_tensors_by_names
(
names
):
"""
"""
Get a list of tensors in the default graph by a list of names
Get a list of tensors in the default graph by a list of names
...
@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
...
@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
get_vars_by_names
=
get_tensors_by_names
get_vars_by_names
=
get_tensors_by_names
def
backup_collection
(
keys
):
def
backup_collection
(
keys
):
ret
=
{}
ret
=
{}
for
k
in
keys
:
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
return
ret
def
restore_collection
(
backup
):
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
def
clear_collection
(
keys
):
for
k
in
keys
:
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
@
contextmanager
def
freeze_collection
(
keys
):
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
backup
=
backup_collection
(
keys
)
yield
yield
restore_collection
(
backup
)
restore_collection
(
backup
)
def
get_tf_version
():
def
get_tf_version
():
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
tensorpack/tfutils/gradproc.py
View file @
fb2a051c
...
@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
...
@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'GlobalNormClip'
]
'GlobalNormClip'
]
def
apply_grad_processors
(
grads
,
gradprocs
):
def
apply_grad_processors
(
grads
,
gradprocs
):
"""
"""
:param grads: list of (grad, var).
:param grads: list of (grad, var).
...
@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
...
@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
g
=
proc
.
process
(
g
)
g
=
proc
.
process
(
g
)
return
g
return
g
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
GradientProcessor
(
object
):
class
GradientProcessor
(
object
):
...
@@ -51,6 +53,7 @@ class GradientProcessor(object):
...
@@ -51,6 +53,7 @@ class GradientProcessor(object):
class
GlobalNormClip
(
GradientProcessor
):
class
GlobalNormClip
(
GradientProcessor
):
def
__init__
(
self
,
global_norm
):
def
__init__
(
self
,
global_norm
):
""" Clip by global norm
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
Note that the global norm is the sum of norm for **all** gradients
...
@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
...
@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
g
,
_
=
tf
.
clip_by_global_norm
(
g
,
self
.
_norm
,
name
=
'clip_by_global_norm'
)
g
,
_
=
tf
.
clip_by_global_norm
(
g
,
self
.
_norm
,
name
=
'clip_by_global_norm'
)
return
list
(
zip
(
g
,
v
))
return
list
(
zip
(
g
,
v
))
class
MapGradient
(
GradientProcessor
):
class
MapGradient
(
GradientProcessor
):
"""
"""
Apply a function on all gradient if the name matches regex.
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
Keep the other gradients unchanged.
"""
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
...
@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
...
@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
_summaried_gradient
=
set
()
_summaried_gradient
=
set
()
class
SummaryGradient
(
MapGradient
):
class
SummaryGradient
(
MapGradient
):
"""
"""
Summary history and RMS for each graident variable
Summary history and RMS for each graident variable
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
...
@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
return
grad
return
grad
class
CheckGradient
(
MapGradient
):
class
CheckGradient
(
MapGradient
):
"""
"""
Check for numeric issue.
Check for numeric issue.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
...
@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
return
grad
return
grad
class
ScaleGradient
(
MapGradient
):
class
ScaleGradient
(
MapGradient
):
"""
"""
Scale certain gradient by a multiplier
Scale certain gradient by a multiplier
"""
"""
def
__init__
(
self
,
multipliers
,
log
=
True
):
def
__init__
(
self
,
multipliers
,
log
=
True
):
"""
"""
:param multipliers: list of (regex, float)
:param multipliers: list of (regex, float)
...
...
tensorpack/tfutils/modelutils.py
View file @
fb2a051c
...
@@ -9,6 +9,7 @@ from ..utils import logger
...
@@ -9,6 +9,7 @@ from ..utils import logger
__all__
=
[
'describe_model'
,
'get_shape_str'
]
__all__
=
[
'describe_model'
,
'get_shape_str'
]
def
describe_model
():
def
describe_model
():
""" print a description of the current model parameters """
""" print a description of the current model parameters """
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
@@ -40,5 +41,3 @@ def get_shape_str(tensors):
...
@@ -40,5 +41,3 @@ def get_shape_str(tensors):
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
return
shape_str
tensorpack/tfutils/sessinit.py
View file @
fb2a051c
...
@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
...
@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default.
# TODO they initialize_all at the beginning by default.
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
""" Base class for utilities to initialize a session"""
...
@@ -35,23 +36,29 @@ class SessionInit(object):
...
@@ -35,23 +36,29 @@ class SessionInit(object):
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
pass
pass
class
JustCurrentSession
(
SessionInit
):
class
JustCurrentSession
(
SessionInit
):
""" Just use the current default session. This is a no-op placeholder"""
""" Just use the current default session. This is a no-op placeholder"""
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
pass
pass
class
NewSession
(
SessionInit
):
class
NewSession
(
SessionInit
):
"""
"""
Create a new session. All variables will be initialized by their
Create a new session. All variables will be initialized by their
initializer.
initializer.
"""
"""
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
class
SaverRestore
(
SessionInit
):
class
SaverRestore
(
SessionInit
):
"""
"""
Restore an old model saved by `ModelSaver`.
Restore an old model saved by `ModelSaver`.
"""
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
...
@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
...
@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
return
var_dict
return
var_dict
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
"""
"""
Restore variables from a dictionary.
Restore variables from a dictionary.
"""
"""
def
__init__
(
self
,
param_dict
):
def
__init__
(
self
,
param_dict
):
"""
"""
:param param_dict: a dict of {name: value}
:param param_dict: a dict of {name: value}
...
@@ -174,7 +183,7 @@ class ParamRestore(SessionInit):
...
@@ -174,7 +183,7 @@ class ParamRestore(SessionInit):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
\
[
v
for
v
in
variables
if
get_savename_from_varname
(
v
.
name
)
in
intersect
])
get_savename_from_varname
(
v
.
name
)
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
...
@@ -182,6 +191,7 @@ class ParamRestore(SessionInit):
...
@@ -182,6 +191,7 @@ class ParamRestore(SessionInit):
class
ChainInit
(
SessionInit
):
class
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
""" Init a session by a list of SessionInit instance."""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
"""
:params sess_inits: list of `SessionInit` instances.
:params sess_inits: list of `SessionInit` instances.
...
...
tensorpack/tfutils/summary.py
View file @
fb2a051c
...
@@ -15,6 +15,7 @@ from .symbolic_functions import rms
...
@@ -15,6 +15,7 @@ from .symbolic_functions import rms
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'add_moving_summary'
,
'summary_moving_average'
]
'add_moving_summary'
,
'summary_moving_average'
]
def
create_summary
(
name
,
v
):
def
create_summary
(
name
,
v
):
"""
"""
Return a tf.Summary object with name and simple scalar value v
Return a tf.Summary object with name and simple scalar value v
...
@@ -25,6 +26,7 @@ def create_summary(name, v):
...
@@ -25,6 +26,7 @@ def create_summary(name, v):
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
return
s
return
s
def
add_activation_summary
(
x
,
name
=
None
):
def
add_activation_summary
(
x
,
name
=
None
):
"""
"""
Add summary to graph for an activation tensor x.
Add summary to graph for an activation tensor x.
...
@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
...
@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
tf
.
summary
.
scalar
(
name
+
'-sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
summary
.
scalar
(
name
+
'-sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
summary
.
scalar
(
name
+
'-rms'
,
rms
(
x
))
tf
.
summary
.
scalar
(
name
+
'-rms'
,
rms
(
x
))
def
add_param_summary
(
summary_lists
):
def
add_param_summary
(
summary_lists
):
"""
"""
Add summary for all trainable variables matching the regex
Add summary for all trainable variables matching the regex
...
@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
...
@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
return
return
def
perform
(
var
,
action
):
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
ndim
=
var
.
get_shape
()
.
ndims
name
=
var
.
name
.
replace
(
':0'
,
''
)
name
=
var
.
name
.
replace
(
':0'
,
''
)
...
@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
...
@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
for
act
in
actions
:
for
act
in
actions
:
perform
(
p
,
act
)
perform
(
p
,
act
)
def
add_moving_summary
(
v
,
*
args
):
def
add_moving_summary
(
v
,
*
args
):
"""
"""
:param v: tensor or list of tensor to summary
:param v: tensor or list of tensor to summary
...
@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
...
@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
@
memoized
@
memoized
def
summary_moving_average
(
tensors
=
None
):
def
summary_moving_average
(
tensors
=
None
):
"""
"""
...
@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
...
@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
tensorpack/tfutils/symbolic_functions.py
View file @
fb2a051c
...
@@ -6,6 +6,7 @@ import tensorflow as tf
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
from
..utils
import
logger
from
..utils
import
logger
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
"""
"""
:param logits: NxC
:param logits: NxC
...
@@ -15,12 +16,14 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
...
@@ -15,12 +16,14 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
tf
.
float32
,
name
=
name
)
def
flatten
(
x
):
def
flatten
(
x
):
"""
"""
Flatten the tensor.
Flatten the tensor.
"""
"""
return
tf
.
reshape
(
x
,
[
-
1
])
return
tf
.
reshape
(
x
,
[
-
1
])
def
batch_flatten
(
x
):
def
batch_flatten
(
x
):
"""
"""
Flatten the tensor except the first dimension.
Flatten the tensor except the first dimension.
...
@@ -30,6 +33,7 @@ def batch_flatten(x):
...
@@ -30,6 +33,7 @@ def batch_flatten(x):
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
tf
.
pack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
return
tf
.
reshape
(
x
,
tf
.
pack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
The class-balanced cross entropy loss,
The class-balanced cross entropy loss,
...
@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
...
@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
,
name
=
name
)
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
,
name
=
name
)
return
cost
return
cost
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
The class-balanced cross entropy loss,
The class-balanced cross entropy loss,
...
@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
...
@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#loss_pos = -beta * tf.reduce_mean(-y *
#
loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
#(logstable - tf.minimum(0.0, z)))
#loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#
loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
#cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
return
cost
def
print_stat
(
x
,
message
=
None
):
def
print_stat
(
x
,
message
=
None
):
""" a simple print op.
""" a simple print op.
Use it like: x = print_stat(x)
Use it like: x = print_stat(x)
...
@@ -91,6 +97,7 @@ def print_stat(x, message=None):
...
@@ -91,6 +97,7 @@ def print_stat(x, message=None):
return
tf
.
Print
(
x
,
[
tf
.
shape
(
x
),
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
return
tf
.
Print
(
x
,
[
tf
.
shape
(
x
),
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
def
rms
(
x
,
name
=
None
):
def
rms
(
x
,
name
=
None
):
if
name
is
None
:
if
name
is
None
:
name
=
x
.
op
.
name
+
'/rms'
name
=
x
.
op
.
name
+
'/rms'
...
@@ -98,6 +105,7 @@ def rms(x, name=None):
...
@@ -98,6 +105,7 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
sqrcost
=
tf
.
square
(
x
)
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
abscost
=
tf
.
abs
(
x
)
...
@@ -107,6 +115,7 @@ def huber_loss(x, delta=1, name='huber_loss'):
...
@@ -107,6 +115,7 @@ def huber_loss(x, delta=1, name='huber_loss'):
abscost
*
delta
-
0.5
*
delta
**
2
),
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
name
=
name
)
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
"""
"""
get a scalar variable with certain initial value
get a scalar variable with certain initial value
...
...
tensorpack/tfutils/tower.py
View file @
fb2a051c
...
@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
...
@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext
=
None
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
self
.
_name
=
tower_name
...
@@ -78,7 +80,7 @@ class TowerContext(object):
...
@@ -78,7 +80,7 @@ class TowerContext(object):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
return
False
def
get_current_tower_context
():
def
get_current_tower_context
():
global
_CurrentTowerContext
global
_CurrentTowerContext
return
_CurrentTowerContext
return
_CurrentTowerContext
tensorpack/tfutils/varmanip.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: varmanip.py
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
,
os
import
six
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
defaultdict
from
collections
import
defaultdict
import
re
import
re
...
@@ -15,6 +16,7 @@ from .common import get_op_tensor_name
...
@@ -15,6 +16,7 @@ from .common import get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
,
'is_training_name'
]
'get_savename_from_varname'
,
'is_training_name'
]
def
get_savename_from_varname
(
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
varname
,
varname_prefix
=
None
,
savename_prefix
=
None
):
savename_prefix
=
None
):
...
@@ -33,13 +35,15 @@ def get_savename_from_varname(
...
@@ -33,13 +35,15 @@ def get_savename_from_varname(
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
varname_prefix
is
not
None
\
if
varname_prefix
is
not
None
\
and
name
.
startswith
(
varname_prefix
):
and
name
.
startswith
(
varname_prefix
):
name
=
name
[
len
(
varname_prefix
)
+
1
:]
name
=
name
[
len
(
varname_prefix
)
+
1
:]
if
savename_prefix
is
not
None
:
if
savename_prefix
is
not
None
:
name
=
savename_prefix
+
'/'
+
name
name
=
savename_prefix
+
'/'
+
name
return
name
return
name
class
SessionUpdate
(
object
):
class
SessionUpdate
(
object
):
""" Update the variables in a session """
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
def
__init__
(
self
,
sess
,
vars_to_update
):
"""
"""
:param vars_to_update: a collection of variables to update
:param vars_to_update: a collection of variables to update
...
@@ -71,6 +75,7 @@ class SessionUpdate(object):
...
@@ -71,6 +75,7 @@ class SessionUpdate(object):
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
npy format, loadable by ParamRestore
...
@@ -90,6 +95,7 @@ the same name".format(v.name))
...
@@ -90,6 +95,7 @@ the same name".format(v.name))
logger
.
info
(
str
(
result
.
keys
()))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
np
.
save
(
path
,
result
)
def
dump_chkpt_vars
(
model_path
):
def
dump_chkpt_vars
(
model_path
):
""" Dump all variables from a checkpoint to a dict"""
""" Dump all variables from a checkpoint to a dict"""
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
...
@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
...
@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
result
[
n
]
=
reader
.
get_tensor
(
n
)
result
[
n
]
=
reader
.
get_tensor
(
n
)
return
result
return
result
def
is_training_name
(
name
):
def
is_training_name
(
name
):
"""
"""
This is only used to improve logging.
This is only used to improve logging.
...
...
tensorpack/train/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ import os.path
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
if
module_name
.
startswith
(
'_'
):
continue
continue
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/train/base.py
View file @
fb2a051c
...
@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
...
@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
class
StopTraining
(
BaseException
):
pass
pass
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer."""
""" Base class for a trainer."""
...
@@ -138,7 +141,7 @@ class Trainer(object):
...
@@ -138,7 +141,7 @@ class Trainer(object):
callbacks
.
before_train
()
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
for
epoch_num
in
range
(
for
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
'Epoch {} (global_step {})'
.
format
(
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
...
...
tensorpack/train/config.py
View file @
fb2a051c
...
@@ -14,10 +14,12 @@ from .input_data import InputData
...
@@ -14,10 +14,12 @@ from .input_data import InputData
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
class
TrainConfig
(
object
):
class
TrainConfig
(
object
):
"""
"""
Config for training a model with a single loss
Config for training a model with a single loss
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param dataset: the dataset to train. a `DataFlow` instance.
...
...
tensorpack/train/feedfree.py
View file @
fb2a051c
...
@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
...
@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster) """
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
# note that summary_op will take a data from the queue
...
@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
...
@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
()
actual_inputs
=
self
.
_get_input_tensors
()
...
@@ -50,26 +54,28 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...
@@ -50,26 +54,28 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run self.train_op"""
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
#if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
#self.cnt = 0
# self.cnt = 0
#else:
# else:
#self.cnt += 1
# self.cnt += 1
#if self.cnt % 10 == 0:
# if self.cnt % 10 == 0:
## debug-benchmark code:
# # debug-benchmark code:
#run_metadata = tf.RunMetadata()
# run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
# self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
# options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
# run_metadata=run_metadata
#)
# )
#from tensorflow.python.client import timeline
# from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
# trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
# trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
# trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
# import sys; sys.exit()
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
A trainer with single cost, single training tower and feed-free input
A trainer with single cost, single training tower and feed-free input
...
@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
...
@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
...
...
tensorpack/train/input_data.py
View file @
fb2a051c
...
@@ -16,11 +16,14 @@ from ..callbacks.concurrency import StartProcOrThread
...
@@ -16,11 +16,14 @@ from ..callbacks.concurrency import StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
'DummyConstantInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
class
InputData
(
object
):
pass
pass
class
FeedInput
(
InputData
):
class
FeedInput
(
InputData
):
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -39,7 +42,9 @@ class FeedInput(InputData):
...
@@ -39,7 +42,9 @@ class FeedInput(InputData):
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
return
feed
return
feed
class
FeedfreeInput
(
InputData
):
class
FeedfreeInput
(
InputData
):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
return
self
.
_get_input_tensors
()
return
self
.
_get_input_tensors
()
...
@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
...
@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
always create and return a list of new input tensors
always create and return a list of new input tensors
"""
"""
class
EnqueueThread
(
threading
.
Thread
):
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
name
=
'EnqueueThread'
...
@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
...
@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#
print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
as
e
:
pass
pass
...
@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
...
@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
pass
pass
logger
.
info
(
"Enqueue Thread Exited."
)
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInput
(
FeedfreeInput
):
class
QueueInput
(
FeedfreeInput
):
def
__init__
(
self
,
ds
,
queue
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
"""
:param ds: a `DataFlow` instance
:param ds: a `DataFlow` instance
...
@@ -126,14 +135,16 @@ class QueueInput(FeedfreeInput):
...
@@ -126,14 +135,16 @@ class QueueInput(FeedfreeInput):
qv
.
set_shape
(
v
.
get_shape
())
qv
.
set_shape
(
v
.
get_shape
())
# test the overhead of queue
# test the overhead of queue
#with tf.device('/gpu:0'):
#
with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
dtype=tf.float32), trainable=False),
#
dtype=tf.float32), trainable=False),
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return
ret
return
ret
class
DummyConstantInput
(
QueueInput
):
class
DummyConstantInput
(
QueueInput
):
""" only for debugging performance issues """
""" only for debugging performance issues """
def
__init__
(
self
,
ds
,
shapes
):
def
__init__
(
self
,
ds
,
shapes
):
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
self
.
shapes
=
shapes
self
.
shapes
=
shapes
...
@@ -150,7 +161,9 @@ class DummyConstantInput(QueueInput):
...
@@ -150,7 +161,9 @@ class DummyConstantInput(QueueInput):
initializer
=
tf
.
constant_initializer
()))
initializer
=
tf
.
constant_initializer
()))
return
ret
return
ret
class
TensorInput
(
FeedfreeInput
):
class
TensorInput
(
FeedfreeInput
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
self
.
get_tensor_fn
=
get_tensor_fn
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
self
.
_size
=
size
...
...
tensorpack/train/multigpu.py
View file @
fb2a051c
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
itertools
,
re
import
itertools
import
re
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils
import
logger
...
@@ -22,6 +23,7 @@ from .input_data import QueueInput
...
@@ -22,6 +23,7 @@ from .input_data import QueueInput
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Trainer
):
class
MultiGPUTrainer
(
Trainer
):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
@
staticmethod
@
staticmethod
...
@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
...
@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
restore_collection
(
backup
)
restore_collection
(
backup
)
return
grad_list
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
...
@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
if
len
(
tower_grads
)
==
1
:
if
len
(
tower_grads
)
==
1
:
...
@@ -97,7 +100,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -97,7 +100,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# debug tower performance:
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#self.train_op = tf.group(*ops)
#return
#
return
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
...
@@ -109,9 +112,11 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -109,9 +112,11 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
run_step
(
self
):
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
input_queue
=
None
,
average_gradient
=
True
,
average_gradient
=
True
,
...
@@ -157,6 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -157,6 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self
.
training_threads
=
[]
self
.
training_threads
=
[]
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)):
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
self
.
sess
.
run
([
op
])
next
(
self
.
async_step_counter
)
next
(
self
.
async_step_counter
)
...
...
tensorpack/train/trainer.py
View file @
fb2a051c
...
@@ -16,7 +16,8 @@ from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
...
@@ -16,7 +16,8 @@ from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
,
FeedfreeInput
from
.input_data
import
FeedInput
,
FeedfreeInput
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
class
PredictorFactory
(
object
):
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
""" Make predictors for a trainer"""
...
@@ -52,8 +53,10 @@ class PredictorFactory(object):
...
@@ -52,8 +53,10 @@ class PredictorFactory(object):
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer """
""" A naive demo trainer """
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
...
@@ -93,8 +96,10 @@ class SimpleTrainer(Trainer):
...
@@ -93,8 +96,10 @@ class SimpleTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
MultiPredictorTowerTrainer
(
Trainer
):
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
,
predict_tower
):
def
_setup_predictor_factory
(
self
,
predict_tower
):
# by default, use the first training gpu for prediction
# by default, use the first training gpu for prediction
predict_tower
=
predict_tower
or
[
0
]
predict_tower
=
predict_tower
or
[
0
]
...
...
tensorpack/utils/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
...
@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -23,7 +24,7 @@ _TO_IMPORT = set([
...
@@ -23,7 +24,7 @@ _TO_IMPORT = set([
'naming'
,
'naming'
,
'utils'
,
'utils'
,
'gpu'
'gpu'
])
])
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
for
_
,
module_name
,
_
in
walk_packages
(
for
_
,
module_name
,
_
in
walk_packages
(
...
@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
...
@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
if
module_name
in
_TO_IMPORT
:
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
_global_import
(
module_name
)
__all__
.
append
(
module_name
)
__all__
.
append
(
module_name
)
tensorpack/utils/argtools.py
View file @
fb2a051c
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
import
operator
import
operator
import
inspect
,
six
,
functools
import
inspect
import
six
import
functools
import
collections
import
collections
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
"""
"""
...
@@ -26,11 +29,13 @@ def map_arg(**maps):
...
@@ -26,11 +29,13 @@ def map_arg(**maps):
return
wrapper
return
wrapper
return
deco
return
deco
class
memoized
(
object
):
class
memoized
(
object
):
'''Decorator. Caches a function's return value each time it is called.
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
If called later with the same arguments, the cached value is returned
(not reevaluated).
(not reevaluated).
'''
'''
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
self
.
func
=
func
self
.
func
=
func
self
.
cache
=
{}
self
.
cache
=
{}
...
@@ -60,8 +65,11 @@ class memoized(object):
...
@@ -60,8 +65,11 @@ class memoized(object):
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
if
func
not
in
_MEMOIZED_NOARGS
:
res
=
func
(
*
args
,
**
kwargs
)
res
=
func
(
*
args
,
**
kwargs
)
...
@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
...
@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
return
_MEMOIZED_NOARGS
[
func
]
return
_MEMOIZED_NOARGS
[
func
]
return
wrapper
return
wrapper
#_GLOBAL_MEMOIZED_CACHE = dict()
# _GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
# def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
# """ Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
# calls to global_memoized(func)
#"""
# """
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
# if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
# return ret
def
shape2d
(
a
):
def
shape2d
(
a
):
"""
"""
...
...
tensorpack/utils/concurrency.py
View file @
fb2a051c
...
@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
...
@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'mask_sigint'
,
'start_proc_mask_signal'
]
'mask_sigint'
,
'start_proc_mask_signal'
]
class
StoppableThread
(
threading
.
Thread
):
class
StoppableThread
(
threading
.
Thread
):
"""
"""
A thread that has a 'stop' event.
A thread that has a 'stop' event.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop_evt
=
threading
.
Event
()
self
.
_stop_evt
=
threading
.
Event
()
...
@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
...
@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
except
queue
.
Empty
:
except
queue
.
Empty
:
pass
pass
class
LoopThread
(
StoppableThread
):
class
LoopThread
(
StoppableThread
):
""" A pausable thread that simply runs a loop"""
""" A pausable thread that simply runs a loop"""
def
__init__
(
self
,
func
,
pausable
=
True
):
def
__init__
(
self
,
func
,
pausable
=
True
):
"""
"""
:param func: the function to run
:param func: the function to run
...
@@ -89,6 +93,7 @@ class DIE(object):
...
@@ -89,6 +93,7 @@ class DIE(object):
""" A placeholder class indicating end of queue """
""" A placeholder class indicating end of queue """
pass
pass
def
ensure_proc_terminate
(
proc
):
def
ensure_proc_terminate
(
proc
):
if
isinstance
(
proc
,
list
):
if
isinstance
(
proc
,
list
):
for
p
in
proc
:
for
p
in
proc
:
...
@@ -114,6 +119,7 @@ def mask_sigint():
...
@@ -114,6 +119,7 @@ def mask_sigint():
yield
yield
signal
.
signal
(
signal
.
SIGINT
,
sigint_handler
)
signal
.
signal
(
signal
.
SIGINT
,
sigint_handler
)
def
start_proc_mask_signal
(
proc
):
def
start_proc_mask_signal
(
proc
):
if
not
isinstance
(
proc
,
list
):
if
not
isinstance
(
proc
,
list
):
proc
=
[
proc
]
proc
=
[
proc
]
...
@@ -122,6 +128,7 @@ def start_proc_mask_signal(proc):
...
@@ -122,6 +128,7 @@ def start_proc_mask_signal(proc):
for
p
in
proc
:
for
p
in
proc
:
p
.
start
()
p
.
start
()
def
subproc_call
(
cmd
,
timeout
=
None
):
def
subproc_call
(
cmd
,
timeout
=
None
):
try
:
try
:
output
=
subprocess
.
check_output
(
output
=
subprocess
.
check_output
(
...
@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
...
@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
logger
.
warn
(
"Commnad failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
"Commnad failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
e
.
output
)
logger
.
warn
(
e
.
output
)
class
OrderedContainer
(
object
):
class
OrderedContainer
(
object
):
"""
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
"""
def
__init__
(
self
,
start
=
0
):
def
__init__
(
self
,
start
=
0
):
self
.
ranks
=
[]
self
.
ranks
=
[]
self
.
data
=
[]
self
.
data
=
[]
...
@@ -163,11 +172,13 @@ class OrderedContainer(object):
...
@@ -163,11 +172,13 @@ class OrderedContainer(object):
self
.
wait_for
+=
1
self
.
wait_for
+=
1
return
rank
,
ret
return
rank
,
ret
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
"""
"""
Gather indexed data from a data queue, and produce results with the
Gather indexed data from a data queue, and produce results with the
original index-based order.
original index-based order.
"""
"""
def
__init__
(
self
,
data_queue
,
nr_producer
,
start
=
0
):
def
__init__
(
self
,
data_queue
,
nr_producer
,
start
=
0
):
"""
"""
:param data_queue: a multiprocessing.Queue to produce input dp
:param data_queue: a multiprocessing.Queue to produce input dp
...
...
tensorpack/utils/debug.py
View file @
fb2a051c
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
import
sys
import
sys
__all__
=
[
'enable_call_trace'
]
__all__
=
[
'enable_call_trace'
]
def
enable_call_trace
():
def
enable_call_trace
():
def
tracer
(
frame
,
event
,
arg
):
def
tracer
(
frame
,
event
,
arg
):
if
event
==
'call'
:
if
event
==
'call'
:
...
@@ -21,7 +22,7 @@ def enable_call_trace():
...
@@ -21,7 +22,7 @@ def enable_call_trace():
if
caller
:
if
caller
:
caller_line_no
=
caller
.
f_lineno
caller_line_no
=
caller
.
f_lineno
caller_filename
=
caller
.
f_code
.
co_filename
caller_filename
=
caller
.
f_code
.
co_filename
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
(
func_name
,
func_filename
,
func_line_no
,
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
))
caller_filename
,
caller_line_no
))
return
return
...
@@ -32,6 +33,7 @@ if __name__ == '__main__':
...
@@ -32,6 +33,7 @@ if __name__ == '__main__':
def
b
(
a
):
def
b
(
a
):
print
(
2
)
print
(
2
)
def
a
():
def
a
():
print
(
1
)
print
(
1
)
b
(
1
)
b
(
1
)
...
...
tensorpack/utils/discretize.py
View file @
fb2a051c
...
@@ -12,11 +12,14 @@ from six.moves import range
...
@@ -12,11 +12,14 @@ from six.moves import range
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
@
memoized
@
memoized
def
log_once
(
s
):
def
log_once
(
s
):
logger
.
warn
(
s
)
logger
.
warn
(
s
)
# just a placeholder
# just a placeholder
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Discretizer
(
object
):
class
Discretizer
(
object
):
...
@@ -28,10 +31,13 @@ class Discretizer(object):
...
@@ -28,10 +31,13 @@ class Discretizer(object):
def
get_bin
(
self
,
v
):
def
get_bin
(
self
,
v
):
pass
pass
class
Discretizer1D
(
Discretizer
):
class
Discretizer1D
(
Discretizer
):
pass
pass
class
UniformDiscretizer1D
(
Discretizer1D
):
class
UniformDiscretizer1D
(
Discretizer1D
):
def
__init__
(
self
,
minv
,
maxv
,
spacing
):
def
__init__
(
self
,
minv
,
maxv
,
spacing
):
"""
"""
:params minv: minimum value of the first bin
:params minv: minimum value of the first bin
...
@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
...
@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
if
v
>=
self
.
maxv
or
v
<=
self
.
minv
:
if
v
>=
self
.
maxv
or
v
<=
self
.
minv
:
return
ret
return
ret
try
:
try
:
for
k
in
range
(
1
,
smooth_radius
+
1
):
for
k
in
range
(
1
,
smooth_radius
+
1
):
ret
[
b
+
k
]
=
smooth_factor
**
k
ret
[
b
+
k
]
=
smooth_factor
**
k
except
IndexError
:
except
IndexError
:
pass
pass
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
ret
[
b
-
k
]
=
smooth_factor
**
k
ret
[
b
-
k
]
=
smooth_factor
**
k
ret
/=
ret
.
sum
()
ret
/=
ret
.
sum
()
return
ret
return
ret
class
UniformDiscretizerND
(
Discretizer
):
class
UniformDiscretizerND
(
Discretizer
):
def
__init__
(
self
,
*
min_max_spacing
):
def
__init__
(
self
,
*
min_max_spacing
):
"""
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
:params min_max_spacing: (minv, maxv, spacing) for each dimension
...
@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
...
@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#u = UniformDiscretizer1D(-10, 10, 0.12)
#u = UniformDiscretizer1D(-10, 10, 0.12)
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/utils/fs.py
View file @
fb2a051c
...
@@ -3,13 +3,15 @@
...
@@ -3,13 +3,15 @@
# File: fs.py
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
sys
import
os
import
sys
from
six.moves
import
urllib
from
six.moves
import
urllib
import
errno
import
errno
from
.
import
logger
from
.
import
logger
__all__
=
[
'mkdir_p'
,
'download'
,
'recursive_walk'
]
__all__
=
[
'mkdir_p'
,
'download'
,
'recursive_walk'
]
def
mkdir_p
(
dirname
):
def
mkdir_p
(
dirname
):
""" make a dir recursively, but do nothing if the dir exists"""
""" make a dir recursively, but do nothing if the dir exists"""
assert
dirname
is
not
None
assert
dirname
is
not
None
...
@@ -21,6 +23,7 @@ def mkdir_p(dirname):
...
@@ -21,6 +23,7 @@ def mkdir_p(dirname):
if
e
.
errno
!=
errno
.
EEXIST
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
e
raise
e
def
download
(
url
,
dir
):
def
download
(
url
,
dir
):
mkdir_p
(
dir
)
mkdir_p
(
dir
)
fname
=
url
.
split
(
'/'
)[
-
1
]
fname
=
url
.
split
(
'/'
)[
-
1
]
...
@@ -29,7 +32,7 @@ def download(url, dir):
...
@@ -29,7 +32,7 @@ def download(url, dir):
def
_progress
(
count
,
block_size
,
total_size
):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
(
fname
,
min
(
float
(
count
*
block_size
)
/
total_size
,
min
(
float
(
count
*
block_size
)
/
total_size
,
1.0
)
*
100.0
))
1.0
)
*
100.0
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
try
:
try
:
...
@@ -45,6 +48,7 @@ def download(url, dir):
...
@@ -45,6 +48,7 @@ def download(url, dir):
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
return
fpath
return
fpath
def
recursive_walk
(
rootdir
):
def
recursive_walk
(
rootdir
):
for
r
,
dirs
,
files
in
os
.
walk
(
rootdir
):
for
r
,
dirs
,
files
in
os
.
walk
(
rootdir
):
for
f
in
files
:
for
f
in
files
:
...
...
tensorpack/utils/globvars.py
View file @
fb2a051c
...
@@ -9,13 +9,15 @@ import argparse
...
@@ -9,13 +9,15 @@ import argparse
__all__
=
[
'globalns'
,
'use_global_argument'
]
__all__
=
[
'globalns'
,
'use_global_argument'
]
if
six
.
PY2
:
if
six
.
PY2
:
class
NS
:
pass
class
NS
:
pass
else
:
else
:
import
types
import
types
NS
=
types
.
SimpleNamespace
NS
=
types
.
SimpleNamespace
globalns
=
NS
()
globalns
=
NS
()
def
use_global_argument
(
args
):
def
use_global_argument
(
args
):
"""
"""
Add the content of argparse.Namespace to globalns
Add the content of argparse.Namespace to globalns
...
...
tensorpack/utils/gpu.py
View file @
fb2a051c
...
@@ -8,20 +8,22 @@ from .utils import change_env
...
@@ -8,20 +8,22 @@ from .utils import change_env
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_gpus'
]
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_gpus'
]
def
change_gpu
(
val
):
def
change_gpu
(
val
):
val
=
str
(
val
)
val
=
str
(
val
)
if
val
==
'-1'
:
if
val
==
'-1'
:
val
=
''
val
=
''
return
change_env
(
'CUDA_VISIBLE_DEVICES'
,
val
)
return
change_env
(
'CUDA_VISIBLE_DEVICES'
,
val
)
def
get_nr_gpu
():
def
get_nr_gpu
():
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
len
(
env
.
split
(
','
))
return
len
(
env
.
split
(
','
))
def
get_gpus
():
def
get_gpus
():
""" return a list of GPU physical id"""
""" return a list of GPU physical id"""
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
tensorpack/utils/loadcaffe.py
View file @
fb2a051c
...
@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
...
@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL
=
"https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
CAFFE_PROTO_URL
=
"https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
class
CaffeLayerProcessor
(
object
):
class
CaffeLayerProcessor
(
object
):
def
__init__
(
self
,
net
):
def
__init__
(
self
,
net
):
self
.
net
=
net
self
.
net
=
net
self
.
layer_names
=
net
.
_layer_names
self
.
layer_names
=
net
.
_layer_names
...
@@ -49,7 +51,7 @@ class CaffeLayerProcessor(object):
...
@@ -49,7 +51,7 @@ class CaffeLayerProcessor(object):
assert
len
(
param
)
<=
2
assert
len
(
param
)
<=
2
assert
param
[
0
]
.
data
.
ndim
==
4
assert
param
[
0
]
.
data
.
ndim
==
4
# caffe: ch_out, ch_in, h, w
# caffe: ch_out, ch_in, h, w
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
if
len
(
param
)
==
1
:
if
len
(
param
)
==
1
:
return
{
name
+
'/W'
:
W
}
return
{
name
+
'/W'
:
W
}
else
:
else
:
...
@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
...
@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"FC layer {} takes spatial data."
.
format
(
name
))
logger
.
info
(
"FC layer {} takes spatial data."
.
format
(
name
))
W
=
param
[
0
]
.
data
W
=
param
[
0
]
.
data
# original: outx(CxHxW)
# original: outx(CxHxW)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
# become: (HxWxC)xout
# become: (HxWxC)xout
else
:
else
:
W
=
param
[
0
]
.
data
.
transpose
()
W
=
param
[
0
]
.
data
.
transpose
()
...
@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
...
@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
def
proc_bn
(
self
,
idx
,
name
,
param
):
def
proc_bn
(
self
,
idx
,
name
,
param
):
assert
param
[
2
]
.
data
[
0
]
==
1.0
assert
param
[
2
]
.
data
[
0
]
==
1.0
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
def
proc_scale
(
self
,
idx
,
name
,
param
):
def
proc_scale
(
self
,
idx
,
name
,
param
):
bottom_name
=
self
.
net
.
bottom_names
[
name
][
0
]
bottom_name
=
self
.
net
.
bottom_names
[
name
][
0
]
...
@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
...
@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"Merge {} and {} into one BatchNorm layer"
.
format
(
logger
.
info
(
"Merge {} and {} into one BatchNorm layer"
.
format
(
name
,
name2
))
name
,
name2
))
return
{
name2
+
'/beta'
:
param
[
1
]
.
data
,
return
{
name2
+
'/beta'
:
param
[
1
]
.
data
,
name2
+
'/gamma'
:
param
[
0
]
.
data
}
name2
+
'/gamma'
:
param
[
0
]
.
data
}
# assume this scaling layer is part of some BN
# assume this scaling layer is part of some BN
logger
.
error
(
"Could not find a BN layer corresponding to this Scale layer!"
)
logger
.
error
(
"Could not find a BN layer corresponding to this Scale layer!"
)
raise
ValueError
()
raise
ValueError
()
...
@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
...
@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
caffe
.
set_mode_cpu
()
caffe
.
set_mode_cpu
()
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
param_dict
=
CaffeLayerProcessor
(
net
)
.
process
()
param_dict
=
CaffeLayerProcessor
(
net
)
.
process
()
logger
.
info
(
"Model loaded from caffe. Params: "
+
\
logger
.
info
(
"Model loaded from caffe. Params: "
+
" "
.
join
(
sorted
(
param_dict
.
keys
())))
" "
.
join
(
sorted
(
param_dict
.
keys
())))
return
param_dict
return
param_dict
def
get_caffe_pb
():
def
get_caffe_pb
():
dir
=
get_dataset_path
(
'caffe'
)
dir
=
get_dataset_path
(
'caffe'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
...
@@ -131,4 +134,3 @@ if __name__ == '__main__':
...
@@ -131,4 +134,3 @@ if __name__ == '__main__':
import
numpy
as
np
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
np
.
save
(
args
.
output
,
ret
)
tensorpack/utils/logger.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
logging
import
logging
import
os
,
shutil
import
os
import
shutil
import
os.path
import
os.path
from
termcolor
import
colored
from
termcolor
import
colored
from
datetime
import
datetime
from
datetime
import
datetime
...
@@ -12,7 +13,9 @@ import sys
...
@@ -12,7 +13,9 @@ import sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
class
_MyFormatter
(
logging
.
Formatter
):
class
_MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
msg
=
'
%(message)
s'
...
@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
...
@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
self
.
_fmt
=
fmt
self
.
_fmt
=
fmt
return
super
(
_MyFormatter
,
self
)
.
format
(
record
)
return
super
(
_MyFormatter
,
self
)
.
format
(
record
)
def
_getlogger
():
def
_getlogger
():
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
.
propagate
=
False
logger
.
propagate
=
False
...
@@ -45,6 +49,8 @@ def get_time_str():
...
@@ -45,6 +49,8 @@ def get_time_str():
# logger file and directory:
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
global
LOG_FILE
,
LOG_DIR
LOG_DIR
=
None
LOG_DIR
=
None
def
_set_file
(
path
):
def
_set_file
(
path
):
if
os
.
path
.
isfile
(
path
):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
get_time_str
()
backup_name
=
path
+
'.'
+
get_time_str
()
...
@@ -56,6 +62,7 @@ def _set_file(path):
...
@@ -56,6 +62,7 @@ def _set_file(path):
_logger
.
addHandler
(
hdl
)
_logger
.
addHandler
(
hdl
)
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
def
set_logger_dir
(
dirname
,
action
=
None
):
def
set_logger_dir
(
dirname
,
action
=
None
):
"""
"""
Set the directory for global logging.
Set the directory for global logging.
...
@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
...
@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
def
disable_logger
():
""" disable all logging ability from this moment"""
""" disable all logging ability from this moment"""
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
globals
()[
func
]
=
lambda
x
:
None
globals
()[
func
]
=
lambda
x
:
None
def
auto_set_dir
(
action
=
None
,
overwrite
=
False
):
def
auto_set_dir
(
action
=
None
,
overwrite
=
False
):
""" set log directory to a subdir inside 'train_log', with the name being
""" set log directory to a subdir inside 'train_log', with the name being
the main python file currently running"""
the main python file currently running"""
...
@@ -116,5 +125,6 @@ def auto_set_dir(action=None, overwrite=False):
...
@@ -116,5 +125,6 @@ def auto_set_dir(action=None, overwrite=False):
basename
[:
basename
.
rfind
(
'.'
)]),
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
action
=
action
)
def
warn_dependency
(
name
,
dependencies
):
def
warn_dependency
(
name
,
dependencies
):
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
tensorpack/utils/lut.py
View file @
fb2a051c
...
@@ -7,10 +7,12 @@ import six
...
@@ -7,10 +7,12 @@ import six
__all__
=
[
'LookUpTable'
]
__all__
=
[
'LookUpTable'
]
class
LookUpTable
(
object
):
class
LookUpTable
(
object
):
def
__init__
(
self
,
objlist
):
def
__init__
(
self
,
objlist
):
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
def
size
(
self
):
def
size
(
self
):
return
len
(
self
.
idx2obj
)
return
len
(
self
.
idx2obj
)
...
...
tensorpack/utils/rect.py
View file @
fb2a051c
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
numpy
as
np
import
numpy
as
np
class
Rect
(
object
):
class
Rect
(
object
):
"""
"""
A Rectangle.
A Rectangle.
...
@@ -68,7 +69,7 @@ class Rect(object):
...
@@ -68,7 +69,7 @@ class Rect(object):
def
roi
(
self
,
img
):
def
roi
(
self
,
img
):
assert
self
.
validate
(
img
.
shape
[:
2
]),
"{} vs {}"
.
format
(
self
,
img
.
shape
[:
2
])
assert
self
.
validate
(
img
.
shape
[:
2
]),
"{} vs {}"
.
format
(
self
,
img
.
shape
[:
2
])
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
def
expand
(
self
,
frac
):
def
expand
(
self
,
frac
):
assert
frac
>
1.0
,
frac
assert
frac
>
1.0
,
frac
...
@@ -92,7 +93,7 @@ class Rect(object):
...
@@ -92,7 +93,7 @@ class Rect(object):
xmax
=
min
(
self
.
x1
,
img
.
shape
[
1
])
xmax
=
min
(
self
.
x1
,
img
.
shape
[
1
])
ymax
=
min
(
self
.
y1
,
img
.
shape
[
0
])
ymax
=
min
(
self
.
y1
,
img
.
shape
[
0
])
patch
=
img
[
ymin
:
ymax
,
xmin
:
xmax
]
patch
=
img
[
ymin
:
ymax
,
xmin
:
xmax
]
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
return
ret
return
ret
__repr__
=
__str__
__repr__
=
__str__
...
@@ -101,6 +102,6 @@ class Rect(object):
...
@@ -101,6 +102,6 @@ class Rect(object):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
img
=
np
.
random
.
rand
(
3
,
3
)
img
=
np
.
random
.
rand
(
3
,
3
)
print
(
img
)
print
(
img
)
print
(
x
.
roi_zeropad
(
img
))
print
(
x
.
roi_zeropad
(
img
))
tensorpack/utils/serialize.py
View file @
fb2a051c
...
@@ -10,10 +10,12 @@ msgpack_numpy.patch()
...
@@ -10,10 +10,12 @@ msgpack_numpy.patch()
__all__
=
[
'loads'
,
'dumps'
]
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
def
dumps
(
obj
):
#return dill.dumps(obj)
#
return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
def
loads
(
buf
):
#return dill.loads(buf)
#
return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/stats.py
View file @
fb2a051c
...
@@ -6,8 +6,10 @@ import numpy as np
...
@@ -6,8 +6,10 @@ import numpy as np
__all__
=
[
'StatCounter'
,
'Accuracy'
,
'BinaryStatistics'
,
'RatioCounter'
,
__all__
=
[
'StatCounter'
,
'Accuracy'
,
'BinaryStatistics'
,
'RatioCounter'
,
'OnlineMoments'
]
'OnlineMoments'
]
class
StatCounter
(
object
):
class
StatCounter
(
object
):
""" A simple counter"""
""" A simple counter"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -36,8 +38,10 @@ class StatCounter(object):
...
@@ -36,8 +38,10 @@ class StatCounter(object):
assert
len
(
self
.
_values
)
assert
len
(
self
.
_values
)
return
max
(
self
.
_values
)
return
max
(
self
.
_values
)
class
RatioCounter
(
object
):
class
RatioCounter
(
object
):
""" A counter to count ratio of something"""
""" A counter to count ratio of something"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -59,17 +63,20 @@ class RatioCounter(object):
...
@@ -59,17 +63,20 @@ class RatioCounter(object):
def
count
(
self
):
def
count
(
self
):
return
self
.
_tot
return
self
.
_tot
class
Accuracy
(
RatioCounter
):
class
Accuracy
(
RatioCounter
):
""" A RatioCounter with a fancy name """
""" A RatioCounter with a fancy name """
@
property
@
property
def
accuracy
(
self
):
def
accuracy
(
self
):
return
self
.
ratio
return
self
.
ratio
class
BinaryStatistics
(
object
):
class
BinaryStatistics
(
object
):
"""
"""
Statistics for binary decision,
Statistics for binary decision,
including precision, recall, false positive, false negative
including precision, recall, false positive, false negative
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -118,10 +125,12 @@ class BinaryStatistics(object):
...
@@ -118,10 +125,12 @@ class BinaryStatistics(object):
return
0
return
0
return
1
-
self
.
recall
return
1
-
self
.
recall
class
OnlineMoments
(
object
):
class
OnlineMoments
(
object
):
"""Compute 1st and 2nd moments online
"""Compute 1st and 2nd moments online
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_mean
=
0
self
.
_mean
=
0
self
.
_M2
=
0
self
.
_M2
=
0
...
@@ -140,7 +149,7 @@ class OnlineMoments(object):
...
@@ -140,7 +149,7 @@ class OnlineMoments(object):
@
property
@
property
def
variance
(
self
):
def
variance
(
self
):
return
self
.
_M2
/
(
self
.
_n
-
1
)
return
self
.
_M2
/
(
self
.
_n
-
1
)
@
property
@
property
def
std
(
self
):
def
std
(
self
):
...
...
tensorpack/utils/timer.py
View file @
fb2a051c
...
@@ -16,8 +16,10 @@ from . import logger
...
@@ -16,8 +16,10 @@ from . import logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
,
'IterSpeedCounter'
]
'print_total_timer'
,
'IterSpeedCounter'
]
class
IterSpeedCounter
(
object
):
class
IterSpeedCounter
(
object
):
""" To count how often some code gets reached"""
""" To count how often some code gets reached"""
def
__init__
(
self
,
print_every
,
name
=
None
):
def
__init__
(
self
,
print_every
,
name
=
None
):
self
.
cnt
=
0
self
.
cnt
=
0
self
.
print_every
=
int
(
print_every
)
self
.
print_every
=
int
(
print_every
)
...
@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
...
@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
@
contextmanager
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
if
log_start
:
...
@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
...
@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
@
contextmanager
@
contextmanager
def
total_timer
(
msg
):
def
total_timer
(
msg
):
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -54,6 +58,7 @@ def total_timer(msg):
...
@@ -54,6 +58,7 @@ def total_timer(msg):
t
=
time
.
time
()
-
start
t
=
time
.
time
()
-
start
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
def
print_total_timer
():
def
print_total_timer
():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
return
...
...
tensorpack/utils/utils.py
View file @
fb2a051c
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# File: utils.py
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
,
sys
import
os
import
sys
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
inspect
import
inspect
from
datetime
import
datetime
from
datetime
import
datetime
...
@@ -18,6 +19,7 @@ __all__ = ['change_env',
...
@@ -18,6 +19,7 @@ __all__ = ['change_env',
'execute_only_once'
'execute_only_once'
]
]
@
contextmanager
@
contextmanager
def
change_env
(
name
,
val
):
def
change_env
(
name
,
val
):
oldval
=
os
.
environ
.
get
(
name
,
None
)
oldval
=
os
.
environ
.
get
(
name
,
None
)
...
@@ -28,6 +30,7 @@ def change_env(name, val):
...
@@ -28,6 +30,7 @@ def change_env(name, val):
else
:
else
:
os
.
environ
[
name
]
=
oldval
os
.
environ
[
name
]
=
oldval
def
get_rng
(
obj
=
None
):
def
get_rng
(
obj
=
None
):
""" obj: some object to use to generate random seed"""
""" obj: some object to use to generate random seed"""
seed
=
(
id
(
obj
)
+
os
.
getpid
()
+
seed
=
(
id
(
obj
)
+
os
.
getpid
()
+
...
@@ -36,6 +39,8 @@ def get_rng(obj=None):
...
@@ -36,6 +39,8 @@ def get_rng(obj=None):
_EXECUTE_HISTORY
=
set
()
_EXECUTE_HISTORY
=
set
()
def
execute_only_once
():
def
execute_only_once
():
"""
"""
when called with:
when called with:
...
@@ -50,6 +55,7 @@ def execute_only_once():
...
@@ -50,6 +55,7 @@ def execute_only_once():
_EXECUTE_HISTORY
.
add
(
ident
)
_EXECUTE_HISTORY
.
add
(
ident
)
return
True
return
True
def
get_dataset_path
(
*
args
):
def
get_dataset_path
(
*
args
):
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
if
d
is
None
:
if
d
is
None
:
...
@@ -61,6 +67,7 @@ def get_dataset_path(*args):
...
@@ -61,6 +67,7 @@ def get_dataset_path(*args):
assert
os
.
path
.
isdir
(
d
),
d
assert
os
.
path
.
isdir
(
d
),
d
return
os
.
path
.
join
(
d
,
*
args
)
return
os
.
path
.
join
(
d
,
*
args
)
def
get_tqdm_kwargs
(
**
kwargs
):
def
get_tqdm_kwargs
(
**
kwargs
):
default
=
dict
(
default
=
dict
(
smoothing
=
0.5
,
smoothing
=
0.5
,
...
@@ -76,5 +83,6 @@ def get_tqdm_kwargs(**kwargs):
...
@@ -76,5 +83,6 @@ def get_tqdm_kwargs(**kwargs):
default
.
update
(
kwargs
)
default
.
update
(
kwargs
)
return
default
return
default
def
get_tqdm
(
**
kwargs
):
def
get_tqdm
(
**
kwargs
):
return
tqdm
(
**
get_tqdm_kwargs
(
**
kwargs
))
return
tqdm
(
**
get_tqdm_kwargs
(
**
kwargs
))
tensorpack/utils/viz.py
View file @
fb2a051c
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
# Credit: zxytim
# Credit: zxytim
import
numpy
as
np
import
numpy
as
np
import
os
,
sys
import
os
import
sys
import
io
import
io
import
cv2
import
cv2
from
.fs
import
mkdir_p
from
.fs
import
mkdir_p
...
@@ -18,6 +19,7 @@ except ImportError:
...
@@ -18,6 +19,7 @@ except ImportError:
__all__
=
[
'pyplot2img'
,
'build_patch_list'
,
'pyplot_viz'
,
__all__
=
[
'pyplot2img'
,
'build_patch_list'
,
'pyplot_viz'
,
'dump_dataflow_images'
,
'interactive_imshow'
]
'dump_dataflow_images'
,
'interactive_imshow'
]
def
pyplot2img
(
plt
):
def
pyplot2img
(
plt
):
buf
=
io
.
BytesIO
()
buf
=
io
.
BytesIO
()
plt
.
axis
(
'off'
)
plt
.
axis
(
'off'
)
...
@@ -28,23 +30,28 @@ def pyplot2img(plt):
...
@@ -28,23 +30,28 @@ def pyplot2img(plt):
buf
.
close
()
buf
.
close
()
return
im
return
im
def
pyplot_viz
(
img
,
shape
=
None
):
def
pyplot_viz
(
img
,
shape
=
None
):
""" use pyplot to visualize the image
""" use pyplot to visualize the image
Note: this is quite slow. and the returned image will have a border
Note: this is quite slow. and the returned image will have a border
"""
"""
plt
.
clf
()
plt
.
clf
()
plt
.
axes
([
0
,
0
,
1
,
1
])
plt
.
axes
([
0
,
0
,
1
,
1
])
plt
.
imshow
(
img
)
plt
.
imshow
(
img
)
ret
=
pyplot2img
(
plt
)
ret
=
pyplot2img
(
plt
)
if
shape
is
not
None
:
if
shape
is
not
None
:
ret
=
cv2
.
resize
(
ret
,
shape
)
ret
=
cv2
.
resize
(
ret
,
shape
)
return
ret
return
ret
def
minnone
(
x
,
y
):
def
minnone
(
x
,
y
):
if
x
is
None
:
x
=
y
if
x
is
None
:
elif
y
is
None
:
y
=
x
x
=
y
elif
y
is
None
:
y
=
x
return
min
(
x
,
y
)
return
min
(
x
,
y
)
def
interactive_imshow
(
img
,
lclick_cb
=
None
,
rclick_cb
=
None
,
**
kwargs
):
def
interactive_imshow
(
img
,
lclick_cb
=
None
,
rclick_cb
=
None
,
**
kwargs
):
"""
"""
:param lclick_cb: a callback(img, x, y) for left click
:param lclick_cb: a callback(img, x, y) for left click
...
@@ -70,6 +77,7 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
...
@@ -70,6 +77,7 @@ def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
elif
key
==
's'
:
elif
key
==
's'
:
cv2
.
imwrite
(
'out.png'
,
img
)
cv2
.
imwrite
(
'out.png'
,
img
)
def
build_patch_list
(
patch_list
,
def
build_patch_list
(
patch_list
,
nr_row
=
None
,
nr_col
=
None
,
border
=
None
,
nr_row
=
None
,
nr_col
=
None
,
border
=
None
,
max_width
=
1000
,
max_height
=
1000
,
max_width
=
1000
,
max_height
=
1000
,
...
@@ -89,7 +97,7 @@ def build_patch_list(patch_list,
...
@@ -89,7 +97,7 @@ def build_patch_list(patch_list,
# setup parameters
# setup parameters
patch_list
=
np
.
asarray
(
patch_list
)
patch_list
=
np
.
asarray
(
patch_list
)
if
patch_list
.
ndim
==
3
:
if
patch_list
.
ndim
==
3
:
patch_list
=
patch_list
[:,
:,:,
np
.
newaxis
]
patch_list
=
patch_list
[:,
:,
:,
np
.
newaxis
]
assert
patch_list
.
ndim
==
4
and
patch_list
.
shape
[
3
]
in
[
1
,
3
],
patch_list
.
shape
assert
patch_list
.
ndim
==
4
and
patch_list
.
shape
[
3
]
in
[
1
,
3
],
patch_list
.
shape
if
shuffle
:
if
shuffle
:
np
.
random
.
shuffle
(
patch_list
)
np
.
random
.
shuffle
(
patch_list
)
...
@@ -114,7 +122,7 @@ def build_patch_list(patch_list,
...
@@ -114,7 +122,7 @@ def build_patch_list(patch_list,
for
patch
in
plist
:
for
patch
in
plist
:
r0
=
cur_row
*
(
ph
+
border
)
r0
=
cur_row
*
(
ph
+
border
)
c0
=
cur_col
*
(
pw
+
border
)
c0
=
cur_col
*
(
pw
+
border
)
canvas
[
r0
:
r0
+
ph
,
c0
:
c0
+
pw
]
=
patch
canvas
[
r0
:
r0
+
ph
,
c0
:
c0
+
pw
]
=
patch
cur_col
+=
1
cur_col
+=
1
if
cur_col
==
nr_col
:
if
cur_col
==
nr_col
:
cur_col
=
0
cur_col
=
0
...
@@ -143,6 +151,7 @@ def build_patch_list(patch_list,
...
@@ -143,6 +151,7 @@ def build_patch_list(patch_list,
yield
canvas
yield
canvas
start
=
end
start
=
end
def
dump_dataflow_images
(
df
,
index
=
0
,
batched
=
True
,
def
dump_dataflow_images
(
df
,
index
=
0
,
batched
=
True
,
number
=
1000
,
output_dir
=
None
,
number
=
1000
,
output_dir
=
None
,
scale
=
1
,
resize
=
None
,
viz
=
None
,
scale
=
1
,
resize
=
None
,
viz
=
None
,
...
@@ -188,7 +197,7 @@ def dump_dataflow_images(df, index=0, batched=True,
...
@@ -188,7 +197,7 @@ def dump_dataflow_images(df, index=0, batched=True,
if
resize
is
not
None
:
if
resize
is
not
None
:
img
=
cv2
.
resize
(
img
,
resize
)
img
=
cv2
.
resize
(
img
,
resize
)
if
flipRGB
:
if
flipRGB
:
img
=
img
[:,
:,
::
-
1
]
img
=
img
[:,
:,
::
-
1
]
if
output_dir
:
if
output_dir
:
fname
=
os
.
path
.
join
(
output_dir
,
'{:03d}.jpg'
.
format
(
cnt
))
fname
=
os
.
path
.
join
(
output_dir
,
'{:03d}.jpg'
.
format
(
cnt
))
cv2
.
imwrite
(
fname
,
img
)
cv2
.
imwrite
(
fname
,
img
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment