Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fb2a051c
Commit
fb2a051c
authored
Jan 02, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
run autopep8 over tensorpack/
parent
59553585
Changes
100
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
100 changed files
with
1047 additions
and
532 deletions
+1047
-532
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+2
-1
tensorpack/RL/common.py
tensorpack/RL/common.py
+9
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+13
-1
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+50
-45
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+4
-2
tensorpack/RL/history.py
tensorpack/RL/history.py
+2
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+22
-5
tensorpack/__init__.py
tensorpack/__init__.py
+1
-1
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+2
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+5
-1
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+4
-2
tensorpack/callbacks/dispatcher.py
tensorpack/callbacks/dispatcher.py
+2
-0
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+2
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-0
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+7
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+17
-12
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+24
-10
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+14
-10
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+13
-4
tensorpack/dataflow/__init__.py
tensorpack/dataflow/__init__.py
+2
-2
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+4
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+28
-4
tensorpack/dataflow/dataset/__init__.py
tensorpack/dataflow/dataset/__init__.py
+2
-1
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+5
-3
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+16
-7
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+13
-8
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+12
-6
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+2
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+4
-3
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+6
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+9
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+22
-8
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+6
-2
tensorpack/dataflow/imgaug/__init__.py
tensorpack/dataflow/imgaug/__init__.py
+1
-1
tensorpack/dataflow/imgaug/_test.py
tensorpack/dataflow/imgaug/_test.py
+4
-4
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+5
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+26
-16
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+18
-12
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+22
-18
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+21
-8
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+11
-2
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+5
-0
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+13
-6
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+10
-3
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+11
-2
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+8
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+4
-2
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+14
-12
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+2
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+9
-4
tensorpack/models/_test.py
tensorpack/models/_test.py
+3
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+19
-16
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+10
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-1
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+31
-26
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+14
-6
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+5
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+30
-19
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-1
tensorpack/models/shapes.py
tensorpack/models/shapes.py
+1
-0
tensorpack/models/softmax.py
tensorpack/models/softmax.py
+2
-1
tensorpack/predict/__init__.py
tensorpack/predict/__init__.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+22
-12
tensorpack/predict/common.py
tensorpack/predict/common.py
+6
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+18
-10
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+18
-10
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+2
-0
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+13
-3
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+12
-1
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+14
-4
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+6
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+21
-12
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+4
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+11
-4
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+7
-4
tensorpack/train/config.py
tensorpack/train/config.py
+3
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+28
-21
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+26
-13
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+22
-16
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+9
-4
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+2
-3
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+20
-11
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+13
-2
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+5
-3
tensorpack/utils/discretize.py
tensorpack/utils/discretize.py
+14
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+6
-2
tensorpack/utils/globvars.py
tensorpack/utils/globvars.py
+3
-1
tensorpack/utils/gpu.py
tensorpack/utils/gpu.py
+3
-1
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+11
-9
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+14
-4
tensorpack/utils/lut.py
tensorpack/utils/lut.py
+3
-1
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+4
-3
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+4
-2
tensorpack/utils/stats.py
tensorpack/utils/stats.py
+11
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+6
-1
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+22
-14
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+26
-17
No files found.
tensorpack/RL/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,8 @@ import os
...
@@ -8,6 +8,8 @@ import os
import
os.path
import
os.path
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
...
@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/RL/common.py
View file @
fb2a051c
...
@@ -9,7 +9,8 @@ from collections import deque
...
@@ -9,7 +9,8 @@ from collections import deque
from
.envbase
import
ProxyPlayer
from
.envbase
import
ProxyPlayer
__all__
=
[
'PreventStuckPlayer'
,
'LimitLengthPlayer'
,
'AutoRestartPlayer'
,
__all__
=
[
'PreventStuckPlayer'
,
'LimitLengthPlayer'
,
'AutoRestartPlayer'
,
'MapPlayerState'
]
'MapPlayerState'
]
class
PreventStuckPlayer
(
ProxyPlayer
):
class
PreventStuckPlayer
(
ProxyPlayer
):
""" Prevent the player from getting stuck (repeating a no-op)
""" Prevent the player from getting stuck (repeating a no-op)
...
@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer):
where the agent needs to press the 'start' button to start playing.
where the agent needs to press the 'start' button to start playing.
"""
"""
# TODO hash the state as well?
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
"""
It does auto-reset, but doesn't auto-restart the underlying player.
It does auto-reset, but doesn't auto-restart the underlying player.
...
@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
self
.
act_que
.
clear
()
self
.
act_que
.
clear
()
class
LimitLengthPlayer
(
ProxyPlayer
):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
""" Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout
Will auto restart the underlying player on timeout
"""
"""
def
__init__
(
self
,
player
,
limit
):
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
limit
=
limit
...
@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
...
@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
self
.
cnt
=
0
self
.
cnt
=
0
class
AutoRestartPlayer
(
ProxyPlayer
):
class
AutoRestartPlayer
(
ProxyPlayer
):
""" Auto-restart the player on episode ends,
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
in case some player wasn't designed to do so. """
def
action
(
self
,
act
):
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
r
,
isOver
=
self
.
player
.
action
(
act
)
if
isOver
:
if
isOver
:
...
@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
...
@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
self
.
player
.
restart_episode
()
return
r
,
isOver
return
r
,
isOver
class
MapPlayerState
(
ProxyPlayer
):
class
MapPlayerState
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
func
):
def
__init__
(
self
,
player
,
func
):
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
self
.
func
=
func
self
.
func
=
func
...
...
tensorpack/RL/envbase.py
View file @
fb2a051c
...
@@ -13,8 +13,10 @@ from ..utils import get_rng
...
@@ -13,8 +13,10 @@ from ..utils import get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
'DiscreteActionSpace'
]
'DiscreteActionSpace'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
RLEnvironment
(
object
):
class
RLEnvironment
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset_stat
()
self
.
reset_stat
()
...
@@ -60,13 +62,15 @@ class RLEnvironment(object):
...
@@ -60,13 +62,15 @@ class RLEnvironment(object):
s
=
self
.
current_state
()
s
=
self
.
current_state
()
act
=
func
(
s
)
act
=
func
(
s
)
r
,
isOver
=
self
.
action
(
act
)
r
,
isOver
=
self
.
action
(
act
)
#print r
#
print r
if
isOver
:
if
isOver
:
s
=
[
self
.
stats
[
k
]
for
k
in
stat
]
s
=
[
self
.
stats
[
k
]
for
k
in
stat
]
self
.
reset_stat
()
self
.
reset_stat
()
return
s
if
len
(
s
)
>
1
else
s
[
0
]
return
s
if
len
(
s
)
>
1
else
s
[
0
]
class
ActionSpace
(
object
):
class
ActionSpace
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -77,7 +81,9 @@ class ActionSpace(object):
...
@@ -77,7 +81,9 @@ class ActionSpace(object):
def
num_actions
(
self
):
def
num_actions
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
DiscreteActionSpace
(
ActionSpace
):
class
DiscreteActionSpace
(
ActionSpace
):
def
__init__
(
self
,
num
):
def
__init__
(
self
,
num
):
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
self
.
num
=
num
self
.
num
=
num
...
@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
...
@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
def
__str__
(
self
):
def
__str__
(
self
):
return
"DiscreteActionSpace({})"
.
format
(
self
.
num
)
return
"DiscreteActionSpace({})"
.
format
(
self
.
num
)
class
NaiveRLEnvironment
(
RLEnvironment
):
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
""" for testing only"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
k
=
0
self
.
k
=
0
def
current_state
(
self
):
def
current_state
(
self
):
self
.
k
+=
1
self
.
k
+=
1
return
self
.
k
return
self
.
k
def
action
(
self
,
act
):
def
action
(
self
,
act
):
self
.
k
=
act
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
""" Serve as a proxy another player """
def
__init__
(
self
,
player
):
def
__init__
(
self
,
player
):
self
.
player
=
player
self
.
player
=
player
...
...
tensorpack/RL/expreplay.py
View file @
fb2a051c
...
@@ -10,14 +10,15 @@ import six
...
@@ -10,14 +10,15 @@ import six
from
six.moves
import
queue
from
six.moves
import
queue
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
,
get_rng
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
__all__
=
[
'ExpReplay'
]
__all__
=
[
'ExpReplay'
]
Experience
=
namedtuple
(
'Experience'
,
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
ExpReplay
(
DataFlow
,
Callback
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
"""
...
@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback):
...
@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as an DataFlow.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
predictor_io_names
,
predictor_io_names
,
player
,
player
,
batch_size
=
32
,
batch_size
=
32
,
memory_size
=
1e6
,
memory_size
=
1e6
,
init_memory_size
=
50000
,
init_memory_size
=
50000
,
exploration
=
1
,
exploration
=
1
,
end_exploration
=
0.1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
,
reward_clip
=
None
,
update_frequency
=
1
,
update_frequency
=
1
,
history_len
=
1
history_len
=
1
):
):
"""
"""
:param predictor: a callabale running the up-to-date network.
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
called with a state, return a distribution.
...
@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
""" populate a transition by epsilon-greedy"""
#if len(self.mem):
#
if len(self.mem):
#
from copy import deepcopy # quickly fill the memory for debug
#
from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem[0]))
#
self.mem.append(deepcopy(self.mem[0]))
#
return
#
return
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
player
.
current_state
()
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
...
@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
...
@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
while
True
:
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2 # for debug
#
import cv2 # for debug
#def view_state(state, next_state):
#
def view_state(state, next_state):
#
""" for debugging state representation"""
#
""" for debugging state representation"""
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([r, r2], axis=0)
#
r = np.concatenate([r, r2], axis=0)
#
print r.shape
#
print r.shape
#
cv2.imshow("state", r)
#
cv2.imshow("state", r)
#
cv2.waitKey()
#
cv2.waitKey()
#exp = batch_exp[0]
#
exp = batch_exp[0]
#print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#if exp[2] or exp[4]:
#
if exp[2] or exp[4]:
#
view_state(exp[0], exp[1])
#
view_state(exp[0], exp[1])
yield
self
.
_process_batch
(
batch_exp
)
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_job_queue
.
put
(
1
)
self
.
_populate_job_queue
.
put
(
1
)
...
@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
...
@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
# when x.isOver==True, (x+1).state is of a different episode
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
def
concat
(
idx
):
def
concat
(
idx
):
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
return
np
.
concatenate
(
v
,
axis
=
2
)
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
0
)
state
=
concat
(
0
)
next_state
=
concat
(
1
)
next_state
=
concat
(
1
)
...
@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
...
@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
# zero-fill state before starting
# zero-fill state before starting
zero_fill
=
False
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
for
k
in
range
(
1
,
self
.
history_len
):
if
samples
[
start_idx
-
k
]
.
isOver
:
if
samples
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
zero_fill
=
True
if
zero_fill
:
if
zero_fill
:
state
[:,
:,
-
k
-
1
]
=
0
state
[:,
:,
-
k
-
1
]
=
0
if
k
+
2
<=
self
.
history_len
:
if
k
+
2
<=
self
.
history_len
:
next_state
[:,
:,
-
k
-
2
]
=
0
next_state
[:,
:,
-
k
-
2
]
=
0
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
def
_process_batch
(
self
,
batch_exp
):
...
@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
# spawn a separate thread to run policy, can speed up 1.3x
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
def
populate_job_func
():
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
with
self
.
trainer
.
sess
.
as_default
():
...
@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback):
...
@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback):
pass
pass
self
.
player
.
reset_stat
()
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
.atari
import
AtariPlayer
from
.atari
import
AtariPlayer
import
sys
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
E
=
ExpReplay
(
predictor
,
player
=
player
,
player
=
player
,
num_actions
=
player
.
get_action_space
()
.
num_actions
(),
num_actions
=
player
.
get_action_space
()
.
num_actions
(),
populate_size
=
1001
,
populate_size
=
1001
,
history_len
=
4
)
history_len
=
4
)
E
.
_init_memory
()
E
.
_init_memory
()
for
k
in
E
.
get_data
():
for
k
in
E
.
get_data
():
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
pass
pass
#import IPython;
#
import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break
#
break
tensorpack/RL/gymenv.py
View file @
fb2a051c
...
@@ -9,7 +9,7 @@ from ..utils import logger
...
@@ -9,7 +9,7 @@ from ..utils import logger
try
:
try
:
import
gym
import
gym
# TODO
# TODO
#gym.undo_logger_setup()
#
gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
# not sure does it cause other problems
__all__
=
[
'GymEnv'
]
__all__
=
[
'GymEnv'
]
...
@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
...
@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
_ENV_LOCK
=
threading
.
Lock
()
_ENV_LOCK
=
threading
.
Lock
()
class
GymEnv
(
RLEnvironment
):
class
GymEnv
(
RLEnvironment
):
"""
"""
An OpenAI/gym wrapper. Can optionally auto restart.
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now
Only support discrete action space now
"""
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
with
_ENV_LOCK
:
with
_ENV_LOCK
:
self
.
gymenv
=
gym
.
make
(
name
)
self
.
gymenv
=
gym
.
make
(
name
)
...
@@ -82,7 +84,7 @@ if __name__ == '__main__':
...
@@ -82,7 +84,7 @@ if __name__ == '__main__':
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
while
True
:
while
True
:
act
=
rng
.
choice
(
range
(
num
))
act
=
rng
.
choice
(
range
(
num
))
#print act
#
print act
r
,
o
=
env
.
action
(
act
)
r
,
o
=
env
.
action
(
act
)
env
.
current_state
()
env
.
current_state
()
if
r
!=
0
or
o
:
if
r
!=
0
or
o
:
...
...
tensorpack/RL/history.py
View file @
fb2a051c
...
@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
...
@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
__all__
=
[
'HistoryFramePlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images
""" Include history frames in state, or use black images
Assume player will do auto-restart.
Assume player will do auto-restart.
"""
"""
def
__init__
(
self
,
player
,
hist_len
):
def
__init__
(
self
,
player
,
hist_len
):
"""
"""
:param hist_len: total length of the state, including the current
:param hist_len: total length of the state, including the current
...
@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
...
@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
self
.
history
.
clear
()
self
.
history
.
clear
()
self
.
history
.
append
(
self
.
player
.
current_state
())
self
.
history
.
append
(
self
.
player
.
current_state
())
tensorpack/RL/simulator.py
View file @
fb2a051c
...
@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps
...
@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
'TransitionExperience'
,
'WeightSync'
]
try
:
try
:
import
zmq
import
zmq
...
@@ -34,8 +34,10 @@ except ImportError:
...
@@ -34,8 +34,10 @@ except ImportError:
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
__all__
=
[]
__all__
=
[]
class
TransitionExperience
(
object
):
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
state
=
state
...
@@ -44,6 +46,7 @@ class TransitionExperience(object):
...
@@ -44,6 +46,7 @@ class TransitionExperience(object):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SimulatorProcessBase
(
mp
.
Process
):
class
SimulatorProcessBase
(
mp
.
Process
):
...
@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to
A process that simulates a player and communicates to master to
send states and receive the next action
send states and receive the next action
"""
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
"""
"""
:param idx: idx of this process
:param idx: idx of this process
...
@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
#s2c_socket.set_hwm(5)
#
s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
s2c_socket
.
connect
(
self
.
s2c
)
state
=
player
.
current_state
()
state
=
player
.
current_state
()
...
@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
# compatibility
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
SimulatorProcess
=
SimulatorProcessStateExchange
class
SimulatorMaster
(
threading
.
Thread
):
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
""" A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
defining callbacks when a transition or an episode is finished.
"""
"""
class
ClientState
(
object
):
class
ClientState
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
self
.
memory
=
[]
# list of Experience
...
@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
...
@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
def
__del__
(
self
):
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
self
.
context
.
destroy
(
linger
=
0
)
class
SimulatorProcessDF
(
SimulatorProcessBase
):
class
SimulatorProcessDF
(
SimulatorProcessBase
):
""" A simulator which contains a forward model itself, allowing
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
it to produce data points directly """
def
__init__
(
self
,
idx
,
pipe_c2s
):
def
__init__
(
self
,
idx
,
pipe_c2s
):
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
self
.
pipe_c2s
=
pipe_c2s
self
.
pipe_c2s
=
pipe_c2s
...
@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
...
@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
def
get_data
(
self
):
def
get_data
(
self
):
pass
pass
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
""" A simulator process with an extra thread waiting for event,
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
Start me under some CUDA_VISIBLE_DEVICES set!
"""
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
condvar
=
condvar
self
.
condvar
=
condvar
...
@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
...
@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
with
self
.
predictor
.
graph
.
as_default
():
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
vars_to_update
=
self
.
_params_to_update
()
self
.
sess_updater
=
SessionUpdate
(
self
.
sess_updater
=
SessionUpdate
(
self
.
predictor
.
session
,
vars_to_update
)
self
.
predictor
.
session
,
vars_to_update
)
# TODO setup callback for explore?
# TODO setup callback for explore?
self
.
predictor
.
graph
.
finalize
()
self
.
predictor
.
graph
.
finalize
()
...
@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
...
@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
# can be overwritten to update more params
# can be overwritten to update more params
return
tf
.
trainable_variables
()
return
tf
.
trainable_variables
()
class
WeightSync
(
Callback
):
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
self
.
shared_dic
=
shared_dic
...
@@ -260,6 +272,7 @@ class WeightSync(Callback):
...
@@ -260,6 +272,7 @@ class WeightSync(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_sync
()
self
.
_sync
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
_sync
()
self
.
_sync
()
...
@@ -274,13 +287,18 @@ class WeightSync(Callback):
...
@@ -274,13 +287,18 @@ class WeightSync(Callback):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
random
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
from
tensorpack.RL
import
NaiveRLEnvironment
class
NaiveSimulator
(
SimulatorProcess
):
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
class
NaiveActioner
(
SimulatorActioner
):
def
_get_action
(
self
,
state
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
#print("Over: ", client.memory)
client
.
memory
=
[]
client
.
memory
=
[]
...
@@ -296,4 +314,3 @@ if __name__ == '__main__':
...
@@ -296,4 +314,3 @@ if __name__ == '__main__':
import
time
import
time
time
.
sleep
(
100
)
time
.
sleep
(
100
)
tensorpack/__init__.py
View file @
fb2a051c
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
# avoid https://github.com/tensorflow/tensorflow/issues/2034
import
numpy
# avoid https://github.com/tensorflow/tensorflow/issues/2034
import
cv2
# avoid https://github.com/tensorflow/tensorflow/issues/1924
import
cv2
# avoid https://github.com/tensorflow/tensorflow/issues/1924
from
tensorpack.train
import
*
from
tensorpack.train
import
*
...
...
tensorpack/callbacks/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,8 @@ import os
...
@@ -7,6 +7,8 @@ import os
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
...
@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
continue
continue
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/callbacks/base.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ import six
...
@@ -11,6 +11,7 @@ import six
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
""" Base class for all callbacks """
...
@@ -72,7 +73,9 @@ class Callback(object):
...
@@ -72,7 +73,9 @@ class Callback(object):
def
__str__
(
self
):
def
__str__
(
self
):
return
type
(
self
)
.
__name__
return
type
(
self
)
.
__name__
class
ProxyCallback
(
Callback
):
class
ProxyCallback
(
Callback
):
def
__init__
(
self
,
cb
):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
self
.
cb
=
cb
...
@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
...
@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
def
__str__
(
self
):
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
class
PeriodicCallback
(
ProxyCallback
):
"""
"""
A callback to be triggered after every `period` epochs.
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
Doesn't work for trigger_step
"""
"""
def
__init__
(
self
,
cb
,
period
):
def
__init__
(
self
,
cb
,
period
):
"""
"""
:param cb: a `Callback`
:param cb: a `Callback`
...
@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
...
@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
def
__str__
(
self
):
def
__str__
(
self
):
return
"Periodic-"
+
str
(
self
.
cb
)
return
"Periodic-"
+
str
(
self
.
cb
)
tensorpack/callbacks/concurrency.py
View file @
fb2a051c
...
@@ -9,7 +9,9 @@ from ..utils import logger
...
@@ -9,7 +9,9 @@ from ..utils import logger
__all__
=
[
'StartProcOrThread'
]
__all__
=
[
'StartProcOrThread'
]
class
StartProcOrThread
(
Callback
):
class
StartProcOrThread
(
Callback
):
def
__init__
(
self
,
procs_threads
):
def
__init__
(
self
,
procs_threads
):
"""
"""
Start extra threads and processes before training
Start extra threads and processes before training
...
@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
...
@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
self
.
_procs_threads
=
procs_threads
self
.
_procs_threads
=
procs_threads
def
_before_train
(
self
):
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
\
logger
.
info
(
"Starting "
+
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
# avoid sigint get handled by other processes
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_procs_threads
)
start_proc_mask_signal
(
self
.
_procs_threads
)
tensorpack/callbacks/dispatcher.py
View file @
fb2a051c
...
@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
...
@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
__all__
=
[
'OutputTensorDispatcer'
]
__all__
=
[
'OutputTensorDispatcer'
]
class
OutputTensorDispatcer
(
object
):
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_names
=
[]
self
.
_idxs
=
[]
self
.
_idxs
=
[]
...
...
tensorpack/callbacks/dump.py
View file @
fb2a051c
...
@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
...
@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
__all__
=
[
'DumpParamAsImage'
]
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
class
DumpParamAsImage
(
Callback
):
"""
"""
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
"""
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
"""
:param var_name: the name of the variable.
:param var_name: the name of the variable.
...
@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
...
@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
if
self
.
clip
:
if
self
.
clip
:
res
=
np
.
clip
(
res
,
0
,
255
)
res
=
np
.
clip
(
res
,
0
,
255
)
cv2
.
imwrite
(
fname
,
res
.
astype
(
'uint8'
))
cv2
.
imwrite
(
fname
,
res
.
astype
(
'uint8'
))
tensorpack/callbacks/graph.py
View file @
fb2a051c
...
@@ -10,8 +10,10 @@ from ..utils import logger
...
@@ -10,8 +10,10 @@ from ..utils import logger
__all__
=
[
'RunOp'
]
__all__
=
[
'RunOp'
]
class
RunOp
(
Callback
):
class
RunOp
(
Callback
):
""" Run an op periodically"""
""" Run an op periodically"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
"""
:param setup_func: a function that returns the op in the graph
:param setup_func: a function that returns the op in the graph
...
@@ -34,5 +36,5 @@ class RunOp(Callback):
...
@@ -34,5 +36,5 @@ class RunOp(Callback):
if
self
.
run_epoch
:
if
self
.
run_epoch
:
self
.
_op
.
run
()
self
.
_op
.
run
()
#def _log(self):
#
def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/group.py
View file @
fb2a051c
...
@@ -12,7 +12,9 @@ from ..utils import logger
...
@@ -12,7 +12,9 @@ from ..utils import logger
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
class
CallbackTimeLogger
(
object
):
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
times
=
[]
self
.
tot
=
0
self
.
tot
=
0
...
@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
...
@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}"
.
format
(
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
self
.
tot
,
'; '
.
join
(
msgs
)))
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
"""
A container to hold all callbacks, and execute them in the right order and proper session.
A container to hold all callbacks, and execute them in the right order and proper session.
"""
"""
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
"""
"""
:param cbs: a list of `Callbacks`
:param cbs: a list of `Callbacks`
...
...
tensorpack/callbacks/inference.py
View file @
fb2a051c
...
@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics
...
@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics
from
..tfutils
import
get_op_var_name
from
..tfutils
import
get_op_var_name
__all__
=
[
'ClassificationError'
,
__all__
=
[
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
object
):
class
Inferencer
(
object
):
...
@@ -59,12 +60,14 @@ class Inferencer(object):
...
@@ -59,12 +60,14 @@ class Inferencer(object):
def
_get_output_tensors
(
self
):
def
_get_output_tensors
(
self
):
pass
pass
class
ScalarStats
(
Inferencer
):
class
ScalarStats
(
Inferencer
):
"""
"""
Write some scalar tensor to both stat and summary.
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
The value will be averaged over all data points in the inference dataflow.
"""
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
"""
:param names_to_print: list of names of tensors, or just a name
:param names_to_print: list of names of tensors, or just a name
...
@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
...
@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
ret
[
name
]
=
stat
ret
[
name
]
=
stat
return
ret
return
ret
class
ClassificationError
(
Inferencer
):
class
ClassificationError
(
Inferencer
):
"""
"""
Compute classification error in batch mode, from a `wrong` variable
Compute classification error in batch mode, from a `wrong` variable
...
@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
...
@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size).
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
Therefore the result is different from averaging the error rate of each batch.
"""
"""
def
__init__
(
self
,
wrong_var_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
def
__init__
(
self
,
wrong_var_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
"""
"""
:param wrong_var_name: name of the `wrong` variable
:param wrong_var_name: name of the `wrong` variable
...
@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
...
@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
def
_after_inference
(
self
):
def
_after_inference
(
self
):
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
class
BinaryClassificationStats
(
Inferencer
):
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
""" Compute precision/recall in binary classification, given the
prediction vector and the label vector.
prediction vector and the label vector.
...
...
tensorpack/callbacks/inference_runner.py
View file @
fb2a051c
...
@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
...
@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
__all__
=
[
'InferenceRunner'
]
__all__
=
[
'InferenceRunner'
]
def
summary_inferencer
(
trainer
,
infs
):
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
ret
=
inf
.
after_inference
()
...
@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
...
@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
continue
continue
trainer
.
write_scalar_summary
(
k
,
v
)
trainer
.
write_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Callback
):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs different kinds of inferencer.
...
@@ -54,16 +56,17 @@ class InferenceRunner(Callback):
...
@@ -54,16 +56,17 @@ class InferenceRunner(Callback):
self
.
input_tensors
=
input_tensors
self
.
input_tensors
=
input_tensors
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
self
.
input_tensors
,
self
.
output_tensors
)
self
.
input_tensors
,
self
.
output_tensors
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
if
self
.
input_tensors
is
None
:
if
self
.
input_tensors
is
None
:
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
# TODO even if it works here, sparse still is unavailable
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# because get_tensor_by_name doesn't work for sparse
def
get_name
(
x
):
def
get_name
(
x
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
...
@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
...
@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
list
(
filter
(
self
.
output_tensors
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
def
find_oid
(
idxs
):
def
find_oid
(
idxs
):
ret
=
[]
ret
=
[]
for
idx
in
idxs
:
for
idx
in
idxs
:
...
@@ -102,7 +106,7 @@ class InferenceRunner(Callback):
...
@@ -102,7 +106,7 @@ class InferenceRunner(Callback):
outputs
=
self
.
pred_func
(
dp
)
outputs
=
self
.
pred_func
(
dp
)
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
for
k
in
tensormap
]
for
k
in
tensormap
]
inf
.
datapoint
(
inf_output
)
inf
.
datapoint
(
inf_output
)
pbar
.
update
()
pbar
.
update
()
self
.
_write_summary_after_inference
()
self
.
_write_summary_after_inference
()
...
@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
...
@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
def
_write_summary_after_inference
(
self
):
def
_write_summary_after_inference
(
self
):
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
class
FeedfreeInferenceRunner
(
Callback
):
class
FeedfreeInferenceRunner
(
Callback
):
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
...
@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback):
if
self
.
input_tensor_names
is
not
None
:
if
self
.
input_tensor_names
is
not
None
:
assert
isinstance
(
self
.
input_tensor_names
,
list
)
assert
isinstance
(
self
.
input_tensor_names
,
list
)
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_tensor_names
),
\
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_tensor_names
),
\
"names of input tensors are not defined in the Model"
"names of input tensors are not defined in the Model"
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
# doesn't support output an input tensor
# doesn't support output an input tensor
...
@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
all_names
self
.
output_tensors
=
all_names
def
find_oid
(
idxs
):
def
find_oid
(
idxs
):
ret
=
[]
ret
=
[]
for
idx
in
idxs
:
for
idx
in
idxs
:
...
@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
# list of list of (var_name: IOTensor)
# list of list of (var_name: IOTensor)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
before_inference
()
inf
.
before_inference
()
...
@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
sz
=
self
.
_input_data
.
size
()
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
for
_
in
range
(
sz
):
#outputs = self.pred_func(dp)
#
outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
for k in tensormap]
#
for k in tensormap]
#
inf.datapoint(inf_output)
#
inf.datapoint(inf_output)
pbar
.
update
()
pbar
.
update
()
self
.
_write_summary_after_inference
()
self
.
_write_summary_after_inference
()
...
...
tensorpack/callbacks/param.py
View file @
fb2a051c
...
@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
...
@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
HyperParam
(
object
):
class
HyperParam
(
object
):
""" Base class for a hyper param"""
""" Base class for a hyper param"""
...
@@ -35,8 +37,10 @@ class HyperParam(object):
...
@@ -35,8 +37,10 @@ class HyperParam(object):
""" A name to display"""
""" A name to display"""
return
self
.
_readable_name
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
class
GraphVarParam
(
HyperParam
):
""" a variable in the graph can be a hyperparam"""
""" a variable in the graph can be a hyperparam"""
def
__init__
(
self
,
name
,
shape
=
[]):
def
__init__
(
self
,
name
,
shape
=
[]):
self
.
name
=
name
self
.
name
=
name
self
.
shape
=
shape
self
.
shape
=
shape
...
@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
...
@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
def
get_value
(
self
):
return
self
.
var
.
eval
()
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
class
ObjAttrParam
(
HyperParam
):
""" an attribute of an object can be a hyperparam"""
""" an attribute of an object can be a hyperparam"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
""" :param readable_name: default to be attrname."""
self
.
obj
=
obj
self
.
obj
=
obj
...
@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
...
@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
def
get_value
(
self
,
v
):
def
get_value
(
self
,
v
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Callback
):
class
HyperParamSetter
(
Callback
):
"""
"""
Base class to set hyperparameters after every epoch.
Base class to set hyperparameters after every epoch.
...
@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
...
@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
if
v
is
not
None
:
if
v
is
not
None
:
self
.
param
.
set_value
(
v
)
self
.
param
.
set_value
(
v
)
class
HumanHyperParamSetter
(
HyperParamSetter
):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by loading the value from a file each time it get called.
Set hyperparameters by loading the value from a file each time it get called.
"""
"""
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
"""
"""
:param file_name: a file containing the value of the variable.
:param file_name: a file containing the value of the variable.
...
@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
with
open
(
self
.
file_name
)
as
f
:
with
open
(
self
.
file_name
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
ret
=
dic
[
self
.
param
.
readable_name
]
ret
=
dic
[
self
.
param
.
readable_name
]
return
ret
return
ret
except
:
except
:
...
@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
self
.
param
.
readable_name
,
self
.
file_name
))
self
.
param
.
readable_name
,
self
.
file_name
))
return
None
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
"""
Set hyperparameters by a predefined schedule.
Set hyperparameters by a predefined schedule.
"""
"""
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
"""
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
...
@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
return
v
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
func
):
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
new_value = f(epoch_num, old_value)
...
@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter):
...
@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter):
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
class
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
last_k
,
reverse
=
False
):
):
"""
"""
Set hyperparameter by a func, when a specific stat wasn't
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
decreasing/increasing enough in the last $k$ epochs.
...
@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter):
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
holder
=
self
.
trainer
.
stat_holder
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
return
None
return
None
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist_first
=
hist
[
0
]
hist_first
=
hist
[
0
]
if
not
self
.
reverse
:
if
not
self
.
reverse
:
hist_min
=
min
(
hist
)
hist_min
=
min
(
hist
)
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
return
None
return
None
else
:
else
:
hist_max
=
max
(
hist
)
hist_max
=
max
(
hist
)
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
return
None
return
None
self
.
last_changed_epoch
=
self
.
epoch_num
self
.
last_changed_epoch
=
self
.
epoch_num
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history: "
+
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history: "
+
','
.
join
(
map
(
str
,
hist
)))
','
.
join
(
map
(
str
,
hist
)))
return
self
.
value_func
(
self
.
get_current_value
())
return
self
.
value_func
(
self
.
get_current_value
())
tensorpack/callbacks/saver.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
os
,
shutil
import
os
import
shutil
import
re
import
re
from
.base
import
Callback
from
.base
import
Callback
...
@@ -13,12 +14,14 @@ from ..tfutils import get_global_step
...
@@ -13,12 +14,14 @@ from ..tfutils import get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
class
ModelSaver
(
Callback
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
None
):
var_collections
=
None
):
"""
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
...
@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try
:
try
:
if
not
self
.
meta_graph_written
:
if
not
self
.
meta_graph_written
:
self
.
saver
.
export_meta_graph
(
self
.
saver
.
export_meta_graph
(
os
.
path
.
join
(
logger
.
LOG_DIR
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
collection_list
=
self
.
graph
.
get_all_collection_keys
())
collection_list
=
self
.
graph
.
get_all_collection_keys
())
self
.
meta_graph_written
=
True
self
.
meta_graph_written
=
True
self
.
saver
.
save
(
self
.
saver
.
save
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
...
@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Callback
):
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
self
.
monitor_stat
=
monitor_stat
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
reverse
=
reverse
...
@@ -116,15 +121,14 @@ class MinSaver(Callback):
...
@@ -116,15 +121,14 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
path
=
ckpt
.
model_checkpoint_path
path
=
ckpt
.
model_checkpoint_path
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
self
.
filename
or
self
.
filename
or
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
shutil
.
copy
(
path
,
newname
)
shutil
.
copy
(
path
,
newname
)
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
class
MaxSaver
(
MinSaver
):
class
MaxSaver
(
MinSaver
):
def
__init__
(
self
,
monitor_stat
):
def
__init__
(
self
,
monitor_stat
):
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
tensorpack/callbacks/stats.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
re
,
os
import
re
import
os
import
operator
import
operator
import
json
import
json
...
@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
...
@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
class
StatHolder
(
object
):
class
StatHolder
(
object
):
"""
"""
A holder to keep all statistics aside from tensorflow events.
A holder to keep all statistics aside from tensorflow events.
"""
"""
def
__init__
(
self
,
log_dir
):
def
__init__
(
self
,
log_dir
):
"""
"""
:param log_dir: directory to save the stats.
:param log_dir: directory to save the stats.
...
@@ -62,9 +65,11 @@ class StatHolder(object):
...
@@ -62,9 +65,11 @@ class StatHolder(object):
ret
=
[]
ret
=
[]
for
h
in
self
.
stat_history
:
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
v
=
h
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
return
ret
return
ret
def
finalize
(
self
):
def
finalize
(
self
):
...
@@ -88,13 +93,15 @@ class StatHolder(object):
...
@@ -88,13 +93,15 @@ class StatHolder(object):
with
open
(
tmp_filename
,
'w'
)
as
f
:
with
open
(
tmp_filename
,
'w'
)
as
f
:
json
.
dump
(
self
.
stat_history
,
f
)
json
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
except
IOError
:
# disk error sometimes..
except
IOError
:
# disk error sometimes..
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
class
StatPrinter
(
Callback
):
class
StatPrinter
(
Callback
):
"""
"""
Control what stats to print.
Control what stats to print.
"""
"""
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
"""
"""
:param print_tag: a list of regex to match scalar summary to print.
:param print_tag: a list of regex to match scalar summary to print.
...
@@ -116,6 +123,7 @@ class StatPrinter(Callback):
...
@@ -116,6 +123,7 @@ class StatPrinter(Callback):
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
class
SendStat
(
Callback
):
class
SendStat
(
Callback
):
"""
"""
Execute a command with some specific stats.
Execute a command with some specific stats.
...
@@ -126,6 +134,7 @@ class SendStat(Callback):
...
@@ -126,6 +134,7 @@ class SendStat(Callback):
-d body={validation_error} > /dev/null 2>&1',
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
'validation_error')
"""
"""
def
__init__
(
self
,
command
,
stats
):
def
__init__
(
self
,
command
,
stats
):
self
.
command
=
command
self
.
command
=
command
if
not
isinstance
(
stats
,
list
):
if
not
isinstance
(
stats
,
list
):
...
...
tensorpack/dataflow/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from . import imgaug
...
@@ -12,6 +12,7 @@ from . import imgaug
__all__
=
[
'dataset'
,
'imgaug'
,
'dftools'
]
__all__
=
[
'dataset'
,
'imgaug'
,
'dftools'
]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug']
...
@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug']
for
_
,
module_name
,
_
in
walk_packages
(
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
)
and
\
if
not
module_name
.
startswith
(
'_'
)
and
\
module_name
not
in
__SKIP
:
module_name
not
in
__SKIP
:
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/dataflow/base.py
View file @
fb2a051c
...
@@ -10,6 +10,7 @@ from ..utils import get_rng
...
@@ -10,6 +10,7 @@ from ..utils import get_rng
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
]
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DataFlow
(
object
):
class
DataFlow
(
object
):
""" Base class for all DataFlow """
""" Base class for all DataFlow """
...
@@ -17,7 +18,6 @@ class DataFlow(object):
...
@@ -17,7 +18,6 @@ class DataFlow(object):
class
Infinity
:
class
Infinity
:
pass
pass
@
abstractmethod
@
abstractmethod
def
get_data
(
self
):
def
get_data
(
self
):
"""
"""
...
@@ -44,11 +44,14 @@ class DataFlow(object):
...
@@ -44,11 +44,14 @@ class DataFlow(object):
class
RNGDataFlow
(
DataFlow
):
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
""" A dataflow with rng"""
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
class
ProxyDataFlow
(
DataFlow
):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
:param ds: a :mod:`DataFlow` instance to proxy
:param ds: a :mod:`DataFlow` instance to proxy
...
...
tensorpack/dataflow/common.py
View file @
fb2a051c
...
@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
...
@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
class
TestDataSpeed
(
ProxyDataFlow
):
class
TestDataSpeed
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
size
=
1000
):
def
__init__
(
self
,
ds
,
size
=
1000
):
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
self
.
test_size
=
size
self
.
test_size
=
size
...
@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
...
@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
pbar
.
update
()
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
"""
Group data in `ds` into batches.
Group data in `ds` into batches.
...
@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
...
@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
raise
raise
except
:
except
:
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
return
result
return
result
class
BatchDataByShape
(
BatchData
):
class
BatchDataByShape
(
BatchData
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
""" Group datapoint of the same shape together to batches
""" Group datapoint of the same shape together to batches
...
@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
...
@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
yield
BatchData
.
_aggregate_batch
(
holder
)
yield
BatchData
.
_aggregate_batch
(
holder
)
del
holder
[:]
del
holder
[:]
class
FixedSizeData
(
ProxyDataFlow
):
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed epoch size.
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
The state of the underlying DataFlow is maintained among each epoch.
"""
"""
def
__init__
(
self
,
ds
,
size
):
def
__init__
(
self
,
ds
,
size
):
"""
"""
:param ds: a :mod:`DataFlow` to produce data
:param ds: a :mod:`DataFlow` to produce data
...
@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
if
cnt
==
self
.
_size
:
if
cnt
==
self
.
_size
:
return
return
class
RepeatedData
(
ProxyDataFlow
):
class
RepeatedData
(
ProxyDataFlow
):
""" Take data points from another `DataFlow` and produce them until
""" Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times.
it's exhausted for certain amount of times.
"""
"""
def
__init__
(
self
,
ds
,
nr
):
def
__init__
(
self
,
ds
,
nr
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
...
@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
""" Apply map/filter a function on the datapoint"""
""" Apply map/filter a function on the datapoint"""
def
__init__
(
self
,
ds
,
func
):
def
__init__
(
self
,
ds
,
func
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
...
@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
if
ret
is
not
None
:
if
ret
is
not
None
:
yield
ret
yield
ret
class
MapDataComponent
(
ProxyDataFlow
):
class
MapDataComponent
(
ProxyDataFlow
):
""" Apply map/filter on the given index in the datapoint"""
""" Apply map/filter on the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
"""
:param ds: a :mod:`DataFlow` instance.
:param ds: a :mod:`DataFlow` instance.
...
@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
...
@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
dp
[
self
.
index
]
=
repl
# NOTE modifying
dp
[
self
.
index
]
=
repl
# NOTE modifying
yield
dp
yield
dp
class
RandomChooseData
(
RNGDataFlow
):
class
RandomChooseData
(
RNGDataFlow
):
"""
"""
Randomly choose from several DataFlow. Stop producing when any of them is
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
exhausted.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
...
@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
...
@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
except
StopIteration
:
except
StopIteration
:
return
return
class
RandomMixData
(
RNGDataFlow
):
class
RandomMixData
(
RNGDataFlow
):
"""
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of dataflow.
:param df_lists: list of dataflow.
...
@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
...
@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
idxs
=
np
.
array
(
list
(
map
(
idxs
=
np
.
array
(
list
(
map
(
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
)))
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
)))
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
for
k
in
idxs
:
for
k
in
idxs
:
yield
next
(
itrs
[
k
])
yield
next
(
itrs
[
k
])
class
ConcatData
(
DataFlow
):
class
ConcatData
(
DataFlow
):
"""
"""
Concatenate several dataflows.
Concatenate several dataflows.
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of :mod:`DataFlow` instances
:param df_lists: list of :mod:`DataFlow` instances
...
@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
...
@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
for
dp
in
d
.
get_data
():
for
dp
in
d
.
get_data
():
yield
dp
yield
dp
class
JoinData
(
DataFlow
):
class
JoinData
(
DataFlow
):
"""
"""
Join the components from each DataFlow.
Join the components from each DataFlow.
...
@@ -321,6 +342,7 @@ class JoinData(DataFlow):
...
@@ -321,6 +342,7 @@ class JoinData(DataFlow):
df2: [dp3, dp4]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
"""
def
__init__
(
self
,
df_lists
):
def
__init__
(
self
,
df_lists
):
"""
"""
:param df_lists: list of :mod:`DataFlow` instances
:param df_lists: list of :mod:`DataFlow` instances
...
@@ -329,7 +351,7 @@ class JoinData(DataFlow):
...
@@ -329,7 +351,7 @@ class JoinData(DataFlow):
self
.
_size
=
self
.
df_lists
[
0
]
.
size
()
self
.
_size
=
self
.
df_lists
[
0
]
.
size
()
for
d
in
self
.
df_lists
:
for
d
in
self
.
df_lists
:
assert
d
.
size
()
==
self
.
_size
,
\
assert
d
.
size
()
==
self
.
_size
,
\
"All DataFlow must have the same size! {} != {}"
.
format
(
d
.
size
(),
self
.
_size
)
"All DataFlow must have the same size! {} != {}"
.
format
(
d
.
size
(),
self
.
_size
)
def
reset_state
(
self
):
def
reset_state
(
self
):
for
d
in
self
.
df_lists
:
for
d
in
self
.
df_lists
:
...
@@ -352,7 +374,9 @@ class JoinData(DataFlow):
...
@@ -352,7 +374,9 @@ class JoinData(DataFlow):
for
itr
in
itrs
:
for
itr
in
itrs
:
del
itr
del
itr
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
,
nr_reuse
=
1
):
def
__init__
(
self
,
ds
,
cache_size
,
nr_reuse
=
1
):
"""
"""
Cache a number of datapoints and shuffle them.
Cache a number of datapoints and shuffle them.
...
@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
...
@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield
v
yield
v
return
return
def
SelectComponent
(
ds
,
idxs
):
def
SelectComponent
(
ds
,
idxs
):
"""
"""
:param ds: a :mod:`DataFlow` instance
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
:param idxs: a list of datapoint component index of the original dataflow
"""
"""
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
tensorpack/dataflow/dataset/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,8 @@ import os
...
@@ -7,6 +7,8 @@ import os
import
os.path
import
os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
...
@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/dataflow/dataset/bsds500.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: bsds500.py
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
glob
import
os
import
glob
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -21,6 +22,7 @@ except ImportError:
...
@@ -21,6 +22,7 @@ except ImportError:
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
RNGDataFlow
):
class
BSDS500
(
RNGDataFlow
):
"""
"""
`Berkeley Segmentation Data Set and Benchmarks 500
`Berkeley Segmentation Data Set and Benchmarks 500
...
@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
...
@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
assert
im
is
not
None
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
...
@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
...
@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
a
=
BSDS500
(
'val'
)
a
=
BSDS500
(
'val'
)
for
k
in
a
.
get_data
():
for
k
in
a
.
get_data
():
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
waitKey
(
1000
)
cv2
.
waitKey
(
1000
)
tensorpack/dataflow/dataset/cifar.py
View file @
fb2a051c
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import
os
,
sys
import
os
import
sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
random
import
random
...
@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
...
@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
copied from tensorflow example """
...
@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
...
@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
import
tarfile
import
tarfile
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar
(
filenames
,
cifar_classnum
):
def
read_cifar
(
filenames
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
ret
=
[]
ret
=
[]
...
@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum):
...
@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum):
data
=
dic
[
b
'data'
]
data
=
dic
[
b
'data'
]
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
label
=
dic
[
b
'labels'
]
label
=
dic
[
b
'labels'
]
IMG_NUM
=
10000
# cifar10 data are split into blocks of 10000
IMG_NUM
=
10000
# cifar10 data are split into blocks of 10000
elif
cifar_classnum
==
100
:
elif
cifar_classnum
==
100
:
label
=
dic
[
b
'fine_labels'
]
label
=
dic
[
b
'fine_labels'
]
IMG_NUM
=
50000
if
'train'
in
fname
else
10000
IMG_NUM
=
50000
if
'train'
in
fname
else
10000
...
@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
...
@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
ret
.
append
([
img
,
label
[
k
]])
ret
.
append
([
img
,
label
[
k
]])
return
ret
return
ret
def
get_filenames
(
dir
,
cifar_classnum
):
def
get_filenames
(
dir
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
...
@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
...
@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
return
filenames
class
CifarBase
(
RNGDataFlow
):
class
CifarBase
(
RNGDataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 32x32x3 in the range [0,255]
image is 32x32x3 in the range [0,255]
"""
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
"""
"""
Args:
Args:
...
@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
...
@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
return three values as mean of each channel
return three values as mean of each channel
"""
"""
mean
=
self
.
get_per_pixel_mean
()
mean
=
self
.
get_per_pixel_mean
()
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
class
Cifar10
(
CifarBase
):
class
Cifar10
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
class
Cifar100
(
CifarBase
):
class
Cifar100
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
...
@@ -149,7 +159,6 @@ if __name__ == '__main__':
...
@@ -149,7 +159,6 @@ if __name__ == '__main__':
print
(
mean
)
print
(
mean
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
#for (img, label) in ds.get_data():
# for (img, label) in ds.get_data():
#from IPython import embed; embed()
# from IPython import embed; embed()
#break
# break
tensorpack/dataflow/dataset/ilsvrc.py
View file @
fb2a051c
...
@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
...
@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
class
ILSVRCMeta
(
object
):
"""
"""
Some metadata for ILSVRC dataset.
Some metadata for ILSVRC dataset.
"""
"""
def
__init__
(
self
,
dir
=
None
):
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
...
@@ -82,14 +84,16 @@ class ILSVRCMeta(object):
...
@@ -82,14 +84,16 @@ class ILSVRCMeta(object):
with
open
(
mean_file
,
'rb'
)
as
f
:
with
open
(
mean_file
,
'rb'
)
as
f
:
obj
.
ParseFromString
(
f
.
read
())
obj
.
ParseFromString
(
f
.
read
())
arr
=
np
.
array
(
obj
.
data
)
.
reshape
((
3
,
256
,
256
))
.
astype
(
'float32'
)
arr
=
np
.
array
(
obj
.
data
)
.
reshape
((
3
,
256
,
256
))
.
astype
(
'float32'
)
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
if
size
is
not
None
:
if
size
is
not
None
:
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
return
arr
return
arr
class
ILSVRC12
(
RNGDataFlow
):
class
ILSVRC12
(
RNGDataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
,
include_bb
=
False
):
dir_structure
=
'original'
,
include_bb
=
False
):
"""
"""
:param dir: A directory containing a subdir named `name`, where the
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
original ILSVRC12_`name`.tar gets decompressed.
...
@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow):
...
@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow):
if
include_bb
:
if
include_bb
:
bbdir
=
os
.
path
.
join
(
dir
,
'bbox'
)
if
not
\
bbdir
=
os
.
path
.
join
(
dir
,
'bbox'
)
if
not
\
isinstance
(
include_bb
,
six
.
string_types
)
else
include_bb
isinstance
(
include_bb
,
six
.
string_types
)
else
include_bb
assert
name
==
'train'
,
'Bounding box only available for training'
assert
name
==
'train'
,
'Bounding box only available for training'
self
.
bblist
=
ILSVRC12
.
get_training_bbox
(
bbdir
,
self
.
imglist
)
self
.
bblist
=
ILSVRC12
.
get_training_bbox
(
bbdir
,
self
.
imglist
)
self
.
include_bb
=
include_bb
self
.
include_bb
=
include_bb
...
@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
...
@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
im
=
cv2
.
imread
(
fname
.
strip
(),
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
fname
.
strip
(),
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
if
im
.
ndim
==
2
:
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
if
self
.
include_bb
:
if
self
.
include_bb
:
bb
=
self
.
bblist
[
k
]
bb
=
self
.
bblist
[
k
]
if
bb
is
None
:
if
bb
is
None
:
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
yield
[
im
,
label
,
bb
]
yield
[
im
,
label
,
bb
]
else
:
else
:
yield
[
im
,
label
]
yield
[
im
,
label
]
...
@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
...
@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
meta
=
ILSVRCMeta
()
#print(meta.get_synset_words_1000())
#
print(meta.get_synset_words_1000())
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
shuffle
=
False
)
shuffle
=
False
)
ds
.
reset_state
()
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
for
k
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
break
tensorpack/dataflow/dataset/mnist.py
View file @
fb2a051c
...
@@ -17,6 +17,7 @@ __all__ = ['Mnist']
...
@@ -17,6 +17,7 @@ __all__ = ['Mnist']
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
def
maybe_download
(
filename
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
"""Download the data from Yann's website, unless it's already here."""
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
...
@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory):
...
@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory):
download
(
SOURCE_URL
+
filename
,
work_directory
)
download
(
SOURCE_URL
+
filename
,
work_directory
)
return
filepath
return
filepath
def
_read32
(
bytestream
):
def
_read32
(
bytestream
):
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
extract_images
(
filename
):
def
extract_images
(
filename
):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
magic
=
_read32
(
bytestream
)
if
magic
!=
2051
:
if
magic
!=
2051
:
raise
ValueError
(
raise
ValueError
(
'Invalid magic number
%
d in MNIST image file:
%
s'
%
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
(
magic
,
filename
))
num_images
=
_read32
(
bytestream
)
num_images
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
...
@@ -46,24 +49,27 @@ def extract_images(filename):
...
@@ -46,24 +49,27 @@ def extract_images(filename):
data
=
data
.
astype
(
'float32'
)
/
255.0
data
=
data
.
astype
(
'float32'
)
/
255.0
return
data
return
data
def
extract_labels
(
filename
):
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
magic
=
_read32
(
bytestream
)
if
magic
!=
2049
:
if
magic
!=
2049
:
raise
ValueError
(
raise
ValueError
(
'Invalid magic number
%
d in MNIST label file:
%
s'
%
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
(
magic
,
filename
))
num_items
=
_read32
(
bytestream
)
num_items
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
num_items
)
buf
=
bytestream
.
read
(
num_items
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
return
labels
class
Mnist
(
RNGDataFlow
):
class
Mnist
(
RNGDataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 28x28 in the range [0,1]
image is 28x28 in the range [0,1]
"""
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
"""
Args:
Args:
...
@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
...
@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
for
(
img
,
label
)
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
break
tensorpack/dataflow/dataset/ptb.py
View file @
fb2a051c
...
@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
...
@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
@
memoized_ignoreargs
def
get_PennTreeBank
(
data_dir
=
None
):
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
if
data_dir
is
None
:
...
@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None):
...
@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None):
# TODO these functions in TF might not be available in the future
# TODO these functions in TF might not be available in the future
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
return
data3
,
word_to_id
tensorpack/dataflow/dataset/svhn.py
View file @
fb2a051c
...
@@ -19,6 +19,7 @@ except ImportError:
...
@@ -19,6 +19,7 @@ except ImportError:
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
RNGDataFlow
):
class
SVHNDigit
(
RNGDataFlow
):
"""
"""
SVHN Cropped Digit Dataset.
SVHN Cropped Digit Dataset.
...
@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow):
...
@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow):
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
assert
os
.
path
.
isfile
(
filename
),
\
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
data
=
scipy
.
io
.
loadmat
(
filename
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
Y
=
data
[
'y'
]
.
reshape
((
-
1
))
self
.
Y
=
data
[
'y'
]
.
reshape
((
-
1
))
self
.
Y
[
self
.
Y
==
10
]
=
0
self
.
Y
[
self
.
Y
==
10
]
=
0
SVHNDigit
.
_Cache
[
name
]
=
(
self
.
X
,
self
.
Y
)
SVHNDigit
.
_Cache
[
name
]
=
(
self
.
X
,
self
.
Y
)
def
size
(
self
):
def
size
(
self
):
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ import json
...
@@ -12,6 +12,7 @@ import json
__all__
=
[
'VisualQA'
]
__all__
=
[
'VisualQA'
]
def
read_json
(
fname
):
def
read_json
(
fname
):
f
=
open
(
fname
)
f
=
open
(
fname
)
ret
=
json
.
load
(
f
)
ret
=
json
.
load
(
f
)
...
@@ -19,11 +20,14 @@ def read_json(fname):
...
@@ -19,11 +20,14 @@ def read_json(fname):
return
ret
return
ret
# TODO shuffle
# TODO shuffle
class
VisualQA
(
DataFlow
):
class
VisualQA
(
DataFlow
):
"""
"""
Visual QA dataset. See http://visualqa.org/
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
Simply read q/a json file and produce q/a pairs in their original format.
"""
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
def
__init__
(
self
,
question_file
,
annotation_file
):
with
timed_operation
(
'Reading VQA JSON file'
):
with
timed_operation
(
'Reading VQA JSON file'
):
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
...
@@ -62,7 +66,7 @@ class VisualQA(DataFlow):
...
@@ -62,7 +66,7 @@ class VisualQA(DataFlow):
""" Get the n most common words in questions
""" Get the n most common words in questions
n=4600 ~= thresh 6
n=4600 ~= thresh 6
"""
"""
from
nltk.tokenize
import
word_tokenize
# will need to download 'punckt'
from
nltk.tokenize
import
word_tokenize
# will need to download 'punckt'
cnt
=
Counter
()
cnt
=
Counter
()
for
q
in
self
.
questions
:
for
q
in
self
.
questions
:
cnt
.
update
(
word_tokenize
(
q
[
'question'
]
.
lower
()))
cnt
.
update
(
word_tokenize
(
q
[
'question'
]
.
lower
()))
...
@@ -72,7 +76,7 @@ class VisualQA(DataFlow):
...
@@ -72,7 +76,7 @@ class VisualQA(DataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
for
k
in
vqa
.
get_data
():
for
k
in
vqa
.
get_data
():
print
(
json
.
dumps
(
k
))
print
(
json
.
dumps
(
k
))
break
break
...
...
tensorpack/dataflow/dftools.py
View file @
fb2a051c
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# File: dftools.py
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
sys
,
os
import
sys
import
os
import
cv2
import
cv2
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
six
import
six
...
@@ -23,6 +24,8 @@ else:
...
@@ -23,6 +24,8 @@ else:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
# TODO pass a name_func to write label as filename?
# TODO pass a name_func to write label as filename?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a `DataFlow` to a directory.
""" Dump images from a `DataFlow` to a directory.
...
@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img
=
dp
[
index
]
img
=
dp
[
index
]
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
and the data is the serialized datapoint.
...
@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
...
@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
assert
not
os
.
path
.
isfile
(
lmdb_path
),
"LMDB file exists!"
assert
not
os
.
path
.
isfile
(
lmdb_path
),
"LMDB file exists!"
ds
.
reset_state
()
ds
.
reset_state
()
db
=
lmdb
.
open
(
lmdb_path
,
subdir
=
isdir
,
db
=
lmdb
.
open
(
lmdb_path
,
subdir
=
isdir
,
map_size
=
1099511627776
*
2
,
readonly
=
False
,
map_size
=
1099511627776
*
2
,
readonly
=
False
,
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
try
:
try
:
sz
=
ds
.
size
()
sz
=
ds
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
...
@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
...
@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
the queue once you start it. Each element is (task_id, dp).
the queue once you start it. Each element is (task_id, dp).
"""
"""
q
=
mp
.
Queue
(
size
)
q
=
mp
.
Queue
(
size
)
class
EnqueProc
(
mp
.
Process
):
class
EnqueProc
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
...
@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
return
q
,
proc
tensorpack/dataflow/format.py
View file @
fb2a051c
...
@@ -40,10 +40,13 @@ Adapters for different data format.
...
@@ -40,10 +40,13 @@ Adapters for different data format.
"""
"""
# TODO lazy load
# TODO lazy load
class
HDF5Data
(
RNGDataFlow
):
class
HDF5Data
(
RNGDataFlow
):
"""
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
"""
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
"""
:param filename: h5 data file.
:param filename: h5 data file.
...
@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
...
@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
self
.
dps
=
[
self
.
f
[
k
]
.
value
for
k
in
data_paths
]
self
.
dps
=
[
self
.
f
[
k
]
.
value
for
k
in
data_paths
]
lens
=
[
len
(
k
)
for
k
in
self
.
dps
]
lens
=
[
len
(
k
)
for
k
in
self
.
dps
]
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
self
.
_size
=
lens
[
0
]
self
.
_size
=
lens
[
0
]
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
...
@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
self
.
_lmdb_path
=
lmdb_path
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
...
@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow):
...
@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow):
def
open_lmdb
(
self
):
def
open_lmdb
(
self
):
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
if
self
.
_shuffle
:
if
self
.
_shuffle
:
...
@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
...
@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
v
=
self
.
_txn
.
get
(
k
)
v
=
self
.
_txn
.
get
(
k
)
yield
[
k
,
v
]
yield
[
k
,
v
]
class
LMDBDataDecoder
(
LMDBData
):
class
LMDBDataDecoder
(
LMDBData
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
"""
"""
:param decoder: a function taking k, v and return a data point,
:param decoder: a function taking k, v and return a data point,
...
@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
...
@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
if
v
:
yield
v
if
v
:
yield
v
class
LMDBDataPoint
(
LMDBDataDecoder
):
class
LMDBDataPoint
(
LMDBDataDecoder
):
""" Read a LMDB file where each value is a serialized datapoint"""
""" Read a LMDB file where each value is a serialized datapoint"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
super
(
LMDBDataPoint
,
self
)
.
__init__
(
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
class
CaffeLMDB
(
LMDBDataDecoder
):
class
CaffeLMDB
(
LMDBDataDecoder
):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
cpb
=
get_caffe_pb
()
cpb
=
get_caffe_pb
()
def
decoder
(
k
,
v
):
def
decoder
(
k
,
v
):
try
:
try
:
datum
=
cpb
.
Datum
()
datum
=
cpb
.
Datum
()
...
@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder):
...
@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder):
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
super
(
CaffeLMDB
,
self
)
.
__init__
(
super
(
CaffeLMDB
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
class
SVMLightData
(
RNGDataFlow
):
class
SVMLightData
(
RNGDataFlow
):
""" Read X,y from a svmlight file """
""" Read X,y from a svmlight file """
def
__init__
(
self
,
filename
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
shuffle
=
True
):
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
...
@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
...
@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
id
in
idxs
:
for
id
in
idxs
:
yield
[
self
.
X
[
id
,:],
self
.
y
[
id
]]
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
tensorpack/dataflow/image.py
View file @
fb2a051c
...
@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
...
@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageComponents'
]
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageComponents'
]
class
ImageFromFile
(
RNGDataFlow
):
class
ImageFromFile
(
RNGDataFlow
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
"""
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
...
@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
...
@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
if
self
.
resize
is
not
None
:
if
self
.
resize
is
not
None
:
im
=
cv2
.
resize
(
im
,
self
.
resize
[::
-
1
])
im
=
cv2
.
resize
(
im
,
self
.
resize
[::
-
1
])
if
self
.
channel
==
1
:
if
self
.
channel
==
1
:
im
=
im
[:,
:,
np
.
newaxis
]
im
=
im
[:,
:,
np
.
newaxis
]
yield
[
im
]
yield
[
im
]
class
AugmentImageComponent
(
MapDataComponent
):
class
AugmentImageComponent
(
MapDataComponent
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
"""
"""
Augment the image component of datapoints
Augment the image component of datapoints
...
@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
...
@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
class
AugmentImageComponents
(
MapData
):
class
AugmentImageComponents
(
MapData
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
""" Augment a list of images of the same shape, with the same parameters
""" Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
...
...
tensorpack/dataflow/imgaug/__init__.py
View file @
fb2a051c
...
@@ -7,6 +7,7 @@ from pkgutil import walk_packages
...
@@ -7,6 +7,7 @@ from pkgutil import walk_packages
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
...
@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/dataflow/imgaug/_test.py
View file @
fb2a051c
...
@@ -10,15 +10,15 @@ from .crop import *
...
@@ -10,15 +10,15 @@ from .crop import *
from
.imgproc
import
*
from
.imgproc
import
*
from
.noname
import
*
from
.noname
import
*
from
.deform
import
*
from
.deform
import
*
from
.noise
import
SaltPepperNoise
from
.noise
import
SaltPepperNoise
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
augmentors
=
AugmentorList
([
augmentors
=
AugmentorList
([
Contrast
((
0.8
,
1.2
)),
Contrast
((
0.8
,
1.2
)),
Flip
(
horiz
=
True
),
Flip
(
horiz
=
True
),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
#RandomCropRandomShape(0.3),
#
RandomCropRandomShape(0.3),
SaltPepperNoise
()
SaltPepperNoise
()
])
])
...
...
tensorpack/dataflow/imgaug/base.py
View file @
fb2a051c
...
@@ -9,6 +9,7 @@ from six.moves import zip
...
@@ -9,6 +9,7 @@ from six.moves import zip
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Augmentor
(
object
):
class
Augmentor
(
object
):
""" Base class for an augmentor"""
""" Base class for an augmentor"""
...
@@ -58,7 +59,9 @@ class Augmentor(object):
...
@@ -58,7 +59,9 @@ class Augmentor(object):
size
=
[]
size
=
[]
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
class
ImageAugmentor
(
Augmentor
):
class
ImageAugmentor
(
Augmentor
):
def
augment
(
self
,
img
):
def
augment
(
self
,
img
):
"""
"""
Perform augmentation on the image in-place.
Perform augmentation on the image in-place.
...
@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
...
@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
return
coord
return
coord
class
AugmentorList
(
ImageAugmentor
):
class
AugmentorList
(
ImageAugmentor
):
"""
"""
Augment by a list of augmentors
Augment by a list of augmentors
"""
"""
def
__init__
(
self
,
augmentors
):
def
__init__
(
self
,
augmentors
):
"""
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
:param augmentors: list of `ImageAugmentor` instance to be applied
...
@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
...
@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """
""" Will reset state of each augmentor """
for
a
in
self
.
augs
:
for
a
in
self
.
augs
:
a
.
reset_state
()
a
.
reset_state
()
tensorpack/dataflow/imgaug/crop.py
View file @
fb2a051c
...
@@ -10,10 +10,12 @@ from six.moves import range
...
@@ -10,10 +10,12 @@ from six.moves import range
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
class
RandomCrop
(
ImageAugmentor
):
class
RandomCrop
(
ImageAugmentor
):
""" Randomly crop the image into a smaller one """
""" Randomly crop the image into a smaller one """
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
:param crop_shape: a shape like (h, w)
...
@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor):
...
@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
assert
orig_shape
[
0
]
>=
self
.
crop_shape
[
0
]
\
assert
orig_shape
[
0
]
>=
self
.
crop_shape
[
0
]
\
and
orig_shape
[
1
]
>=
self
.
crop_shape
[
1
],
orig_shape
and
orig_shape
[
1
]
>=
self
.
crop_shape
[
1
],
orig_shape
diffh
=
orig_shape
[
0
]
-
self
.
crop_shape
[
0
]
diffh
=
orig_shape
[
0
]
-
self
.
crop_shape
[
0
]
h0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
h0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
diffw
=
orig_shape
[
1
]
-
self
.
crop_shape
[
1
]
diffw
=
orig_shape
[
1
]
-
self
.
crop_shape
[
1
]
...
@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
...
@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
param
):
def
_augment
(
self
,
img
,
param
):
h0
,
w0
=
param
h0
,
w0
=
param
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
CenterCrop
(
ImageAugmentor
):
class
CenterCrop
(
ImageAugmentor
):
""" Crop the image at the center"""
""" Crop the image at the center"""
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
:param crop_shape: a shape like (h, w)
...
@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
...
@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
FixedCrop
(
ImageAugmentor
):
class
FixedCrop
(
ImageAugmentor
):
""" Crop a rectangle at a given location"""
""" Crop a rectangle at a given location"""
def
__init__
(
self
,
rect
):
def
__init__
(
self
,
rect
):
"""
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
Two arguments defined the range in both axes to crop, min inclued, max excluded.
...
@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor):
...
@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
orig_shape
=
img
.
shape
orig_shape
=
img
.
shape
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
perturb_BB
(
image_shape
,
bb
,
max_pertub_pixel
,
def
perturb_BB
(
image_shape
,
bb
,
max_pertub_pixel
,
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
max_try
=
100
):
max_try
=
100
):
"""
"""
Perturb a bounding box.
Perturb a bounding box.
:param image_shape: [h, w]
:param image_shape: [h, w]
...
@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
"""
"""
Crop a box around a bounding box
Crop a box around a bounding box
"""
"""
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
"""
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
...
@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
shape
=
img
.
shape
[:
2
]
shape
=
img
.
shape
[:
2
]
box
=
Rect
(
0
,
0
,
shape
[
1
]
-
1
,
shape
[
0
]
-
1
)
box
=
Rect
(
0
,
0
,
shape
[
1
]
-
1
,
shape
[
0
]
-
1
)
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
newbox
=
perturb_BB
(
shape
,
box
,
dist
,
newbox
=
perturb_BB
(
shape
,
box
,
dist
,
self
.
rng
,
self
.
max_aspect_ratio_diff
)
self
.
rng
,
self
.
max_aspect_ratio_diff
)
return
newbox
return
newbox
def
_augment
(
self
,
img
,
newbox
):
def
_augment
(
self
,
img
,
newbox
):
...
@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
RandomCropRandomShape
(
ImageAugmentor
):
class
RandomCropRandomShape
(
ImageAugmentor
):
def
__init__
(
self
,
wmin
,
hmin
,
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
max_aspect_ratio
=
None
):
"""
"""
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape.
If max is None, will use the input image shape.
...
@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
diffh
=
img
.
shape
[
0
]
-
h
diffh
=
img
.
shape
[
0
]
-
h
diffw
=
img
.
shape
[
1
]
-
w
diffw
=
img
.
shape
[
1
]
-
w
assert
diffh
>=
0
and
diffw
>=
0
assert
diffh
>=
0
and
diffw
>=
0
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
return
(
y0
,
x0
,
h
,
w
)
return
(
y0
,
x0
,
h
,
w
)
def
_augment
(
self
,
img
,
param
):
def
_augment
(
self
,
img
,
param
):
y0
,
x0
,
h
,
w
=
param
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/dataflow/imgaug/deform.py
View file @
fb2a051c
...
@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
...
@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup
# TODO really needs speedup
class
GaussianMap
(
object
):
class
GaussianMap
(
object
):
""" Generate gaussian weighted deformation map"""
""" Generate gaussian weighted deformation map"""
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
assert
len
(
image_shape
)
==
2
assert
len
(
image_shape
)
==
2
self
.
shape
=
image_shape
self
.
shape
=
image_shape
...
@@ -25,17 +27,18 @@ class GaussianMap(object):
...
@@ -25,17 +27,18 @@ class GaussianMap(object):
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
#cv2.imshow(" ", g)
#cv2.imshow(" ", g)
#cv2.waitKey()
#
cv2.waitKey()
return
g
return
g
def
np_sample
(
img
,
coords
):
def
np_sample
(
img
,
coords
):
# a numpy implementation of ImageSample layer
# a numpy implementation of ImageSample layer
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
lcoor
=
np
.
floor
(
coords
)
.
astype
(
'int32'
)
lcoor
=
np
.
floor
(
coords
)
.
astype
(
'int32'
)
ucoor
=
lcoor
+
1
ucoor
=
lcoor
+
1
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
diff
=
coords
-
lcoor
diff
=
coords
-
lcoor
neg_diff
=
1.0
-
diff
neg_diff
=
1.0
-
diff
...
@@ -46,17 +49,20 @@ def np_sample(img, coords):
...
@@ -46,17 +49,20 @@ def np_sample(img, coords):
diffy
,
diffx
=
np
.
split
(
diff
,
2
,
axis
=
2
)
diffy
,
diffx
=
np
.
split
(
diff
,
2
,
axis
=
2
)
ndiffy
,
ndiffx
=
np
.
split
(
neg_diff
,
2
,
axis
=
2
)
ndiffy
,
ndiffx
=
np
.
split
(
neg_diff
,
2
,
axis
=
2
)
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
return
ret
[:,
:,
0
,
:]
# TODO input/output with different shape
# TODO input/output with different shape
class
GaussianDeform
(
ImageAugmentor
):
class
GaussianDeform
(
ImageAugmentor
):
"""
"""
Some kind of deformation. Quite slow.
Some kind of deformation. Quite slow.
"""
"""
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
"""
"""
:param anchors: in [0,1] coordinate
:param anchors: in [0,1] coordinate
...
@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
...
@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
self
.
anchors
=
anchors
self
.
anchors
=
anchors
self
.
K
=
len
(
self
.
anchors
)
self
.
K
=
len
(
self
.
anchors
)
self
.
shape
=
shape
self
.
shape
=
shape
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
gm
=
GaussianMap
(
self
.
shape
,
sigma
=
sigma
)
gm
=
GaussianMap
(
self
.
shape
,
sigma
=
sigma
)
self
.
gws
=
np
.
array
([
gm
.
get_gaussian_weight
(
ank
)
self
.
gws
=
np
.
array
([
gm
.
get_gaussian_weight
(
ank
)
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#HxWxK
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#
HxWxK
if
randrange
is
None
:
if
randrange
is
None
:
self
.
randrange
=
self
.
shape
[
0
]
/
8
self
.
randrange
=
self
.
shape
[
0
]
/
8
else
:
else
:
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
fb2a051c
...
@@ -10,11 +10,13 @@ import numpy as np
...
@@ -10,11 +10,13 @@ import numpy as np
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
class
Rotation
(
ImageAugmentor
):
class
Rotation
(
ImageAugmentor
):
""" Random rotate the image w.r.t a random center"""
""" Random rotate the image w.r.t a random center"""
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
interp
=
cv2
.
INTER_CUBIC
,
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
border
=
cv2
.
BORDER_REPLICATE
):
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
"""
"""
:param max_deg: max abs value of the rotation degree
:param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center
:param center_range: the location of the rotation center
...
@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor):
...
@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
center
=
img
.
shape
[
1
::
-
1
]
*
self
.
_rand_range
(
center
=
img
.
shape
[
1
::
-
1
]
*
self
.
_rand_range
(
self
.
center_range
[
0
],
self
.
center_range
[
1
],
(
2
,))
self
.
center_range
[
0
],
self
.
center_range
[
1
],
(
2
,))
deg
=
self
.
_rand_range
(
-
self
.
max_deg
,
self
.
max_deg
)
deg
=
self
.
_rand_range
(
-
self
.
max_deg
,
self
.
max_deg
)
return
cv2
.
getRotationMatrix2D
(
tuple
(
center
),
deg
,
1
)
return
cv2
.
getRotationMatrix2D
(
tuple
(
center
),
deg
,
1
)
def
_augment
(
self
,
img
,
rot_m
):
def
_augment
(
self
,
img
,
rot_m
):
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
return
ret
return
ret
class
RotationAndCropValid
(
ImageAugmentor
):
class
RotationAndCropValid
(
ImageAugmentor
):
""" Random rotate and crop the largest possible rect without the border
""" Random rotate and crop the largest possible rect without the border
This will produce images of different shapes.
This will produce images of different shapes.
"""
"""
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
super
(
RotationAndCropValid
,
self
)
.
__init__
()
super
(
RotationAndCropValid
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor):
...
@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor):
return
deg
return
deg
def
_augment
(
self
,
img
,
deg
):
def
_augment
(
self
,
img
,
deg
):
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
rot_m
=
cv2
.
getRotationMatrix2D
(
center
,
deg
,
1
)
rot_m
=
cv2
.
getRotationMatrix2D
(
center
,
deg
,
1
)
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
neww
,
newh
=
RotationAndCropValid
.
largest_rotated_rect
(
ret
.
shape
[
1
],
ret
.
shape
[
0
],
deg
)
neww
,
newh
=
RotationAndCropValid
.
largest_rotated_rect
(
ret
.
shape
[
1
],
ret
.
shape
[
0
],
deg
)
neww
=
min
(
neww
,
ret
.
shape
[
1
])
neww
=
min
(
neww
,
ret
.
shape
[
1
])
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
#print(ret.shape, deg, newx, newy, neww, newh)
#print(ret.shape, deg, newx, newy, neww, newh)
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
@
staticmethod
@
staticmethod
def
largest_rotated_rect
(
w
,
h
,
angle
):
def
largest_rotated_rect
(
w
,
h
,
angle
):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle
=
angle
/
180.0
*
math
.
pi
angle
=
angle
/
180.0
*
math
.
pi
if
w
<=
0
or
h
<=
0
:
if
w
<=
0
or
h
<=
0
:
return
0
,
0
return
0
,
0
width_is_longer
=
w
>=
h
width_is_longer
=
w
>=
h
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
# since the solutions for angle, -angle and 180-angle are all the same,
# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a
,
cos_a
=
abs
(
math
.
sin
(
angle
)),
abs
(
math
.
cos
(
angle
))
sin_a
,
cos_a
=
abs
(
math
.
sin
(
angle
)),
abs
(
math
.
cos
(
angle
))
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
# half constrained case: two crop corners touch the longer side,
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
# the other two corners are on the mid-line parallel to the longer line
x
=
0.5
*
side_short
x
=
0.5
*
side_short
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
else
:
else
:
# fully constrained case: crop touches all 4 sides
# fully constrained case: crop touches all 4 sides
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
return
int
(
wr
),
int
(
hr
)
return
int
(
wr
),
int
(
hr
)
tensorpack/dataflow/imgaug/imgproc.py
View file @
fb2a051c
...
@@ -7,12 +7,14 @@ import numpy as np
...
@@ -7,12 +7,14 @@ import numpy as np
import
cv2
import
cv2
__all__
=
[
'Brightness'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
__all__
=
[
'Brightness'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
class
Brightness
(
ImageAugmentor
):
class
Brightness
(
ImageAugmentor
):
"""
"""
Random adjust brightness.
Random adjust brightness.
"""
"""
def
__init__
(
self
,
delta
,
clip
=
True
):
def
__init__
(
self
,
delta
,
clip
=
True
):
"""
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...
@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
...
@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
img
=
np
.
clip
(
img
,
0
,
255
)
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
return
img
class
Contrast
(
ImageAugmentor
):
class
Contrast
(
ImageAugmentor
):
"""
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
and clip to [0, 255]
"""
"""
def
__init__
(
self
,
factor_range
,
clip
=
True
):
def
__init__
(
self
,
factor_range
,
clip
=
True
):
"""
"""
:param factor_range: an interval to random sample the `contrast_factor`.
:param factor_range: an interval to random sample the `contrast_factor`.
...
@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
...
@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
return
self
.
_rand_range
(
*
self
.
factor_range
)
return
self
.
_rand_range
(
*
self
.
factor_range
)
def
_augment
(
self
,
img
,
r
):
def
_augment
(
self
,
img
,
r
):
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
img
=
(
img
-
mean
)
*
r
+
mean
img
=
(
img
-
mean
)
*
r
+
mean
if
self
.
clip
:
if
self
.
clip
:
img
=
np
.
clip
(
img
,
0
,
255
)
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
return
img
class
MeanVarianceNormalize
(
ImageAugmentor
):
class
MeanVarianceNormalize
(
ImageAugmentor
):
"""
"""
Linearly scales image to have zero mean and unit norm.
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
"""
def
__init__
(
self
,
all_channel
=
True
):
def
__init__
(
self
,
all_channel
=
True
):
"""
"""
:param all_channel: if True, normalize all channels together. else separately.
:param all_channel: if True, normalize all channels together. else separately.
...
@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
...
@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
mean
=
np
.
mean
(
img
)
mean
=
np
.
mean
(
img
)
std
=
np
.
std
(
img
)
std
=
np
.
std
(
img
)
else
:
else
:
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
maximum
(
std
,
1.0
/
np
.
sqrt
(
np
.
prod
(
img
.
shape
)))
std
=
np
.
maximum
(
std
,
1.0
/
np
.
sqrt
(
np
.
prod
(
img
.
shape
)))
img
=
(
img
-
mean
)
/
std
img
=
(
img
-
mean
)
/
std
return
img
return
img
class
GaussianBlur
(
ImageAugmentor
):
class
GaussianBlur
(
ImageAugmentor
):
def
__init__
(
self
,
max_size
=
3
):
def
__init__
(
self
,
max_size
=
3
):
""":params max_size: (maximum kernel size-1)/2"""
""":params max_size: (maximum kernel size-1)/2"""
super
(
GaussianBlur
,
self
)
.
__init__
()
super
(
GaussianBlur
,
self
)
.
__init__
()
...
@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor):
...
@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor):
def
_augment
(
self
,
img
,
s
):
def
_augment
(
self
,
img
,
s
):
return
cv2
.
GaussianBlur
(
img
,
s
,
sigmaX
=
0
,
sigmaY
=
0
,
return
cv2
.
GaussianBlur
(
img
,
s
,
sigmaX
=
0
,
sigmaY
=
0
,
borderType
=
cv2
.
BORDER_REPLICATE
)
borderType
=
cv2
.
BORDER_REPLICATE
)
class
Gamma
(
ImageAugmentor
):
class
Gamma
(
ImageAugmentor
):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
super
(
Gamma
,
self
)
.
__init__
()
super
(
Gamma
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
...
@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
img
=
cv2
.
LUT
(
img
,
lut
)
.
astype
(
'float32'
)
img
=
cv2
.
LUT
(
img
,
lut
)
.
astype
(
'float32'
)
return
img
return
img
class
Clip
(
ImageAugmentor
):
class
Clip
(
ImageAugmentor
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
...
@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
img
=
np
.
clip
(
img
,
self
.
min
,
self
.
max
)
img
=
np
.
clip
(
img
,
self
.
min
,
self
.
max
)
return
img
return
img
class
Saturation
(
ImageAugmentor
):
class
Saturation
(
ImageAugmentor
):
def
__init__
(
self
,
alpha
=
0.4
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
"""
...
@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
...
@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
def
_augment
(
self
,
img
,
v
):
def
_augment
(
self
,
img
,
v
):
grey
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
grey
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,:,
np
.
newaxis
]
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,
:,
np
.
newaxis
]
class
Lighting
(
ImageAugmentor
):
class
Lighting
(
ImageAugmentor
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
...
@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
...
@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
eigval
=
np
.
asarray
(
eigval
)
eigval
=
np
.
asarray
(
eigval
)
eigvec
=
np
.
asarray
(
eigvec
)
eigvec
=
np
.
asarray
(
eigvec
)
assert
eigval
.
shape
==
(
3
,)
assert
eigval
.
shape
==
(
3
,)
assert
eigvec
.
shape
==
(
3
,
3
)
assert
eigvec
.
shape
==
(
3
,
3
)
self
.
_init
(
locals
())
self
.
_init
(
locals
())
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
...
@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
...
@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
inc
=
np
.
dot
(
self
.
eigvec
,
v
)
.
reshape
((
3
,))
inc
=
np
.
dot
(
self
.
eigvec
,
v
)
.
reshape
((
3
,))
img
+=
inc
img
+=
inc
return
img
return
img
tensorpack/dataflow/imgaug/meta.py
View file @
fb2a051c
...
@@ -7,14 +7,18 @@
...
@@ -7,14 +7,18 @@
from
.base
import
ImageAugmentor
from
.base
import
ImageAugmentor
__all__
=
[
'RandomChooseAug'
,
'MapImage'
,
'Identity'
,
'RandomApplyAug'
,
__all__
=
[
'RandomChooseAug'
,
'MapImage'
,
'Identity'
,
'RandomApplyAug'
,
'RandomOrderAug'
]
'RandomOrderAug'
]
class
Identity
(
ImageAugmentor
):
class
Identity
(
ImageAugmentor
):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
return
img
return
img
class
RandomApplyAug
(
ImageAugmentor
):
class
RandomApplyAug
(
ImageAugmentor
):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
def
__init__
(
self
,
aug
,
prob
):
def
__init__
(
self
,
aug
,
prob
):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
super
(
RandomApplyAug
,
self
)
.
__init__
()
super
(
RandomApplyAug
,
self
)
.
__init__
()
...
@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
...
@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
else
:
else
:
return
self
.
aug
.
_augment
(
img
,
prm
[
1
])
return
self
.
aug
.
_augment
(
img
,
prm
[
1
])
class
RandomChooseAug
(
ImageAugmentor
):
class
RandomChooseAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
...
@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
...
@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
idx
,
prm
=
prm
idx
,
prm
=
prm
return
self
.
aug_lists
[
idx
]
.
_augment
(
img
,
prm
)
return
self
.
aug_lists
[
idx
]
.
_augment
(
img
,
prm
)
class
RandomOrderAug
(
ImageAugmentor
):
class
RandomOrderAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
Shuffle the augmentors into random order.
Shuffle the augmentors into random order.
...
@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
...
@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
img
=
self
.
aug_lists
[
k
]
.
_augment
(
img
,
prms
[
k
])
img
=
self
.
aug_lists
[
k
]
.
_augment
(
img
,
prms
[
k
])
return
img
return
img
class
MapImage
(
ImageAugmentor
):
class
MapImage
(
ImageAugmentor
):
"""
"""
Map the image array by a function.
Map the image array by a function.
"""
"""
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
"""
"""
:param func: a function which takes a image array and return a augmented one
:param func: a function which takes a image array and return a augmented one
...
@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
...
@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
return
self
.
func
(
img
)
return
self
.
func
(
img
)
tensorpack/dataflow/imgaug/noise.py
View file @
fb2a051c
...
@@ -9,7 +9,9 @@ import cv2
...
@@ -9,7 +9,9 @@ import cv2
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
class
JpegNoise
(
ImageAugmentor
):
class
JpegNoise
(
ImageAugmentor
):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
super
(
JpegNoise
,
self
)
.
__init__
()
super
(
JpegNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
...
@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
class
GaussianNoise
(
ImageAugmentor
):
class
GaussianNoise
(
ImageAugmentor
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
"""
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
Add a gaussian noise N(0, sigma^2) of the same shape to img.
...
@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
...
@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
ret
=
np
.
clip
(
ret
,
0
,
255
)
ret
=
np
.
clip
(
ret
,
0
,
255
)
return
ret
return
ret
class
SaltPepperNoise
(
ImageAugmentor
):
class
SaltPepperNoise
(
ImageAugmentor
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
""" Salt and pepper noise.
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
Randomly set some elements in img to 0 or 255, regardless of its channels.
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
fb2a051c
...
@@ -10,10 +10,12 @@ import cv2
...
@@ -10,10 +10,12 @@ import cv2
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
]
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
]
class
Flip
(
ImageAugmentor
):
class
Flip
(
ImageAugmentor
):
"""
"""
Random flip.
Random flip.
"""
"""
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
"""
"""
Only one of horiz, vert can be set.
Only one of horiz, vert can be set.
...
@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
...
@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
Resize
(
ImageAugmentor
):
class
Resize
(
ImageAugmentor
):
""" Resize image to a target size"""
""" Resize image to a target size"""
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param shape: shape in (h, w)
:param shape: shape in (h, w)
...
@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
...
@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
img
,
self
.
shape
[::
-
1
],
img
,
self
.
shape
[::
-
1
],
interpolation
=
self
.
interp
)
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
class
ResizeShortestEdge
(
ImageAugmentor
):
class
ResizeShortestEdge
(
ImageAugmentor
):
""" Resize the shortest edge to a certain number while
""" Resize the shortest edge to a certain number while
keeping the aspect ratio
keeping the aspect ratio
"""
"""
def
__init__
(
self
,
size
):
def
__init__
(
self
,
size
):
size
=
size
*
1.0
size
=
size
*
1.0
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor):
...
@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor):
desSize
=
map
(
int
,
[
scale
*
w
,
scale
*
h
])
desSize
=
map
(
int
,
[
scale
*
w
,
scale
*
h
])
ret
=
cv2
.
resize
(
img
,
tuple
(
desSize
),
interpolation
=
cv2
.
INTER_CUBIC
)
ret
=
cv2
.
resize
(
img
,
tuple
(
desSize
),
interpolation
=
cv2
.
INTER_CUBIC
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
class
RandomResize
(
ImageAugmentor
):
class
RandomResize
(
ImageAugmentor
):
""" randomly rescale w and h of the image"""
""" randomly rescale w and h of the image"""
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param xrange: (min, max) scaling ratio
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
...
@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
...
@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
def
_augment
(
self
,
img
,
dsize
):
def
_augment
(
self
,
img
,
dsize
):
ret
=
cv2
.
resize
(
img
,
dsize
,
interpolation
=
self
.
interp
)
ret
=
cv2
.
resize
(
img
,
dsize
,
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
return
ret
tensorpack/dataflow/imgaug/paste.py
View file @
fb2a051c
...
@@ -9,11 +9,12 @@ from abc import abstractmethod
...
@@ -9,11 +9,12 @@ from abc import abstractmethod
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'CenterPaste'
,
'BackgroundFiller'
,
'ConstantBackgroundFiller'
,
__all__
=
[
'CenterPaste'
,
'BackgroundFiller'
,
'ConstantBackgroundFiller'
,
'RandomPaste'
]
'RandomPaste'
]
class
BackgroundFiller
(
object
):
class
BackgroundFiller
(
object
):
""" Base class for all BackgroundFiller"""
""" Base class for all BackgroundFiller"""
def
fill
(
self
,
background_shape
,
img
):
def
fill
(
self
,
background_shape
,
img
):
"""
"""
Return a proper background image of background_shape, given img
Return a proper background image of background_shape, given img
...
@@ -28,8 +29,10 @@ class BackgroundFiller(object):
...
@@ -28,8 +29,10 @@ class BackgroundFiller(object):
def
_fill
(
self
,
background_shape
,
img
):
def
_fill
(
self
,
background_shape
,
img
):
pass
pass
class
ConstantBackgroundFiller
(
BackgroundFiller
):
class
ConstantBackgroundFiller
(
BackgroundFiller
):
""" Fill the background by a constant """
""" Fill the background by a constant """
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
"""
"""
:param value: the value to fill the background.
:param value: the value to fill the background.
...
@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
...
@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
return_shape
=
background_shape
return_shape
=
background_shape
return
np
.
zeros
(
return_shape
)
+
self
.
value
return
np
.
zeros
(
return_shape
)
+
self
.
value
class
CenterPaste
(
ImageAugmentor
):
class
CenterPaste
(
ImageAugmentor
):
"""
"""
Paste the image onto the center of a background canvas.
Paste the image onto the center of a background canvas.
"""
"""
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
"""
"""
:param background_shape: shape of the background canvas.
:param background_shape: shape of the background canvas.
...
@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
...
@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
self
.
background_shape
,
img
)
self
.
background_shape
,
img
)
y0
=
int
((
self
.
background_shape
[
0
]
-
img_shape
[
0
])
*
0.5
)
y0
=
int
((
self
.
background_shape
[
0
]
-
img_shape
[
0
])
*
0.5
)
x0
=
int
((
self
.
background_shape
[
1
]
-
img_shape
[
1
])
*
0.5
)
x0
=
int
((
self
.
background_shape
[
1
]
-
img_shape
[
1
])
*
0.5
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
return
background
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
RandomPaste
(
CenterPaste
):
class
RandomPaste
(
CenterPaste
):
"""
"""
Randomly paste the image onto a background convas
Randomly paste the image onto a background convas
"""
"""
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
assert
self
.
background_shape
[
0
]
>
img_shape
[
0
]
and
self
.
background_shape
[
1
]
>
img_shape
[
1
]
assert
self
.
background_shape
[
0
]
>
img_shape
[
0
]
and
self
.
background_shape
[
1
]
>
img_shape
[
1
]
...
@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
...
@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
background
=
self
.
background_filler
.
fill
(
background
=
self
.
background_filler
.
fill
(
self
.
background_shape
,
img
)
self
.
background_shape
,
img
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
return
background
tensorpack/dataflow/prefetch.py
View file @
fb2a051c
...
@@ -13,7 +13,7 @@ import os
...
@@ -13,7 +13,7 @@ import os
from
.base
import
ProxyDataFlow
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
(
ensure_proc_terminate
,
from
..utils.concurrency
import
(
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
)
mask_sigint
,
start_proc_mask_signal
)
from
..utils.serialize
import
loads
,
dumps
from
..utils.serialize
import
loads
,
dumps
from
..utils
import
logger
from
..utils
import
logger
from
..utils.gpu
import
change_gpu
from
..utils.gpu
import
change_gpu
...
@@ -28,6 +28,7 @@ else:
...
@@ -28,6 +28,7 @@ else:
class
PrefetchProcess
(
mp
.
Process
):
class
PrefetchProcess
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
"""
"""
:param ds: ds to take data from
:param ds: ds to take data from
...
@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
...
@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
self
.
queue
.
put
(
dp
)
class
PrefetchData
(
ProxyDataFlow
):
class
PrefetchData
(
ProxyDataFlow
):
"""
"""
Prefetch data from a `DataFlow` using multiprocessing
Prefetch data from a `DataFlow` using multiprocessing
"""
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
...
@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes
# do nothing. all ds are reset once and only once in spawned processes
pass
pass
def
BlockParallel
(
ds
,
queue_size
):
def
BlockParallel
(
ds
,
queue_size
):
# TODO more doc
# TODO more doc
"""
"""
...
@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
...
@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
"""
"""
return
PrefetchData
(
ds
,
queue_size
,
1
)
return
PrefetchData
(
ds
,
queue_size
,
1
)
class
PrefetchProcessZMQ
(
mp
.
Process
):
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
...
@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
class
PrefetchDataZMQ
(
ProxyDataFlow
):
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
""" Work the same as `PrefetchData`, but faster. """
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
except
:
except
:
pass
pass
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
variable"""
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
self
.
gpus
=
gpus
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
...
@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
...
@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
with
change_gpu
(
gpu
):
with
change_gpu
(
gpu
):
proc
.
start
()
proc
.
start
()
tensorpack/dataflow/raw.py
View file @
fb2a051c
...
@@ -17,8 +17,10 @@ except:
...
@@ -17,8 +17,10 @@ except:
else
:
else
:
__all__
.
append
(
'DataFromSocket'
)
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
class
FakeData
(
RNGDataFlow
):
""" Generate fake fixed data of given shapes"""
""" Generate fake fixed data of given shapes"""
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
"""
"""
:param shapes: a list of lists/tuples
:param shapes: a list of lists/tuples
...
@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
...
@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
for
_
in
range
(
self
.
_size
):
for
_
in
range
(
self
.
_size
):
yield
copy
.
deepcopy
(
v
)
yield
copy
.
deepcopy
(
v
)
class
DataFromQueue
(
DataFlow
):
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
...
@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
while
True
:
while
True
:
yield
self
.
queue
.
get
()
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
def
__init__
(
self
,
lst
,
shuffle
=
True
):
super
(
DataFromList
,
self
)
.
__init__
()
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
self
.
lst
=
lst
...
@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
...
@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
for
k
in
idxs
:
for
k
in
idxs
:
yield
self
.
lst
[
k
]
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
self
.
_name
=
socket_name
...
@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
...
@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
yield
dp
yield
dp
finally
:
finally
:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
0
)
tensorpack/dataflow/remote.py
View file @
fb2a051c
...
@@ -17,6 +17,7 @@ from .common import RepeatedData
...
@@ -17,6 +17,7 @@ from .common import RepeatedData
from
..utils
import
logger
from
..utils
import
logger
from
..utils.serialize
import
dumps
,
loads
from
..utils.serialize
import
dumps
,
loads
def
serve_data
(
ds
,
addr
):
def
serve_data
(
ds
,
addr
):
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
...
@@ -36,7 +37,9 @@ def serve_data(ds, addr):
...
@@ -36,7 +37,9 @@ def serve_data(ds, addr):
if
not
ctx
.
closed
:
if
not
ctx
.
closed
:
ctx
.
destroy
(
0
)
ctx
.
destroy
(
0
)
class
RemoteData
(
DataFlow
):
class
RemoteData
(
DataFlow
):
def
__init__
(
self
,
addr
):
def
__init__
(
self
,
addr
):
self
.
ctx
=
zmq
.
Context
()
self
.
ctx
=
zmq
.
Context
()
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
...
@@ -54,7 +57,7 @@ if __name__ == '__main__':
...
@@ -54,7 +57,7 @@ if __name__ == '__main__':
from
.raw
import
FakeData
from
.raw
import
FakeData
addr
=
"tcp://127.0.0.1:8877"
addr
=
"tcp://127.0.0.1:8877"
if
sys
.
argv
[
1
]
==
'serve'
:
if
sys
.
argv
[
1
]
==
'serve'
:
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
serve_data
(
ds
,
addr
)
serve_data
(
ds
,
addr
)
else
:
else
:
ds
=
RemoteData
(
addr
)
ds
=
RemoteData
(
addr
)
...
@@ -62,4 +65,3 @@ if __name__ == '__main__':
...
@@ -62,4 +65,3 @@ if __name__ == '__main__':
with
tqdm
(
total
=
10000
)
as
pbar
:
with
tqdm
(
total
=
10000
)
as
pbar
:
for
k
in
ds
.
get_data
():
for
k
in
ds
.
get_data
():
pbar
.
update
()
pbar
.
update
()
tensorpack/dataflow/tf_func.py
View file @
fb2a051c
...
@@ -14,9 +14,11 @@ except ImportError:
...
@@ -14,9 +14,11 @@ except ImportError:
else
:
else
:
__all__
=
[
'TFFuncMapper'
]
__all__
=
[
'TFFuncMapper'
]
class
TFFuncMapper
(
ProxyDataFlow
):
class
TFFuncMapper
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
def
__init__
(
self
,
ds
,
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
"""
"""
:param get_placeholders: a function returning the placeholders
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
:param symbf: a symbolic function taking the placeholders
...
@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow):
...
@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow):
def
run_func
(
vals
):
def
run_func
(
vals
):
return
self
.
sess
.
run
(
self
.
output_vars
,
return
self
.
sess
.
run
(
self
.
output_vars
,
feed_dict
=
dict
(
zip
(
self
.
placeholders
,
vals
)))
feed_dict
=
dict
(
zip
(
self
.
placeholders
,
vals
)))
self
.
run_func
=
run_func
self
.
run_func
=
run_func
def
get_data
(
self
):
def
get_data
(
self
):
...
@@ -63,16 +65,16 @@ if __name__ == '__main__':
...
@@ -63,16 +65,16 @@ if __name__ == '__main__':
v
=
tf
.
image
.
random_flip_left_right
(
v
)
v
=
tf
.
image
.
random_flip_left_right
(
v
)
return
v
return
v
ds
=
TFFuncMapper
(
ds
,
ds
=
TFFuncMapper
(
ds
,
lambda
:
[
tf
.
placeholder
(
tf
.
float32
,
[
224
,
224
,
3
],
name
=
'img'
)],
lambda
:
[
tf
.
placeholder
(
tf
.
float32
,
[
224
,
224
,
3
],
name
=
'img'
)],
tf_aug
,
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
)
#ds = AugmentImageComponent(ds,
#
ds = AugmentImageComponent(ds,
#
[imgaug.Brightness(0.1, clip=False),
#
[imgaug.Brightness(0.1, clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Flip(horiz=True)
#
imgaug.Flip(horiz=True)
#
])
#
])
#ds = PrefetchDataZMQ(ds, 4)
#
ds = PrefetchDataZMQ(ds, 4)
ds
.
reset_state
()
ds
.
reset_state
()
import
tqdm
import
tqdm
...
...
tensorpack/models/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..utils import logger
...
@@ -12,6 +12,7 @@ from ..utils import logger
__all__
=
[
'LinearWrap'
]
__all__
=
[
'LinearWrap'
]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -32,6 +33,7 @@ class LinearWrap(object):
...
@@ -32,6 +33,7 @@ class LinearWrap(object):
"""
"""
class
TFModuleFunc
(
object
):
class
TFModuleFunc
(
object
):
def
__init__
(
self
,
mod
,
tensor
):
def
__init__
(
self
,
mod
,
tensor
):
self
.
_mod
=
mod
self
.
_mod
=
mod
self
.
_t
=
tensor
self
.
_t
=
tensor
...
@@ -88,4 +90,3 @@ class LinearWrap(object):
...
@@ -88,4 +90,3 @@ class LinearWrap(object):
def
print_tensor
(
self
):
def
print_tensor
(
self
):
print
(
self
.
_t
)
print
(
self
.
_t
)
return
self
return
self
tensorpack/models/_common.py
View file @
fb2a051c
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
functools
import
wraps
from
functools
import
wraps
import
six
import
six
import
copy
,
os
import
copy
import
os
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.modelutils
import
get_shape_str
from
..tfutils.modelutils
import
get_shape_str
...
@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
...
@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
# make sure each layer is only logged once
# make sure each layer is only logged once
_layer_logged
=
set
()
_layer_logged
=
set
()
def
disable_layer_logging
():
def
disable_layer_logging
():
class
ContainEverything
:
class
ContainEverything
:
def
__contains__
(
self
,
x
):
def
__contains__
(
self
,
x
):
return
True
return
True
# can use nonlocal in python3, but how
# can use nonlocal in python3, but how
globals
()[
'_layer_logged'
]
=
ContainEverything
()
globals
()[
'_layer_logged'
]
=
ContainEverything
()
def
layer_register
(
def
layer_register
(
summary_activation
=
False
,
summary_activation
=
False
,
log_shape
=
True
,
log_shape
=
True
,
...
@@ -42,13 +46,13 @@ def layer_register(
...
@@ -42,13 +46,13 @@ def layer_register(
def
wrapped_func
(
*
args
,
**
kwargs
):
def
wrapped_func
(
*
args
,
**
kwargs
):
if
use_scope
:
if
use_scope
:
name
,
inputs
=
args
[
0
],
args
[
1
]
name
,
inputs
=
args
[
0
],
args
[
1
]
args
=
args
[
1
:]
# actual positional args used to call func
args
=
args
[
1
:]
# actual positional args used to call func
assert
isinstance
(
name
,
six
.
string_types
),
name
assert
isinstance
(
name
,
six
.
string_types
),
name
else
:
else
:
assert
not
log_shape
and
not
summary_activation
assert
not
log_shape
and
not
summary_activation
if
isinstance
(
args
[
0
],
six
.
string_types
):
if
isinstance
(
args
[
0
],
six
.
string_types
):
name
,
inputs
=
args
[
0
],
args
[
1
]
name
,
inputs
=
args
[
0
],
args
[
1
]
args
=
args
[
1
:]
# actual positional args used to call func
args
=
args
[
1
:]
# actual positional args used to call func
else
:
else
:
inputs
=
args
[
0
]
inputs
=
args
[
0
]
name
=
None
name
=
None
...
@@ -97,13 +101,14 @@ def layer_register(
...
@@ -97,13 +101,14 @@ def layer_register(
# need some special handling for sphinx to work with the arguments
# need some special handling for sphinx to work with the arguments
on_doc
=
os
.
environ
.
get
(
'READTHEDOCS'
)
==
'True'
\
on_doc
=
os
.
environ
.
get
(
'READTHEDOCS'
)
==
'True'
\
or
os
.
environ
.
get
(
'TENSORPACK_DOC_BUILDING'
)
or
os
.
environ
.
get
(
'TENSORPACK_DOC_BUILDING'
)
if
on_doc
:
if
on_doc
:
from
decorator
import
decorator
from
decorator
import
decorator
wrapper
=
decorator
(
wrapper
)
wrapper
=
decorator
(
wrapper
)
return
wrapper
return
wrapper
def
shape4d
(
a
):
def
shape4d
(
a
):
# for use with tensorflow NHWC ops
# for use with tensorflow NHWC ops
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
tensorpack/models/_test.py
View file @
fb2a051c
...
@@ -7,7 +7,9 @@ import tensorflow as tf
...
@@ -7,7 +7,9 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
class
TestModel
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
def
run_variable
(
self
,
var
):
def
run_variable
(
self
,
var
):
sess
=
tf
.
Session
()
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
...
@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
...
@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
else
:
else
:
return
tf
.
Variable
(
args
[
0
])
return
tf
.
Variable
(
args
[
0
])
def
run_test_case
(
case
):
def
run_test_case
(
case
):
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
...
@@ -34,5 +37,3 @@ if __name__ == '__main__':
...
@@ -34,5 +37,3 @@ if __name__ == '__main__':
subs
=
tensorpack
.
models
.
_test
.
TestModel
.
__subclasses__
()
subs
=
tensorpack
.
models
.
_test
.
TestModel
.
__subclasses__
()
for
cls
in
subs
:
for
cls
in
subs
:
run_test_case
(
cls
)
run_test_case
(
cls
)
tensorpack/models/batch_norm.py
View file @
fb2a051c
...
@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
...
@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
"""
...
@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out
=
shape
[
-
1
]
# channel
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
assert
n_out
is
not
None
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
initializer
=
tf
.
constant_initializer
())
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
initializer
=
tf
.
constant_initializer
(
1.0
))
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
...
@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
#reuse = tf.get_variable_scope().reuse
#reuse = tf.get_variable_scope().reuse
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics
# TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
# how to use multiple tensors to update one EMA? seems impossbile
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
...
@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
else
:
#
#
use statistics in another tower
# use statistics in another tower
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
...
@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
"""
...
@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
initializer
=
tf
.
constant_initializer
())
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
initializer
=
tf
.
constant_initializer
(
1.0
))
# x * gamma + beta
# x * gamma + beta
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
...
@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
if
use_local_stat
:
if
use_local_stat
:
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
)
epsilon
=
epsilon
,
is_training
=
True
)
# maintain EMA only in the main training tower
# maintain EMA only in the main training tower
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
:
update_op1
=
moving_averages
.
assign_moving_average
(
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
name
=
'var_ema_op'
)
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
add_model_variable
(
moving_var
)
else
:
else
:
...
@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# consider some fixed-param tasks, such as load model and fine tune one layer
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference
# fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
moving_mean, moving_var,
#
moving_mean, moving_var,
#
epsilon=epsilon, is_training=False, name='output')
#
epsilon=epsilon, is_training=False, name='output')
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
...
...
tensorpack/models/conv2d.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
...
@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
@
layer_register
()
@
layer_register
()
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'SAME'
,
stride
=
1
,
padding
=
'SAME'
,
stride
=
1
,
...
@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
for
i
,
k
in
zip
(
inputs
,
kernels
)]
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
3
,
outputs
)
conv
=
tf
.
concat
(
3
,
outputs
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
class
StaticDynamicShape
(
object
):
class
StaticDynamicShape
(
object
):
def
__init__
(
self
,
static
,
dynamic
):
def
__init__
(
self
,
static
,
dynamic
):
self
.
static
=
static
self
.
static
=
static
self
.
dynamic
=
dynamic
self
.
dynamic
=
dynamic
def
apply
(
self
,
f
):
def
apply
(
self
,
f
):
try
:
try
:
st
=
f
(
self
.
static
)
st
=
f
(
self
.
static
)
...
@@ -76,11 +81,12 @@ class StaticDynamicShape(object):
...
@@ -76,11 +81,12 @@ class StaticDynamicShape(object):
except
:
except
:
return
StaticDynamicShape
(
None
,
f
(
self
.
dynamic
))
return
StaticDynamicShape
(
None
,
f
(
self
.
dynamic
))
@
layer_register
()
@
layer_register
()
def
Deconv2D
(
x
,
out_shape
,
kernel_shape
,
def
Deconv2D
(
x
,
out_shape
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
):
nl
=
tf
.
identity
,
use_bias
=
True
):
"""
"""
2D deconvolution on 4D inputs.
2D deconvolution on 4D inputs.
...
...
tensorpack/models/fc.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
...
@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'FullyConnected'
]
__all__
=
[
'FullyConnected'
]
@
layer_register
()
@
layer_register
()
def
FullyConnected
(
x
,
out_dim
,
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
W_init
=
None
,
b_init
=
None
,
...
@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
...
@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
prod
,
name
=
'output'
)
return
nl
(
prod
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
fb2a051c
...
@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
...
@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
# XXX TODO ugly.
# XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing
# really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206
# See github:tensorflow#418,#206
def
sample
(
img
,
coords
):
def
sample
(
img
,
coords
):
"""
"""
:param img: bxhxwxc
:param img: bxhxwxc
...
@@ -33,14 +35,15 @@ def sample(img, coords):
...
@@ -33,14 +35,15 @@ def sample(img, coords):
# bxh2xw2
# bxh2xw2
batch_add
=
tf
.
range
(
tf
.
shape
(
img
)[
0
])
*
(
shape
[
0
]
*
shape
[
1
])
batch_add
=
tf
.
range
(
tf
.
shape
(
img
)[
0
])
*
(
shape
[
0
]
*
shape
[
1
])
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
flat_coords
=
coords
+
batch_add
flat_coords
=
coords
+
batch_add
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
sampled
=
tf
.
gather
(
img
,
flat_coords
)
sampled
=
tf
.
gather
(
img
,
flat_coords
)
return
sampled
return
sampled
@
layer_register
()
@
layer_register
()
def
ImageSample
(
inputs
,
borderMode
=
'repeat'
):
def
ImageSample
(
inputs
,
borderMode
=
'repeat'
):
"""
"""
...
@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'):
assert
template
.
get_shape
()
.
ndims
==
4
and
mapping
.
get_shape
()
.
ndims
==
4
assert
template
.
get_shape
()
.
ndims
==
4
and
mapping
.
get_shape
()
.
ndims
==
4
input_shape
=
template
.
get_shape
()
.
as_list
()[
1
:]
input_shape
=
template
.
get_shape
()
.
as_list
()[
1
:]
assert
None
not
in
input_shape
,
\
assert
None
not
in
input_shape
,
\
"Images in ImageSample layer must have fully-defined shape"
"Images in ImageSample layer must have fully-defined shape"
assert
borderMode
in
[
'repeat'
,
'constant'
]
assert
borderMode
in
[
'repeat'
,
'constant'
]
orig_mapping
=
mapping
orig_mapping
=
mapping
...
@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
ucoor
=
lcoor
+
1
ucoor
=
lcoor
+
1
diff
=
mapping
-
lcoor
diff
=
mapping
-
lcoor
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
...
@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'):
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
tf.reduce_max(diff), diff], summarize=50)
#
tf.reduce_max(diff), diff], summarize=50)
ret
=
tf
.
add_n
([
sample
(
template
,
lcoor
)
*
neg_diffx
*
neg_diffy
,
ret
=
tf
.
add_n
([
sample
(
template
,
lcoor
)
*
neg_diffx
*
neg_diffy
,
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
sample
(
template
,
lyux
)
*
neg_diffy
*
diffx
,
sample
(
template
,
lyux
)
*
neg_diffy
*
diffx
,
sample
(
template
,
uylx
)
*
diffy
*
neg_diffx
],
name
=
'sampled'
)
sample
(
template
,
uylx
)
*
diffy
*
neg_diffx
],
name
=
'sampled'
)
if
borderMode
==
'constant'
:
if
borderMode
==
'constant'
:
max_coor
=
tf
.
constant
([
input_shape
[
0
]
-
1
,
input_shape
[
1
]
-
1
],
dtype
=
tf
.
float32
)
max_coor
=
tf
.
constant
([
input_shape
[
0
]
-
1
,
input_shape
[
1
]
-
1
],
dtype
=
tf
.
float32
)
mask
=
tf
.
greater_equal
(
orig_mapping
,
0.0
)
mask
=
tf
.
greater_equal
(
orig_mapping
,
0.0
)
mask2
=
tf
.
less_equal
(
orig_mapping
,
max_coor
)
mask2
=
tf
.
less_equal
(
orig_mapping
,
max_coor
)
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
expand_dims
(
mask
,
3
)
mask
=
tf
.
expand_dims
(
mask
,
3
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
return
ret
return
ret
from
._test
import
TestModel
from
._test
import
TestModel
class
TestSample
(
TestModel
):
class
TestSample
(
TestModel
):
def
test_sample
(
self
):
def
test_sample
(
self
):
import
numpy
as
np
import
numpy
as
np
h
,
w
=
3
,
4
h
,
w
=
3
,
4
def
np_sample
(
img
,
coords
):
def
np_sample
(
img
,
coords
):
# a reference implementation
# a reference implementation
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
xs
=
coords
[:,
:,:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
xs
=
coords
[:,
:,
:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,
:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ret
=
np
.
zeros
((
img
.
shape
[
0
],
coords
.
shape
[
1
],
coords
.
shape
[
2
],
ret
=
np
.
zeros
((
img
.
shape
[
0
],
coords
.
shape
[
1
],
coords
.
shape
[
2
],
img
.
shape
[
3
]),
dtype
=
'float32'
)
img
.
shape
[
3
]),
dtype
=
'float32'
)
for
k
in
range
(
img
.
shape
[
0
]):
for
k
in
range
(
img
.
shape
[
0
]):
xss
,
yss
=
xs
[
k
],
ys
[
k
]
xss
,
yss
=
xs
[
k
],
ys
[
k
]
ret
[
k
,
:,:,:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
ret
[
k
,
:,
:,
:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
coords
.
shape
[
2
],
3
))
coords
.
shape
[
2
],
3
))
return
ret
return
ret
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
#mat = np.array([
#
mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
#], dtype='float32') #2x2x2x2
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
inp
,
mapping
=
self
.
make_variable
(
bimg
,
mat
)
inp
,
mapping
=
self
.
make_variable
(
bimg
,
mat
)
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
res
=
self
.
run_variable
(
output
)
res
=
self
.
run_variable
(
output
)
self
.
assertTrue
((
res
==
true_res
)
.
all
())
self
.
assertTrue
((
res
==
true_res
)
.
all
())
...
@@ -146,7 +153,7 @@ if __name__ == '__main__':
...
@@ -146,7 +153,7 @@ if __name__ == '__main__':
diff
=
200
diff
=
200
for
x
in
range
(
w
):
for
x
in
range
(
w
):
for
y
in
range
(
h
):
for
y
in
range
(
h
):
mapping
[
0
,
y
,
x
,:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapping
[
0
,
y
,
x
,
:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapv
=
tf
.
Variable
(
mapping
)
mapv
=
tf
.
Variable
(
mapping
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
],
borderMode
=
'constant'
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
],
borderMode
=
'constant'
)
...
@@ -155,12 +162,10 @@ if __name__ == '__main__':
...
@@ -155,12 +162,10 @@ if __name__ == '__main__':
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#out = sess.run(output)
#print(out[0].min())
#
print(out[0].min())
#print(out[0].max())
#
print(out[0].max())
#print(out[0].sum())
#
print(out[0].sum())
out
=
sess
.
run
([
output
])[
0
]
out
=
sess
.
run
([
output
])[
0
]
im
=
out
[
0
]
im
=
out
[
0
]
cv2
.
imwrite
(
'sampled.jpg'
,
im
)
cv2
.
imwrite
(
'sampled.jpg'
,
im
)
tensorpack/models/model_desc.py
View file @
fb2a051c
...
@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
...
@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class
InputVar
(
object
):
class
InputVar
(
object
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
self
.
type
=
type
self
.
type
=
type
self
.
shape
=
shape
self
.
shape
=
shape
self
.
name
=
name
self
.
name
=
name
self
.
sparse
=
sparse
self
.
sparse
=
sparse
def
dumps
(
self
):
def
dumps
(
self
):
return
pickle
.
dumps
(
self
)
return
pickle
.
dumps
(
self
)
@
staticmethod
@
staticmethod
def
loads
(
buf
):
def
loads
(
buf
):
return
pickle
.
loads
(
buf
)
return
pickle
.
loads
(
buf
)
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description """
""" Base class for a model description """
...
@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
...
@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def
get_gradient_processor
(
self
):
def
get_gradient_processor
(
self
):
""" Return a list of GradientProcessor. They will be executed in order"""
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
#SummaryGradient(),
return
[
# SummaryGradient(),
CheckGradient
()
CheckGradient
()
]
]
class
ModelFromMetaGraph
(
ModelDesc
):
class
ModelFromMetaGraph
(
ModelDesc
):
"""
"""
Load the whole exact TF graph from a saved meta_graph.
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
Only useful for inference.
"""
"""
def
__init__
(
self
,
filename
):
def
__init__
(
self
,
filename
):
tf
.
train
.
import_meta_graph
(
filename
)
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
()
.
VARIABLES
]:
tf
.
GraphKeys
()
.
VARIABLES
]:
assert
k
in
all_coll
,
\
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
"Collection {} not found in metagraph!"
.
format
(
k
)
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
col
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
col
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
...
...
tensorpack/models/nonlin.py
View file @
fb2a051c
...
@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
...
@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
@
layer_register
()
@
layer_register
()
def
Maxout
(
x
,
num_unit
):
def
Maxout
(
x
,
num_unit
):
"""
"""
...
@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
...
@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
x
=
tf
.
reshape
(
x
,
[
-
1
,
ch
/
num_unit
,
num_unit
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
ch
/
num_unit
,
num_unit
])
return
tf
.
reduce_max
(
x
,
ndim
,
name
=
'output'
)
return
tf
.
reduce_max
(
x
,
ndim
,
name
=
'output'
)
@
layer_register
(
log_shape
=
False
)
@
layer_register
(
log_shape
=
False
)
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
"""
"""
...
@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
...
@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name
=
'output'
name
=
'output'
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
"""
"""
...
@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
...
@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
#alpha = float(alpha)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
# return tf.mul(x, 0.5, name=name)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
def
BNReLU
(
x
,
name
=
None
):
def
BNReLU
(
x
,
name
=
None
):
...
...
tensorpack/models/pool.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
...
@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
'BilinearUpSample'
]
@
layer_register
()
@
layer_register
()
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
"""
...
@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
...
@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
@
layer_register
()
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
"""
...
@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
...
@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
avg_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
return
tf
.
nn
.
avg_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
@
layer_register
()
def
GlobalAvgPooling
(
x
):
def
GlobalAvgPooling
(
x
):
"""
"""
...
@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
...
@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
return
tf
.
reduce_mean
(
x
,
[
1
,
2
])
return
tf
.
reduce_mean
(
x
,
[
1
,
2
])
# https://github.com/tensorflow/tensorflow/issues/2169
# https://github.com/tensorflow/tensorflow/issues/2169
def
UnPooling2x2ZeroFilled
(
x
):
def
UnPooling2x2ZeroFilled
(
x
):
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)])
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)])
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)])
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)])
...
@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
...
@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
return
ret
return
ret
@
layer_register
()
@
layer_register
()
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
"""
"""
...
@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
...
@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
# perform a tensor-matrix kronecker product
# perform a tensor-matrix kronecker product
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
...
@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
...
@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
return
prod
return
prod
@
layer_register
()
@
layer_register
()
def
BilinearUpSample
(
x
,
shape
):
def
BilinearUpSample
(
x
,
shape
):
"""
"""
...
@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
...
@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
:param shape: an integer, the upsample factor
:param shape: an integer, the upsample factor
"""
"""
#inp_shape = tf.shape(x)
#inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x,
#
return tf.image.resize_bilinear(x,
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
align_corners=True)
#
align_corners=True)
inp_shape
=
x
.
get_shape
()
.
as_list
()
inp_shape
=
x
.
get_shape
()
.
as_list
()
ch
=
inp_shape
[
3
]
ch
=
inp_shape
[
3
]
...
@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
...
@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
shape
=
int
(
shape
)
shape
=
int
(
shape
)
filter_shape
=
2
*
shape
filter_shape
=
2
*
shape
def
bilinear_conv_filler
(
s
):
def
bilinear_conv_filler
(
s
):
"""
"""
s: width, height of the conv filter
s: width, height of the conv filter
...
@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
...
@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
ret
=
np
.
zeros
((
s
,
s
),
dtype
=
'float32'
)
ret
=
np
.
zeros
((
s
,
s
),
dtype
=
'float32'
)
for
x
in
range
(
s
):
for
x
in
range
(
s
):
for
y
in
range
(
s
):
for
y
in
range
(
s
):
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
return
ret
return
ret
w
=
bilinear_conv_filler
(
filter_shape
)
w
=
bilinear_conv_filler
(
filter_shape
)
w
=
np
.
repeat
(
w
,
ch
*
ch
)
.
reshape
((
filter_shape
,
filter_shape
,
ch
,
ch
))
w
=
np
.
repeat
(
w
,
ch
*
ch
)
.
reshape
((
filter_shape
,
filter_shape
,
ch
,
ch
))
...
@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape):
...
@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape):
shape
=
(
filter_shape
,
filter_shape
,
ch
,
ch
),
shape
=
(
filter_shape
,
filter_shape
,
ch
,
ch
),
name
=
'bilinear_upsample_filter'
)
name
=
'bilinear_upsample_filter'
)
deconv
=
tf
.
nn
.
conv2d_transpose
(
x
,
weight_var
,
deconv
=
tf
.
nn
.
conv2d_transpose
(
x
,
weight_var
,
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
[
1
,
shape
,
shape
,
1
],
'SAME'
)
[
1
,
shape
,
shape
,
1
],
'SAME'
)
if
inp_shape
[
1
]:
inp_shape
[
1
]
*=
shape
if
inp_shape
[
1
]:
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
inp_shape
[
1
]
*=
shape
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
deconv
.
set_shape
(
inp_shape
)
deconv
.
set_shape
(
inp_shape
)
return
deconv
return
deconv
from
._test
import
TestModel
from
._test
import
TestModel
class
TestPool
(
TestModel
):
class
TestPool
(
TestModel
):
def
test_fixed_unpooling
(
self
):
def
test_fixed_unpooling
(
self
):
h
,
w
=
3
,
4
h
,
w
=
3
,
4
mat
=
np
.
random
.
rand
(
h
,
w
,
3
)
.
astype
(
'float32'
)
mat
=
np
.
random
.
rand
(
h
,
w
,
3
)
.
astype
(
'float32'
)
...
@@ -173,13 +184,13 @@ class TestPool(TestModel):
...
@@ -173,13 +184,13 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
3
])
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
3
])
output
=
FixedUnPooling
(
'unpool'
,
inp
,
2
)
output
=
FixedUnPooling
(
'unpool'
,
inp
,
2
)
res
=
self
.
run_variable
(
output
)
res
=
self
.
run_variable
(
output
)
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
# mat is on cornser
# mat is on cornser
ele
=
res
[
0
,
::
2
,::
2
,
0
]
ele
=
res
[
0
,
::
2
,
::
2
,
0
]
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
# the rest are zeros
# the rest are zeros
res
[
0
,
::
2
,::
2
,
:]
=
0
res
[
0
,
::
2
,
::
2
,
:]
=
0
self
.
assertTrue
((
res
==
0
)
.
all
())
self
.
assertTrue
((
res
==
0
)
.
all
())
def
test_upsample
(
self
):
def
test_upsample
(
self
):
...
@@ -191,7 +202,7 @@ class TestPool(TestModel):
...
@@ -191,7 +202,7 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
1
])
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
1
])
output
=
BilinearUpSample
(
'upsample'
,
inp
,
scale
)
output
=
BilinearUpSample
(
'upsample'
,
inp
,
scale
)
res
=
self
.
run_variable
(
output
)[
0
,
:,:,
0
]
res
=
self
.
run_variable
(
output
)[
0
,
:,
:,
0
]
from
skimage.transform
import
rescale
from
skimage.transform
import
rescale
res2
=
rescale
(
mat
,
scale
)
res2
=
rescale
(
mat
,
scale
)
...
@@ -199,9 +210,9 @@ class TestPool(TestModel):
...
@@ -199,9 +210,9 @@ class TestPool(TestModel):
diff
=
np
.
abs
(
res2
-
res
)
diff
=
np
.
abs
(
res2
-
res
)
# not equivalent to rescale on edge?
# not equivalent to rescale on edge?
diff
[
0
,:]
=
0
diff
[
0
,
:]
=
0
diff
[:,
0
]
=
0
diff
[:,
0
]
=
0
if
not
diff
.
max
()
<
1e-4
:
if
not
diff
.
max
()
<
1e-4
:
import
IPython
;
import
IPython
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
self
.
assertTrue
(
diff
.
max
()
<
1e-4
)
self
.
assertTrue
(
diff
.
max
()
<
1e-4
)
tensorpack/models/regularize.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ from ._common import layer_register
...
@@ -12,6 +12,7 @@ from ._common import layer_register
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
@
memoized
@
memoized
def
_log_regularizer
(
name
):
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
...
@@ -19,6 +20,7 @@ def _log_regularizer(name):
...
@@ -19,6 +20,7 @@ def _log_regularizer(name):
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
def
regularize_cost
(
regex
,
func
,
name
=
None
):
def
regularize_cost
(
regex
,
func
,
name
=
None
):
"""
"""
Apply a regularizer on every trainable variable matching the regex.
Apply a regularizer on every trainable variable matching the regex.
...
@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
...
@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
is_training
=
get_current_tower_context
()
.
is_training
is_training
=
get_current_tower_context
()
.
is_training
keep_prob
=
tf
.
constant
(
keep_prob
if
is_training
else
1.0
)
keep_prob
=
tf
.
constant
(
keep_prob
if
is_training
else
1.0
)
return
tf
.
nn
.
dropout
(
x
,
keep_prob
)
return
tf
.
nn
.
dropout
(
x
,
keep_prob
)
tensorpack/models/shapes.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ from ._common import layer_register
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'ConcatWith'
]
__all__
=
[
'ConcatWith'
]
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
ConcatWith
(
x
,
dim
,
tensor
):
def
ConcatWith
(
x
,
dim
,
tensor
):
"""
"""
...
...
tensorpack/models/softmax.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ from ._common import layer_register
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'SoftMax'
]
__all__
=
[
'SoftMax'
]
@
layer_register
()
@
layer_register
()
def
SoftMax
(
x
,
use_temperature
=
False
,
temperature_init
=
1.0
):
def
SoftMax
(
x
,
use_temperature
=
False
,
temperature_init
=
1.0
):
"""
"""
...
@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0):
...
@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0):
"""
"""
if
use_temperature
:
if
use_temperature
:
t
=
tf
.
get_variable
(
'invtemp'
,
[],
t
=
tf
.
get_variable
(
'invtemp'
,
[],
initializer
=
tf
.
constant_initializer
(
1.0
/
float
(
temperature_init
)))
initializer
=
tf
.
constant_initializer
(
1.0
/
float
(
temperature_init
)))
x
=
x
*
t
x
=
x
*
t
return
tf
.
nn
.
softmax
(
x
,
name
=
'output'
)
return
tf
.
nn
.
softmax
(
x
,
name
=
'output'
)
tensorpack/predict/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ import os.path
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
if
module_name
.
startswith
(
'_'
):
continue
continue
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/predict/base.py
View file @
fb2a051c
...
@@ -12,9 +12,10 @@ from ..utils import logger
...
@@ -12,9 +12,10 @@ from ..utils import logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
'AsyncPredictorBase'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
'DataParallelOfflinePredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
PredictorBase
(
object
):
class
PredictorBase
(
object
):
...
@@ -46,7 +47,9 @@ class PredictorBase(object):
...
@@ -46,7 +47,9 @@ class PredictorBase(object):
:return: output as defined by the config
:return: output as defined by the config
"""
"""
class
AsyncPredictorBase
(
PredictorBase
):
class
AsyncPredictorBase
(
PredictorBase
):
@
abstractmethod
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
"""
...
@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
# in Tornado, Future.result() doesn't wait
# in Tornado, Future.result() doesn't wait
return
fut
.
result
()
return
fut
.
result
()
class
OnlinePredictor
(
PredictorBase
):
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
self
.
session
=
sess
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
return_input
=
return_input
...
@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
...
@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
class
OfflinePredictor
(
OnlinePredictor
):
""" Build a predictor from a given config, in an independent graph"""
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
...
@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor):
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
super
(
OfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
...
@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
...
@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
"""
"""
for
k
in
towers
:
for
k
in
towers
:
logger
.
info
(
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
build_tower_fn
(
k
)
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
self
.
predictors
=
[]
...
@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
for
k
in
towers
:
output_vars
=
get_tensors_by_names
(
output_vars
=
get_tensors_by_names
(
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
\
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
for
n
in
config
.
output_names
])
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictors
(
self
,
n
):
def
get_predictors
(
self
,
n
):
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
...
@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
for
k
in
towers
:
towername
=
PREDICT_TOWER
+
str
(
k
)
towername
=
PREDICT_TOWER
+
str
(
k
)
input_vars
=
config
.
model
.
build_placeholders
(
input_vars
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'-'
)
prefix
=
towername
+
'-'
)
logger
.
info
(
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
is_training
=
False
):
config
.
model
.
build_graph
(
input_vars
)
config
.
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
\
[
towername
+
'/'
+
n
for
n
in
config
.
output_names
]))
for
n
in
config
.
output_names
]))
input_vars
=
get_tensors_by_names
(
input_var_names
)
input_vars
=
get_tensors_by_names
(
input_var_names
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
tensorpack/predict/common.py
View file @
fb2a051c
...
@@ -15,11 +15,13 @@ from .base import OfflinePredictor
...
@@ -15,11 +15,13 @@ from .base import OfflinePredictor
import
multiprocessing
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
The config used by `get_predict_func`.
The config used by `get_predict_func`.
...
@@ -61,12 +63,14 @@ class PredictConfig(object):
...
@@ -61,12 +63,14 @@ class PredictConfig(object):
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_names
),
self
.
output_names
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Produce a offline predictor run inside a new session.
Produce a offline predictor run inside a new session.
...
@@ -76,4 +80,3 @@ def get_predict_func(config):
...
@@ -76,4 +80,3 @@ def get_predict_func(config):
a list of output values defined in ``config.output_var_names``.
a list of output values defined in ``config.output_var_names``.
"""
"""
return
OfflinePredictor
(
config
)
return
OfflinePredictor
(
config
)
tensorpack/predict/concurrency.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: concurrency.py
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
,
threading
import
multiprocessing
import
threading
import
tensorflow
as
tf
import
tensorflow
as
tf
import
time
import
time
import
six
import
six
...
@@ -25,10 +26,12 @@ except ImportError:
...
@@ -25,10 +26,12 @@ except ImportError:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
else
:
else
:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
'MultiThreadAsyncPredictor'
]
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
""" Base class for predict worker that runs offline in multiprocess"""
""" Base class for predict worker that runs offline in multiprocess"""
def
__init__
(
self
,
idx
,
config
):
def
__init__
(
self
,
idx
,
config
):
"""
"""
:param idx: index of the worker. the 0th worker will print log.
:param idx: index of the worker. the 0th worker will print log.
...
@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
with
self
.
predictor
.
graph
.
as_default
():
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
describe_model
()
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" An offline predictor worker that takes input and produces output by queue"""
""" An offline predictor worker that takes input and produces output by queue"""
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
"""
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param inqueue: input queue to get data point. elements are (task_id, dp)
...
@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class
PredictorWorkerThread
(
threading
.
Thread
):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
while
True
:
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
outputs
=
self
.
func
(
batched
)
#print "Worker {} batched {} Queue {}".format(
#
print "Worker {} batched {} Queue {}".format(
#
self.id, len(futures), self.queue.qsize())
#
self.id, len(futures), self.queue.qsize())
# debug, for speed testing
#
debug, for speed testing
#if not hasattr(self, 'xxx'):
#
if not hasattr(self, 'xxx'):
#
self.xxx = outputs = self.func(batched)
#
self.xxx = outputs = self.func(batched)
#else:
#
else:
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
cnt
+=
1
cnt
+=
1
return
batched
,
futures
return
batched
,
futures
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
"""
An multithread online async predictor which run a list of PredictorBase.
An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally.
It would do an extra batching internally.
"""
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
""" :param predictors: a list of OnlinePredictor"""
assert
len
(
predictors
)
assert
len
(
predictors
)
...
@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
#assert isinstance(k, OnlinePredictor), type(k)
#assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
...
...
tensorpack/predict/dataset.py
View file @
fb2a051c
...
@@ -20,10 +20,12 @@ from .common import PredictConfig
...
@@ -20,10 +20,12 @@ from .common import PredictConfig
from
.base
import
OfflinePredictor
from
.base
import
OfflinePredictor
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
'MultiProcessDatasetPredictor'
]
'MultiProcessDatasetPredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DatasetPredictorBase
(
object
):
class
DatasetPredictorBase
(
object
):
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
"""
"""
:param config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
...
@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
...
@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
"""
"""
return
list
(
self
.
get_result
())
return
list
(
self
.
get_result
())
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
"""
"""
Run the predict_config on a given `DataFlow`.
Run the predict_config on a given `DataFlow`.
"""
"""
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
self
.
predictor
=
OfflinePredictor
(
config
)
...
@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
...
@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
sz
=
0
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
for
dp
in
self
.
dataset
.
get_data
():
for
dp
in
self
.
dataset
.
get_data
():
res
=
self
.
predictor
(
dp
)
res
=
self
.
predictor
(
dp
)
yield
res
yield
res
pbar
.
update
()
pbar
.
update
()
# TODO allow unordered
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
...
@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
ordered
=
ordered
self
.
ordered
=
ordered
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
if
use_gpu
:
if
use_gpu
:
try
:
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
assert
len
(
gpus
)
>=
self
.
nr_proc
,
\
assert
len
(
gpus
)
>=
self
.
nr_proc
,
\
"nr_proc={} while only {} gpus available"
.
format
(
"nr_proc={} while only {} gpus available"
.
format
(
self
.
nr_proc
,
len
(
gpus
))
self
.
nr_proc
,
len
(
gpus
))
except
KeyError
:
except
KeyError
:
# TODO number of GPUs not checked
# TODO number of GPUs not checked
gpus
=
list
(
range
(
self
.
nr_proc
))
gpus
=
list
(
range
(
self
.
nr_proc
))
...
@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# worker produces (idx, result) to outqueue
# worker produces (idx, result) to outqueue
self
.
outqueue
=
multiprocessing
.
Queue
()
self
.
outqueue
=
multiprocessing
.
Queue
()
self
.
workers
=
[
MultiProcessQueuePredictWorker
(
self
.
workers
=
[
MultiProcessQueuePredictWorker
(
i
,
self
.
inqueue
,
self
.
outqueue
,
self
.
config
)
i
,
self
.
inqueue
,
self
.
outqueue
,
self
.
config
)
for
i
in
range
(
self
.
nr_proc
)]
for
i
in
range
(
self
.
nr_proc
)]
# start inqueue and workers
# start inqueue and workers
self
.
inqueue_proc
.
start
()
self
.
inqueue_proc
.
start
()
...
@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
if
ordered
:
if
ordered
:
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
self
.
result_queue
.
start
()
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
result_queue
)
ensure_proc_terminate
(
self
.
result_queue
)
else
:
else
:
...
@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
sz
=
0
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
die_cnt
=
0
die_cnt
=
0
while
True
:
while
True
:
res
=
self
.
result_queue
.
get
()
res
=
self
.
result_queue
.
get
()
...
@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
result_queue
.
join
()
self
.
result_queue
.
join
()
self
.
result_queue
.
terminate
()
self
.
result_queue
.
terminate
()
for
p
in
self
.
workers
:
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
p
.
join
()
p
.
terminate
()
tensorpack/tfutils/argscope.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
...
@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack
=
[]
_ArgScopeStack
=
[]
@
contextmanager
@
contextmanager
def
argscope
(
layers
,
**
param
):
def
argscope
(
layers
,
**
param
):
if
not
isinstance
(
layers
,
list
):
if
not
isinstance
(
layers
,
list
):
...
@@ -33,6 +34,7 @@ def argscope(layers, **param):
...
@@ -33,6 +34,7 @@ def argscope(layers, **param):
yield
yield
del
_ArgScopeStack
[
-
1
]
del
_ArgScopeStack
[
-
1
]
def
get_arg_scope
():
def
get_arg_scope
():
"""
"""
:returns: the current argscope.
:returns: the current argscope.
...
...
tensorpack/tfutils/common.py
View file @
fb2a051c
...
@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
...
@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
'freeze_collection'
,
'freeze_collection'
,
'get_tf_version'
]
'get_tf_version'
]
def
get_default_sess_config
(
mem_fraction
=
0.99
):
def
get_default_sess_config
(
mem_fraction
=
0.99
):
"""
"""
Return a better session config to use as default.
Return a better session config to use as default.
...
@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
#conf.log_device_placement = True
#conf.log_device_placement = True
return
conf
return
conf
def
get_global_step_var
():
def
get_global_step_var
():
""" :returns: the global_step variable in the current graph. create if not existed"""
""" :returns: the global_step variable in the current graph. create if not existed"""
try
:
try
:
...
@@ -45,19 +47,21 @@ def get_global_step_var():
...
@@ -45,19 +47,21 @@ def get_global_step_var():
except
KeyError
:
except
KeyError
:
scope
=
tf
.
get_variable_scope
()
scope
=
tf
.
get_variable_scope
()
assert
scope
.
name
==
''
,
\
assert
scope
.
name
==
''
,
\
"Creating global_step_var under a variable scope would cause problems!"
"Creating global_step_var under a variable scope would cause problems!"
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
shape
=
[],
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
dtype
=
tf
.
int32
),
initializer
=
tf
.
constant_initializer
(
dtype
=
tf
.
int32
),
trainable
=
False
,
dtype
=
tf
.
int32
)
trainable
=
False
,
dtype
=
tf
.
int32
)
return
var
return
var
def
get_global_step
():
def
get_global_step
():
""" :returns: global_step value in current graph and session"""
""" :returns: global_step value in current graph and session"""
return
tf
.
train
.
global_step
(
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
get_global_step_var
())
get_global_step_var
())
def
get_op_tensor_name
(
name
):
def
get_op_tensor_name
(
name
):
"""
"""
Tensor name is assumed to be ``op_name + ':0'``
Tensor name is assumed to be ``op_name + ':0'``
...
@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
...
@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
get_op_var_name
=
get_op_tensor_name
get_op_var_name
=
get_op_tensor_name
def
get_tensors_by_names
(
names
):
def
get_tensors_by_names
(
names
):
"""
"""
Get a list of tensors in the default graph by a list of names
Get a list of tensors in the default graph by a list of names
...
@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
...
@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
get_vars_by_names
=
get_tensors_by_names
get_vars_by_names
=
get_tensors_by_names
def
backup_collection
(
keys
):
def
backup_collection
(
keys
):
ret
=
{}
ret
=
{}
for
k
in
keys
:
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
return
ret
def
restore_collection
(
backup
):
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
def
clear_collection
(
keys
):
for
k
in
keys
:
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
@
contextmanager
def
freeze_collection
(
keys
):
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
backup
=
backup_collection
(
keys
)
yield
yield
restore_collection
(
backup
)
restore_collection
(
backup
)
def
get_tf_version
():
def
get_tf_version
():
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
tensorpack/tfutils/gradproc.py
View file @
fb2a051c
...
@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
...
@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'GlobalNormClip'
]
'GlobalNormClip'
]
def
apply_grad_processors
(
grads
,
gradprocs
):
def
apply_grad_processors
(
grads
,
gradprocs
):
"""
"""
:param grads: list of (grad, var).
:param grads: list of (grad, var).
...
@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
...
@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
g
=
proc
.
process
(
g
)
g
=
proc
.
process
(
g
)
return
g
return
g
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
GradientProcessor
(
object
):
class
GradientProcessor
(
object
):
...
@@ -51,6 +53,7 @@ class GradientProcessor(object):
...
@@ -51,6 +53,7 @@ class GradientProcessor(object):
class
GlobalNormClip
(
GradientProcessor
):
class
GlobalNormClip
(
GradientProcessor
):
def
__init__
(
self
,
global_norm
):
def
__init__
(
self
,
global_norm
):
""" Clip by global norm
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
Note that the global norm is the sum of norm for **all** gradients
...
@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
...
@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
g
,
_
=
tf
.
clip_by_global_norm
(
g
,
self
.
_norm
,
name
=
'clip_by_global_norm'
)
g
,
_
=
tf
.
clip_by_global_norm
(
g
,
self
.
_norm
,
name
=
'clip_by_global_norm'
)
return
list
(
zip
(
g
,
v
))
return
list
(
zip
(
g
,
v
))
class
MapGradient
(
GradientProcessor
):
class
MapGradient
(
GradientProcessor
):
"""
"""
Apply a function on all gradient if the name matches regex.
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
Keep the other gradients unchanged.
"""
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
...
@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor):
...
@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor):
args
=
inspect
.
getargspec
(
func
)
.
args
args
=
inspect
.
getargspec
(
func
)
.
args
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
assert
arg_num
in
[
1
,
2
],
\
assert
arg_num
in
[
1
,
2
],
\
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
if
arg_num
==
1
:
if
arg_num
==
1
:
self
.
func
=
lambda
grad
,
var
:
func
(
grad
)
self
.
func
=
lambda
grad
,
var
:
func
(
grad
)
else
:
else
:
...
@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
...
@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
_summaried_gradient
=
set
()
_summaried_gradient
=
set
()
class
SummaryGradient
(
MapGradient
):
class
SummaryGradient
(
MapGradient
):
"""
"""
Summary history and RMS for each graident variable
Summary history and RMS for each graident variable
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
...
@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
return
grad
return
grad
class
CheckGradient
(
MapGradient
):
class
CheckGradient
(
MapGradient
):
"""
"""
Check for numeric issue.
Check for numeric issue.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
...
@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
return
grad
return
grad
class
ScaleGradient
(
MapGradient
):
class
ScaleGradient
(
MapGradient
):
"""
"""
Scale certain gradient by a multiplier
Scale certain gradient by a multiplier
"""
"""
def
__init__
(
self
,
multipliers
,
log
=
True
):
def
__init__
(
self
,
multipliers
,
log
=
True
):
"""
"""
:param multipliers: list of (regex, float)
:param multipliers: list of (regex, float)
...
...
tensorpack/tfutils/modelutils.py
View file @
fb2a051c
...
@@ -9,6 +9,7 @@ from ..utils import logger
...
@@ -9,6 +9,7 @@ from ..utils import logger
__all__
=
[
'describe_model'
,
'get_shape_str'
]
__all__
=
[
'describe_model'
,
'get_shape_str'
]
def
describe_model
():
def
describe_model
():
""" print a description of the current model parameters """
""" print a description of the current model parameters """
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
@@ -40,5 +41,3 @@ def get_shape_str(tensors):
...
@@ -40,5 +41,3 @@ def get_shape_str(tensors):
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
return
shape_str
tensorpack/tfutils/sessinit.py
View file @
fb2a051c
...
@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
...
@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default.
# TODO they initialize_all at the beginning by default.
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
""" Base class for utilities to initialize a session"""
...
@@ -35,23 +36,29 @@ class SessionInit(object):
...
@@ -35,23 +36,29 @@ class SessionInit(object):
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
pass
pass
class
JustCurrentSession
(
SessionInit
):
class
JustCurrentSession
(
SessionInit
):
""" Just use the current default session. This is a no-op placeholder"""
""" Just use the current default session. This is a no-op placeholder"""
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
pass
pass
class
NewSession
(
SessionInit
):
class
NewSession
(
SessionInit
):
"""
"""
Create a new session. All variables will be initialized by their
Create a new session. All variables will be initialized by their
initializer.
initializer.
"""
"""
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
class
SaverRestore
(
SessionInit
):
class
SaverRestore
(
SessionInit
):
"""
"""
Restore an old model saved by `ModelSaver`.
Restore an old model saved by `ModelSaver`.
"""
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
...
@@ -71,7 +78,7 @@ class SaverRestore(SessionInit):
...
@@ -71,7 +78,7 @@ class SaverRestore(SessionInit):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
if
new_path
!=
model_path
:
logger
.
warn
(
logger
.
warn
(
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
self
.
set_path
(
model_path
)
self
.
set_path
(
model_path
)
...
@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
...
@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
return
var_dict
return
var_dict
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
"""
"""
Restore variables from a dictionary.
Restore variables from a dictionary.
"""
"""
def
__init__
(
self
,
param_dict
):
def
__init__
(
self
,
param_dict
):
"""
"""
:param param_dict: a dict of {name: value}
:param param_dict: a dict of {name: value}
...
@@ -158,7 +167,7 @@ class ParamRestore(SessionInit):
...
@@ -158,7 +167,7 @@ class ParamRestore(SessionInit):
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
...
@@ -174,14 +183,15 @@ class ParamRestore(SessionInit):
...
@@ -174,14 +183,15 @@ class ParamRestore(SessionInit):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
\
[
v
for
v
in
variables
if
get_savename_from_varname
(
v
.
name
)
in
intersect
])
get_savename_from_varname
(
v
.
name
)
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
class
ChainInit
(
SessionInit
):
class
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
""" Init a session by a list of SessionInit instance."""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
"""
:params sess_inits: list of `SessionInit` instances.
:params sess_inits: list of `SessionInit` instances.
...
...
tensorpack/tfutils/summary.py
View file @
fb2a051c
...
@@ -15,6 +15,7 @@ from .symbolic_functions import rms
...
@@ -15,6 +15,7 @@ from .symbolic_functions import rms
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'add_moving_summary'
,
'summary_moving_average'
]
'add_moving_summary'
,
'summary_moving_average'
]
def
create_summary
(
name
,
v
):
def
create_summary
(
name
,
v
):
"""
"""
Return a tf.Summary object with name and simple scalar value v
Return a tf.Summary object with name and simple scalar value v
...
@@ -25,6 +26,7 @@ def create_summary(name, v):
...
@@ -25,6 +26,7 @@ def create_summary(name, v):
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
return
s
return
s
def
add_activation_summary
(
x
,
name
=
None
):
def
add_activation_summary
(
x
,
name
=
None
):
"""
"""
Add summary to graph for an activation tensor x.
Add summary to graph for an activation tensor x.
...
@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
...
@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
tf
.
summary
.
scalar
(
name
+
'-sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
summary
.
scalar
(
name
+
'-sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
summary
.
scalar
(
name
+
'-rms'
,
rms
(
x
))
tf
.
summary
.
scalar
(
name
+
'-rms'
,
rms
(
x
))
def
add_param_summary
(
summary_lists
):
def
add_param_summary
(
summary_lists
):
"""
"""
Add summary for all trainable variables matching the regex
Add summary for all trainable variables matching the regex
...
@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
...
@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
return
return
def
perform
(
var
,
action
):
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
ndim
=
var
.
get_shape
()
.
ndims
name
=
var
.
name
.
replace
(
':0'
,
''
)
name
=
var
.
name
.
replace
(
':0'
,
''
)
...
@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
...
@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
for
act
in
actions
:
for
act
in
actions
:
perform
(
p
,
act
)
perform
(
p
,
act
)
def
add_moving_summary
(
v
,
*
args
):
def
add_moving_summary
(
v
,
*
args
):
"""
"""
:param v: tensor or list of tensor to summary
:param v: tensor or list of tensor to summary
...
@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
...
@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
@
memoized
@
memoized
def
summary_moving_average
(
tensors
=
None
):
def
summary_moving_average
(
tensors
=
None
):
"""
"""
...
@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
...
@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
tensorpack/tfutils/symbolic_functions.py
View file @
fb2a051c
...
@@ -6,6 +6,7 @@ import tensorflow as tf
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
numpy
as
np
import
numpy
as
np
from
..utils
import
logger
from
..utils
import
logger
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
"""
"""
:param logits: NxC
:param logits: NxC
...
@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
...
@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
"""
"""
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
tf
.
float32
,
name
=
name
)
def
flatten
(
x
):
def
flatten
(
x
):
"""
"""
...
@@ -21,6 +23,7 @@ def flatten(x):
...
@@ -21,6 +23,7 @@ def flatten(x):
"""
"""
return
tf
.
reshape
(
x
,
[
-
1
])
return
tf
.
reshape
(
x
,
[
-
1
])
def
batch_flatten
(
x
):
def
batch_flatten
(
x
):
"""
"""
Flatten the tensor except the first dimension.
Flatten the tensor except the first dimension.
...
@@ -30,6 +33,7 @@ def batch_flatten(x):
...
@@ -30,6 +33,7 @@ def batch_flatten(x):
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
tf
.
pack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
return
tf
.
reshape
(
x
,
tf
.
pack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
The class-balanced cross entropy loss,
The class-balanced cross entropy loss,
...
@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
...
@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
,
name
=
name
)
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
,
name
=
name
)
return
cost
return
cost
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
The class-balanced cross entropy loss,
The class-balanced cross entropy loss,
...
@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
...
@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#loss_pos = -beta * tf.reduce_mean(-y *
#
loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
#(logstable - tf.minimum(0.0, z)))
#loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#
loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
#cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
return
cost
def
print_stat
(
x
,
message
=
None
):
def
print_stat
(
x
,
message
=
None
):
""" a simple print op.
""" a simple print op.
Use it like: x = print_stat(x)
Use it like: x = print_stat(x)
...
@@ -89,7 +95,8 @@ def print_stat(x, message=None):
...
@@ -89,7 +95,8 @@ def print_stat(x, message=None):
if
message
is
None
:
if
message
is
None
:
message
=
x
.
op
.
name
message
=
x
.
op
.
name
return
tf
.
Print
(
x
,
[
tf
.
shape
(
x
),
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
return
tf
.
Print
(
x
,
[
tf
.
shape
(
x
),
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
def
rms
(
x
,
name
=
None
):
def
rms
(
x
,
name
=
None
):
if
name
is
None
:
if
name
is
None
:
...
@@ -98,14 +105,16 @@ def rms(x, name=None):
...
@@ -98,14 +105,16 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
sqrcost
=
tf
.
square
(
x
)
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
reduce_sum
(
return
tf
.
reduce_sum
(
tf
.
select
(
abscost
<
delta
,
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
),
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
name
=
name
)
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
"""
"""
...
@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
...
@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
:param summary: summary this variable
:param summary: summary this variable
"""
"""
ret
=
tf
.
get_variable
(
name
,
shape
=
[],
ret
=
tf
.
get_variable
(
name
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
init_value
),
initializer
=
tf
.
constant_initializer
(
init_value
),
trainable
=
trainable
)
trainable
=
trainable
)
if
summary
:
if
summary
:
# this is recognized in callbacks.StatHolder
# this is recognized in callbacks.StatHolder
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ret
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ret
)
...
...
tensorpack/tfutils/tower.py
View file @
fb2a051c
...
@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
...
@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext
=
None
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
self
.
_name
=
tower_name
...
@@ -65,7 +67,7 @@ class TowerContext(object):
...
@@ -65,7 +67,7 @@ class TowerContext(object):
def
__enter__
(
self
):
def
__enter__
(
self
):
global
_CurrentTowerContext
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext
=
self
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
...
@@ -78,7 +80,7 @@ class TowerContext(object):
...
@@ -78,7 +80,7 @@ class TowerContext(object):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
return
False
def
get_current_tower_context
():
def
get_current_tower_context
():
global
_CurrentTowerContext
global
_CurrentTowerContext
return
_CurrentTowerContext
return
_CurrentTowerContext
tensorpack/tfutils/varmanip.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# File: varmanip.py
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
,
os
import
six
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
defaultdict
from
collections
import
defaultdict
import
re
import
re
...
@@ -13,7 +14,8 @@ from ..utils.naming import *
...
@@ -13,7 +14,8 @@ from ..utils.naming import *
from
.common
import
get_op_tensor_name
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
,
'is_training_name'
]
'get_savename_from_varname'
,
'is_training_name'
]
def
get_savename_from_varname
(
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
varname
,
varname_prefix
=
None
,
...
@@ -33,13 +35,15 @@ def get_savename_from_varname(
...
@@ -33,13 +35,15 @@ def get_savename_from_varname(
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
varname_prefix
is
not
None
\
if
varname_prefix
is
not
None
\
and
name
.
startswith
(
varname_prefix
):
and
name
.
startswith
(
varname_prefix
):
name
=
name
[
len
(
varname_prefix
)
+
1
:]
name
=
name
[
len
(
varname_prefix
)
+
1
:]
if
savename_prefix
is
not
None
:
if
savename_prefix
is
not
None
:
name
=
savename_prefix
+
'/'
+
name
name
=
savename_prefix
+
'/'
+
name
return
name
return
name
class
SessionUpdate
(
object
):
class
SessionUpdate
(
object
):
""" Update the variables in a session """
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
def
__init__
(
self
,
sess
,
vars_to_update
):
"""
"""
:param vars_to_update: a collection of variables to update
:param vars_to_update: a collection of variables to update
...
@@ -66,11 +70,12 @@ class SessionUpdate(object):
...
@@ -66,11 +70,12 @@ class SessionUpdate(object):
if
varshape
!=
value
.
shape
:
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
npy format, loadable by ParamRestore
...
@@ -90,6 +95,7 @@ the same name".format(v.name))
...
@@ -90,6 +95,7 @@ the same name".format(v.name))
logger
.
info
(
str
(
result
.
keys
()))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
np
.
save
(
path
,
result
)
def
dump_chkpt_vars
(
model_path
):
def
dump_chkpt_vars
(
model_path
):
""" Dump all variables from a checkpoint to a dict"""
""" Dump all variables from a checkpoint to a dict"""
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
...
@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
...
@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
result
[
n
]
=
reader
.
get_tensor
(
n
)
result
[
n
]
=
reader
.
get_tensor
(
n
)
return
result
return
result
def
is_training_name
(
name
):
def
is_training_name
(
name
):
"""
"""
This is only used to improve logging.
This is only used to improve logging.
...
...
tensorpack/train/__init__.py
View file @
fb2a051c
...
@@ -8,6 +8,7 @@ import os.path
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
__all__
=
[]
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
if
module_name
.
startswith
(
'_'
):
continue
continue
global_import
(
module_name
)
global_import
(
module_name
)
tensorpack/train/base.py
View file @
fb2a051c
...
@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
...
@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
__all__
=
[
'Trainer'
,
'StopTraining'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
class
StopTraining
(
BaseException
):
pass
pass
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer."""
""" Base class for a trainer."""
...
@@ -91,7 +94,7 @@ class Trainer(object):
...
@@ -91,7 +94,7 @@ class Trainer(object):
for
val
in
summary
.
value
:
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
suffix
=
'-summary'
# issue#6150
suffix
=
'-summary'
# issue#6150
if
val
.
tag
.
endswith
(
suffix
):
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
...
@@ -99,7 +102,7 @@ class Trainer(object):
...
@@ -99,7 +102,7 @@ class Trainer(object):
def
write_scalar_summary
(
self
,
name
,
val
):
def
write_scalar_summary
(
self
,
name
,
val
):
self
.
summary_writer
.
add_summary
(
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
setup
(
self
):
def
setup
(
self
):
...
@@ -138,7 +141,7 @@ class Trainer(object):
...
@@ -138,7 +141,7 @@ class Trainer(object):
callbacks
.
before_train
()
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
for
epoch_num
in
range
(
for
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
'Epoch {} (global_step {})'
.
format
(
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
...
@@ -147,7 +150,7 @@ class Trainer(object):
...
@@ -147,7 +150,7 @@ class Trainer(object):
**
get_tqdm_kwargs
(
leave
=
True
)):
**
get_tqdm_kwargs
(
leave
=
True
)):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
# not useful?
callbacks
.
trigger_step
()
# not useful?
# trigger epoch outside the timing region.
# trigger epoch outside the timing region.
self
.
trigger_epoch
()
self
.
trigger_epoch
()
...
...
tensorpack/train/config.py
View file @
fb2a051c
...
@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow
...
@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
.input_data
import
InputData
from
.input_data
import
InputData
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
class
TrainConfig
(
object
):
class
TrainConfig
(
object
):
"""
"""
Config for training a model with a single loss
Config for training a model with a single loss
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param dataset: the dataset to train. a `DataFlow` instance.
...
...
tensorpack/train/feedfree.py
View file @
fb2a051c
...
@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
...
@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster) """
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
# note that summary_op will take a data from the queue
...
@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
...
@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
()
actual_inputs
=
self
.
_get_input_tensors
()
...
@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...
@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var
=
self
.
model
.
get_cost
()
cost_var
=
self
.
model
.
get_cost
()
# GATE_NONE faster?
# GATE_NONE faster?
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
cost_var
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
False
)
colocate_gradients_with_ops
=
False
)
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
return
cost_var
,
grads
return
cost_var
,
grads
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run self.train_op"""
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
#if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
#self.cnt = 0
# self.cnt = 0
#else:
# else:
#self.cnt += 1
# self.cnt += 1
#if self.cnt % 10 == 0:
# if self.cnt % 10 == 0:
## debug-benchmark code:
# # debug-benchmark code:
#run_metadata = tf.RunMetadata()
# run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
# self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
# options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
# run_metadata=run_metadata
#)
# )
#from tensorflow.python.client import timeline
# from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
# trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
# trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
# trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
# import sys; sys.exit()
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
A trainer with single cost, single training tower and feed-free input
A trainer with single cost, single training tower and feed-free input
...
@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer(
...
@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer(
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
...
@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
...
@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
...
@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
...
@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
tensorpack/train/input_data.py
View file @
fb2a051c
...
@@ -14,13 +14,16 @@ from ..utils import logger
...
@@ -14,13 +14,16 @@ from ..utils import logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
'DummyConstantInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
class
InputData
(
object
):
pass
pass
class
FeedInput
(
InputData
):
class
FeedInput
(
InputData
):
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -39,7 +42,9 @@ class FeedInput(InputData):
...
@@ -39,7 +42,9 @@ class FeedInput(InputData):
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
return
feed
return
feed
class
FeedfreeInput
(
InputData
):
class
FeedfreeInput
(
InputData
):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
return
self
.
_get_input_tensors
()
return
self
.
_get_input_tensors
()
...
@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
...
@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
always create and return a list of new input tensors
always create and return a list of new input tensors
"""
"""
class
EnqueueThread
(
threading
.
Thread
):
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
name
=
'EnqueueThread'
...
@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
...
@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#
print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
as
e
:
pass
pass
...
@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
...
@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
pass
pass
logger
.
info
(
"Enqueue Thread Exited."
)
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInput
(
FeedfreeInput
):
class
QueueInput
(
FeedfreeInput
):
def
__init__
(
self
,
ds
,
queue
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
"""
:param ds: a `DataFlow` instance
:param ds: a `DataFlow` instance
...
@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput):
...
@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput):
def
_setup
(
self
,
trainer
):
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
"QueueInput can only be used with input placeholders!"
if
self
.
queue
is
None
:
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
qv
.
set_shape
(
v
.
get_shape
())
# test the overhead of queue
# test the overhead of queue
#with tf.device('/gpu:0'):
#
with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
dtype=tf.float32), trainable=False),
#
dtype=tf.float32), trainable=False),
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return
ret
return
ret
class
DummyConstantInput
(
QueueInput
):
class
DummyConstantInput
(
QueueInput
):
""" only for debugging performance issues """
""" only for debugging performance issues """
def
__init__
(
self
,
ds
,
shapes
):
def
__init__
(
self
,
ds
,
shapes
):
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
self
.
shapes
=
shapes
self
.
shapes
=
shapes
...
@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput):
...
@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput):
for
idx
,
p
in
enumerate
(
placehdrs
):
for
idx
,
p
in
enumerate
(
placehdrs
):
with
tf
.
device
(
'/gpu:0'
):
with
tf
.
device
(
'/gpu:0'
):
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
()))
initializer
=
tf
.
constant_initializer
()))
return
ret
return
ret
class
TensorInput
(
FeedfreeInput
):
class
TensorInput
(
FeedfreeInput
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
self
.
get_tensor_fn
=
get_tensor_fn
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
self
.
_size
=
size
...
...
tensorpack/train/multigpu.py
View file @
fb2a051c
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
itertools
,
re
import
itertools
import
re
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils
import
logger
...
@@ -12,7 +13,7 @@ from ..utils.naming import *
...
@@ -12,7 +13,7 @@ from ..utils.naming import *
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils
import
(
backup_collection
,
restore_collection
,
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
.base
import
Trainer
from
.base
import
Trainer
...
@@ -22,6 +23,7 @@ from .input_data import QueueInput
...
@@ -22,6 +23,7 @@ from .input_data import QueueInput
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Trainer
):
class
MultiGPUTrainer
(
Trainer
):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
@
staticmethod
@
staticmethod
...
@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
...
@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
restore_collection
(
backup
)
restore_collection
(
backup
)
return
grad_list
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
...
@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
if
len
(
tower_grads
)
==
1
:
if
len
(
tower_grads
)
==
1
:
...
@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
# debug tower performance:
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#self.train_op = tf.group(*ops)
#return
#
return
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
...
@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
run_step
(
self
):
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
input_queue
=
None
,
average_gradient
=
True
,
average_gradient
=
True
,
predict_tower
=
None
):
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
else
:
...
@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
def
_setup
(
self
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
if
self
.
_average_gradient
and
self
.
config
.
nr_tower
>
1
:
if
self
.
_average_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
# pretend to average the grads, in order to make async and
...
@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self
.
training_threads
=
[]
self
.
training_threads
=
[]
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)):
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
self
.
sess
.
run
([
op
])
next
(
self
.
async_step_counter
)
next
(
self
.
async_step_counter
)
th
=
LoopThread
(
f
)
th
=
LoopThread
(
f
)
...
@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
run_step
(
self
):
def
run_step
(
self
):
if
not
self
.
async_running
:
if
not
self
.
async_running
:
self
.
async_running
=
True
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
th
.
resume
()
next
(
self
.
async_step_counter
)
next
(
self
.
async_step_counter
)
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
...
@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
async_step_total_cnt
=
int
(
re
.
findall
(
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
write_scalar_summary
(
self
.
write_scalar_summary
(
'async_global_step'
,
async_step_total_cnt
)
'async_global_step'
,
async_step_total_cnt
)
except
:
except
:
logger
.
exception
(
"Cannot log async_global_step"
)
logger
.
exception
(
"Cannot log async_global_step"
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
tensorpack/train/trainer.py
View file @
fb2a051c
...
@@ -10,13 +10,14 @@ from .base import Trainer
...
@@ -10,13 +10,14 @@ from .base import Trainer
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
,
FeedfreeInput
from
.input_data
import
FeedInput
,
FeedfreeInput
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
class
PredictorFactory
(
object
):
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
""" Make predictors for a trainer"""
...
@@ -52,8 +53,10 @@ class PredictorFactory(object):
...
@@ -52,8 +53,10 @@ class PredictorFactory(object):
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer """
""" A naive demo trainer """
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
...
@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer):
...
@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer):
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
apply_grad_processors
(
grads
,
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
...
@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer):
...
@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
MultiPredictorTowerTrainer
(
Trainer
):
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
,
predict_tower
):
def
_setup_predictor_factory
(
self
,
predict_tower
):
# by default, use the first training gpu for prediction
# by default, use the first training gpu for prediction
predict_tower
=
predict_tower
or
[
0
]
predict_tower
=
predict_tower
or
[
0
]
self
.
_predictor_factory
=
PredictorFactory
(
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
predict_tower
)
self
.
sess
,
self
.
model
,
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
...
...
tensorpack/utils/__init__.py
View file @
fb2a051c
...
@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
...
@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
__all__
=
[]
__all__
=
[]
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
@@ -23,7 +24,7 @@ _TO_IMPORT = set([
...
@@ -23,7 +24,7 @@ _TO_IMPORT = set([
'naming'
,
'naming'
,
'utils'
,
'utils'
,
'gpu'
'gpu'
])
])
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
for
_
,
module_name
,
_
in
walk_packages
(
for
_
,
module_name
,
_
in
walk_packages
(
...
@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
...
@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
if
module_name
in
_TO_IMPORT
:
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
_global_import
(
module_name
)
__all__
.
append
(
module_name
)
__all__
.
append
(
module_name
)
tensorpack/utils/argtools.py
View file @
fb2a051c
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
import
operator
import
operator
import
inspect
,
six
,
functools
import
inspect
import
six
import
functools
import
collections
import
collections
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
"""
"""
...
@@ -26,11 +29,13 @@ def map_arg(**maps):
...
@@ -26,11 +29,13 @@ def map_arg(**maps):
return
wrapper
return
wrapper
return
deco
return
deco
class
memoized
(
object
):
class
memoized
(
object
):
'''Decorator. Caches a function's return value each time it is called.
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
If called later with the same arguments, the cached value is returned
(not reevaluated).
(not reevaluated).
'''
'''
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
self
.
func
=
func
self
.
func
=
func
self
.
cache
=
{}
self
.
cache
=
{}
...
@@ -60,8 +65,11 @@ class memoized(object):
...
@@ -60,8 +65,11 @@ class memoized(object):
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
if
func
not
in
_MEMOIZED_NOARGS
:
res
=
func
(
*
args
,
**
kwargs
)
res
=
func
(
*
args
,
**
kwargs
)
...
@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
...
@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
return
_MEMOIZED_NOARGS
[
func
]
return
_MEMOIZED_NOARGS
[
func
]
return
wrapper
return
wrapper
#_GLOBAL_MEMOIZED_CACHE = dict()
# _GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
# def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
# """ Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
# calls to global_memoized(func)
#"""
# """
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
# if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
# return ret
def
shape2d
(
a
):
def
shape2d
(
a
):
"""
"""
...
...
tensorpack/utils/concurrency.py
View file @
fb2a051c
...
@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
...
@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'mask_sigint'
,
'start_proc_mask_signal'
]
'mask_sigint'
,
'start_proc_mask_signal'
]
class
StoppableThread
(
threading
.
Thread
):
class
StoppableThread
(
threading
.
Thread
):
"""
"""
A thread that has a 'stop' event.
A thread that has a 'stop' event.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop_evt
=
threading
.
Event
()
self
.
_stop_evt
=
threading
.
Event
()
...
@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
...
@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
except
queue
.
Empty
:
except
queue
.
Empty
:
pass
pass
class
LoopThread
(
StoppableThread
):
class
LoopThread
(
StoppableThread
):
""" A pausable thread that simply runs a loop"""
""" A pausable thread that simply runs a loop"""
def
__init__
(
self
,
func
,
pausable
=
True
):
def
__init__
(
self
,
func
,
pausable
=
True
):
"""
"""
:param func: the function to run
:param func: the function to run
...
@@ -89,6 +93,7 @@ class DIE(object):
...
@@ -89,6 +93,7 @@ class DIE(object):
""" A placeholder class indicating end of queue """
""" A placeholder class indicating end of queue """
pass
pass
def
ensure_proc_terminate
(
proc
):
def
ensure_proc_terminate
(
proc
):
if
isinstance
(
proc
,
list
):
if
isinstance
(
proc
,
list
):
for
p
in
proc
:
for
p
in
proc
:
...
@@ -114,6 +119,7 @@ def mask_sigint():
...
@@ -114,6 +119,7 @@ def mask_sigint():
yield
yield
signal
.
signal
(
signal
.
SIGINT
,
sigint_handler
)
signal
.
signal
(
signal
.
SIGINT
,
sigint_handler
)
def
start_proc_mask_signal
(
proc
):
def
start_proc_mask_signal
(
proc
):
if
not
isinstance
(
proc
,
list
):
if
not
isinstance
(
proc
,
list
):
proc
=
[
proc
]
proc
=
[
proc
]
...
@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc):
...
@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc):
for
p
in
proc
:
for
p
in
proc
:
p
.
start
()
p
.
start
()
def
subproc_call
(
cmd
,
timeout
=
None
):
def
subproc_call
(
cmd
,
timeout
=
None
):
try
:
try
:
output
=
subprocess
.
check_output
(
output
=
subprocess
.
check_output
(
cmd
,
stderr
=
subprocess
.
STDOUT
,
cmd
,
stderr
=
subprocess
.
STDOUT
,
shell
=
True
,
timeout
=
timeout
)
shell
=
True
,
timeout
=
timeout
)
return
output
return
output
except
subprocess
.
TimeoutExpired
as
e
:
except
subprocess
.
TimeoutExpired
as
e
:
logger
.
warn
(
"Command timeout!"
)
logger
.
warn
(
"Command timeout!"
)
...
@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
...
@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
logger
.
warn
(
"Commnad failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
"Commnad failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
e
.
output
)
logger
.
warn
(
e
.
output
)
class
OrderedContainer
(
object
):
class
OrderedContainer
(
object
):
"""
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
"""
def
__init__
(
self
,
start
=
0
):
def
__init__
(
self
,
start
=
0
):
self
.
ranks
=
[]
self
.
ranks
=
[]
self
.
data
=
[]
self
.
data
=
[]
...
@@ -163,11 +172,13 @@ class OrderedContainer(object):
...
@@ -163,11 +172,13 @@ class OrderedContainer(object):
self
.
wait_for
+=
1
self
.
wait_for
+=
1
return
rank
,
ret
return
rank
,
ret
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
"""
"""
Gather indexed data from a data queue, and produce results with the
Gather indexed data from a data queue, and produce results with the
original index-based order.
original index-based order.
"""
"""
def
__init__
(
self
,
data_queue
,
nr_producer
,
start
=
0
):
def
__init__
(
self
,
data_queue
,
nr_producer
,
start
=
0
):
"""
"""
:param data_queue: a multiprocessing.Queue to produce input dp
:param data_queue: a multiprocessing.Queue to produce input dp
...
...
tensorpack/utils/debug.py
View file @
fb2a051c
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
import
sys
import
sys
__all__
=
[
'enable_call_trace'
]
__all__
=
[
'enable_call_trace'
]
def
enable_call_trace
():
def
enable_call_trace
():
def
tracer
(
frame
,
event
,
arg
):
def
tracer
(
frame
,
event
,
arg
):
if
event
==
'call'
:
if
event
==
'call'
:
...
@@ -21,9 +22,9 @@ def enable_call_trace():
...
@@ -21,9 +22,9 @@ def enable_call_trace():
if
caller
:
if
caller
:
caller_line_no
=
caller
.
f_lineno
caller_line_no
=
caller
.
f_lineno
caller_filename
=
caller
.
f_code
.
co_filename
caller_filename
=
caller
.
f_code
.
co_filename
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
(
func_name
,
func_filename
,
func_line_no
,
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
))
caller_filename
,
caller_line_no
))
return
return
sys
.
settrace
(
tracer
)
sys
.
settrace
(
tracer
)
...
@@ -32,6 +33,7 @@ if __name__ == '__main__':
...
@@ -32,6 +33,7 @@ if __name__ == '__main__':
def
b
(
a
):
def
b
(
a
):
print
(
2
)
print
(
2
)
def
a
():
def
a
():
print
(
1
)
print
(
1
)
b
(
1
)
b
(
1
)
...
...
tensorpack/utils/discretize.py
View file @
fb2a051c
...
@@ -12,11 +12,14 @@ from six.moves import range
...
@@ -12,11 +12,14 @@ from six.moves import range
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
@
memoized
@
memoized
def
log_once
(
s
):
def
log_once
(
s
):
logger
.
warn
(
s
)
logger
.
warn
(
s
)
# just a placeholder
# just a placeholder
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Discretizer
(
object
):
class
Discretizer
(
object
):
...
@@ -28,10 +31,13 @@ class Discretizer(object):
...
@@ -28,10 +31,13 @@ class Discretizer(object):
def
get_bin
(
self
,
v
):
def
get_bin
(
self
,
v
):
pass
pass
class
Discretizer1D
(
Discretizer
):
class
Discretizer1D
(
Discretizer
):
pass
pass
class
UniformDiscretizer1D
(
Discretizer1D
):
class
UniformDiscretizer1D
(
Discretizer1D
):
def
__init__
(
self
,
minv
,
maxv
,
spacing
):
def
__init__
(
self
,
minv
,
maxv
,
spacing
):
"""
"""
:params minv: minimum value of the first bin
:params minv: minimum value of the first bin
...
@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D):
...
@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D):
log_once
(
"UniformDiscretizer1D: value larger than max!"
)
log_once
(
"UniformDiscretizer1D: value larger than max!"
)
return
self
.
nr_bin
-
1
return
self
.
nr_bin
-
1
return
int
(
np
.
clip
(
return
int
(
np
.
clip
(
(
v
-
self
.
minv
)
/
self
.
spacing
,
(
v
-
self
.
minv
)
/
self
.
spacing
,
0
,
self
.
nr_bin
-
1
))
0
,
self
.
nr_bin
-
1
))
def
get_bin_center
(
self
,
bin_id
):
def
get_bin_center
(
self
,
bin_id
):
return
self
.
minv
+
self
.
spacing
*
(
bin_id
+
0.5
)
return
self
.
minv
+
self
.
spacing
*
(
bin_id
+
0.5
)
...
@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
...
@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
if
v
>=
self
.
maxv
or
v
<=
self
.
minv
:
if
v
>=
self
.
maxv
or
v
<=
self
.
minv
:
return
ret
return
ret
try
:
try
:
for
k
in
range
(
1
,
smooth_radius
+
1
):
for
k
in
range
(
1
,
smooth_radius
+
1
):
ret
[
b
+
k
]
=
smooth_factor
**
k
ret
[
b
+
k
]
=
smooth_factor
**
k
except
IndexError
:
except
IndexError
:
pass
pass
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
ret
[
b
-
k
]
=
smooth_factor
**
k
ret
[
b
-
k
]
=
smooth_factor
**
k
ret
/=
ret
.
sum
()
ret
/=
ret
.
sum
()
return
ret
return
ret
class
UniformDiscretizerND
(
Discretizer
):
class
UniformDiscretizerND
(
Discretizer
):
def
__init__
(
self
,
*
min_max_spacing
):
def
__init__
(
self
,
*
min_max_spacing
):
"""
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
:params min_max_spacing: (minv, maxv, spacing) for each dimension
...
@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
...
@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#u = UniformDiscretizer1D(-10, 10, 0.12)
#u = UniformDiscretizer1D(-10, 10, 0.12)
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/utils/fs.py
View file @
fb2a051c
...
@@ -3,13 +3,15 @@
...
@@ -3,13 +3,15 @@
# File: fs.py
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
sys
import
os
import
sys
from
six.moves
import
urllib
from
six.moves
import
urllib
import
errno
import
errno
from
.
import
logger
from
.
import
logger
__all__
=
[
'mkdir_p'
,
'download'
,
'recursive_walk'
]
__all__
=
[
'mkdir_p'
,
'download'
,
'recursive_walk'
]
def
mkdir_p
(
dirname
):
def
mkdir_p
(
dirname
):
""" make a dir recursively, but do nothing if the dir exists"""
""" make a dir recursively, but do nothing if the dir exists"""
assert
dirname
is
not
None
assert
dirname
is
not
None
...
@@ -21,6 +23,7 @@ def mkdir_p(dirname):
...
@@ -21,6 +23,7 @@ def mkdir_p(dirname):
if
e
.
errno
!=
errno
.
EEXIST
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
e
raise
e
def
download
(
url
,
dir
):
def
download
(
url
,
dir
):
mkdir_p
(
dir
)
mkdir_p
(
dir
)
fname
=
url
.
split
(
'/'
)[
-
1
]
fname
=
url
.
split
(
'/'
)[
-
1
]
...
@@ -29,7 +32,7 @@ def download(url, dir):
...
@@ -29,7 +32,7 @@ def download(url, dir):
def
_progress
(
count
,
block_size
,
total_size
):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
(
fname
,
min
(
float
(
count
*
block_size
)
/
total_size
,
min
(
float
(
count
*
block_size
)
/
total_size
,
1.0
)
*
100.0
))
1.0
)
*
100.0
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
try
:
try
:
...
@@ -45,6 +48,7 @@ def download(url, dir):
...
@@ -45,6 +48,7 @@ def download(url, dir):
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
return
fpath
return
fpath
def
recursive_walk
(
rootdir
):
def
recursive_walk
(
rootdir
):
for
r
,
dirs
,
files
in
os
.
walk
(
rootdir
):
for
r
,
dirs
,
files
in
os
.
walk
(
rootdir
):
for
f
in
files
:
for
f
in
files
:
...
...
tensorpack/utils/globvars.py
View file @
fb2a051c
...
@@ -9,13 +9,15 @@ import argparse
...
@@ -9,13 +9,15 @@ import argparse
__all__
=
[
'globalns'
,
'use_global_argument'
]
__all__
=
[
'globalns'
,
'use_global_argument'
]
if
six
.
PY2
:
if
six
.
PY2
:
class
NS
:
pass
class
NS
:
pass
else
:
else
:
import
types
import
types
NS
=
types
.
SimpleNamespace
NS
=
types
.
SimpleNamespace
globalns
=
NS
()
globalns
=
NS
()
def
use_global_argument
(
args
):
def
use_global_argument
(
args
):
"""
"""
Add the content of argparse.Namespace to globalns
Add the content of argparse.Namespace to globalns
...
...
tensorpack/utils/gpu.py
View file @
fb2a051c
...
@@ -8,20 +8,22 @@ from .utils import change_env
...
@@ -8,20 +8,22 @@ from .utils import change_env
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_gpus'
]
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_gpus'
]
def
change_gpu
(
val
):
def
change_gpu
(
val
):
val
=
str
(
val
)
val
=
str
(
val
)
if
val
==
'-1'
:
if
val
==
'-1'
:
val
=
''
val
=
''
return
change_env
(
'CUDA_VISIBLE_DEVICES'
,
val
)
return
change_env
(
'CUDA_VISIBLE_DEVICES'
,
val
)
def
get_nr_gpu
():
def
get_nr_gpu
():
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
len
(
env
.
split
(
','
))
return
len
(
env
.
split
(
','
))
def
get_gpus
():
def
get_gpus
():
""" return a list of GPU physical id"""
""" return a list of GPU physical id"""
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
tensorpack/utils/loadcaffe.py
View file @
fb2a051c
...
@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
...
@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL
=
"https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
CAFFE_PROTO_URL
=
"https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
class
CaffeLayerProcessor
(
object
):
class
CaffeLayerProcessor
(
object
):
def
__init__
(
self
,
net
):
def
__init__
(
self
,
net
):
self
.
net
=
net
self
.
net
=
net
self
.
layer_names
=
net
.
_layer_names
self
.
layer_names
=
net
.
_layer_names
...
@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object):
...
@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object):
self
.
param_dict
.
update
(
dic
)
self
.
param_dict
.
update
(
dic
)
elif
len
(
layer
.
blobs
)
!=
0
:
elif
len
(
layer
.
blobs
)
!=
0
:
logger
.
warn
(
logger
.
warn
(
"{} layer contains parameters but is not supported!"
.
format
(
layer
.
type
))
"{} layer contains parameters but is not supported!"
.
format
(
layer
.
type
))
return
self
.
param_dict
return
self
.
param_dict
def
proc_conv
(
self
,
idx
,
name
,
param
):
def
proc_conv
(
self
,
idx
,
name
,
param
):
assert
len
(
param
)
<=
2
assert
len
(
param
)
<=
2
assert
param
[
0
]
.
data
.
ndim
==
4
assert
param
[
0
]
.
data
.
ndim
==
4
# caffe: ch_out, ch_in, h, w
# caffe: ch_out, ch_in, h, w
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
if
len
(
param
)
==
1
:
if
len
(
param
)
==
1
:
return
{
name
+
'/W'
:
W
}
return
{
name
+
'/W'
:
W
}
else
:
else
:
...
@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
...
@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"FC layer {} takes spatial data."
.
format
(
name
))
logger
.
info
(
"FC layer {} takes spatial data."
.
format
(
name
))
W
=
param
[
0
]
.
data
W
=
param
[
0
]
.
data
# original: outx(CxHxW)
# original: outx(CxHxW)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
# become: (HxWxC)xout
# become: (HxWxC)xout
else
:
else
:
W
=
param
[
0
]
.
data
.
transpose
()
W
=
param
[
0
]
.
data
.
transpose
()
...
@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
...
@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
def
proc_bn
(
self
,
idx
,
name
,
param
):
def
proc_bn
(
self
,
idx
,
name
,
param
):
assert
param
[
2
]
.
data
[
0
]
==
1.0
assert
param
[
2
]
.
data
[
0
]
==
1.0
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
def
proc_scale
(
self
,
idx
,
name
,
param
):
def
proc_scale
(
self
,
idx
,
name
,
param
):
bottom_name
=
self
.
net
.
bottom_names
[
name
][
0
]
bottom_name
=
self
.
net
.
bottom_names
[
name
][
0
]
...
@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
...
@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"Merge {} and {} into one BatchNorm layer"
.
format
(
logger
.
info
(
"Merge {} and {} into one BatchNorm layer"
.
format
(
name
,
name2
))
name
,
name2
))
return
{
name2
+
'/beta'
:
param
[
1
]
.
data
,
return
{
name2
+
'/beta'
:
param
[
1
]
.
data
,
name2
+
'/gamma'
:
param
[
0
]
.
data
}
name2
+
'/gamma'
:
param
[
0
]
.
data
}
# assume this scaling layer is part of some BN
# assume this scaling layer is part of some BN
logger
.
error
(
"Could not find a BN layer corresponding to this Scale layer!"
)
logger
.
error
(
"Could not find a BN layer corresponding to this Scale layer!"
)
raise
ValueError
()
raise
ValueError
()
...
@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
...
@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
caffe
.
set_mode_cpu
()
caffe
.
set_mode_cpu
()
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
param_dict
=
CaffeLayerProcessor
(
net
)
.
process
()
param_dict
=
CaffeLayerProcessor
(
net
)
.
process
()
logger
.
info
(
"Model loaded from caffe. Params: "
+
\
logger
.
info
(
"Model loaded from caffe. Params: "
+
" "
.
join
(
sorted
(
param_dict
.
keys
())))
" "
.
join
(
sorted
(
param_dict
.
keys
())))
return
param_dict
return
param_dict
def
get_caffe_pb
():
def
get_caffe_pb
():
dir
=
get_dataset_path
(
'caffe'
)
dir
=
get_dataset_path
(
'caffe'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
...
@@ -116,7 +119,7 @@ def get_caffe_pb():
...
@@ -116,7 +119,7 @@ def get_caffe_pb():
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
assert
ret
==
0
,
\
"Command `protoc caffe.proto --python_out .` failed!"
"Command `protoc caffe.proto --python_out .` failed!"
import
imp
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
...
@@ -131,4 +134,3 @@ if __name__ == '__main__':
...
@@ -131,4 +134,3 @@ if __name__ == '__main__':
import
numpy
as
np
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
np
.
save
(
args
.
output
,
ret
)
tensorpack/utils/logger.py
View file @
fb2a051c
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
logging
import
logging
import
os
,
shutil
import
os
import
shutil
import
os.path
import
os.path
from
termcolor
import
colored
from
termcolor
import
colored
from
datetime
import
datetime
from
datetime
import
datetime
...
@@ -12,7 +13,9 @@ import sys
...
@@ -12,7 +13,9 @@ import sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
class
_MyFormatter
(
logging
.
Formatter
):
class
_MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
msg
=
'
%(message)
s'
...
@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
...
@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
self
.
_fmt
=
fmt
self
.
_fmt
=
fmt
return
super
(
_MyFormatter
,
self
)
.
format
(
record
)
return
super
(
_MyFormatter
,
self
)
.
format
(
record
)
def
_getlogger
():
def
_getlogger
():
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
.
propagate
=
False
logger
.
propagate
=
False
...
@@ -45,6 +49,8 @@ def get_time_str():
...
@@ -45,6 +49,8 @@ def get_time_str():
# logger file and directory:
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
global
LOG_FILE
,
LOG_DIR
LOG_DIR
=
None
LOG_DIR
=
None
def
_set_file
(
path
):
def
_set_file
(
path
):
if
os
.
path
.
isfile
(
path
):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
get_time_str
()
backup_name
=
path
+
'.'
+
get_time_str
()
...
@@ -56,6 +62,7 @@ def _set_file(path):
...
@@ -56,6 +62,7 @@ def _set_file(path):
_logger
.
addHandler
(
hdl
)
_logger
.
addHandler
(
hdl
)
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
def
set_logger_dir
(
dirname
,
action
=
None
):
def
set_logger_dir
(
dirname
,
action
=
None
):
"""
"""
Set the directory for global logging.
Set the directory for global logging.
...
@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
...
@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
def
disable_logger
():
""" disable all logging ability from this moment"""
""" disable all logging ability from this moment"""
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
globals
()[
func
]
=
lambda
x
:
None
globals
()[
func
]
=
lambda
x
:
None
def
auto_set_dir
(
action
=
None
,
overwrite
=
False
):
def
auto_set_dir
(
action
=
None
,
overwrite
=
False
):
""" set log directory to a subdir inside 'train_log', with the name being
""" set log directory to a subdir inside 'train_log', with the name being
the main python file currently running"""
the main python file currently running"""
...
@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False):
...
@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False):
mod
=
sys
.
modules
[
'__main__'
]
mod
=
sys
.
modules
[
'__main__'
]
basename
=
os
.
path
.
basename
(
mod
.
__file__
)
basename
=
os
.
path
.
basename
(
mod
.
__file__
)
set_logger_dir
(
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]),
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
action
=
action
)
def
warn_dependency
(
name
,
dependencies
):
def
warn_dependency
(
name
,
dependencies
):
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
tensorpack/utils/lut.py
View file @
fb2a051c
...
@@ -7,10 +7,12 @@ import six
...
@@ -7,10 +7,12 @@ import six
__all__
=
[
'LookUpTable'
]
__all__
=
[
'LookUpTable'
]
class
LookUpTable
(
object
):
class
LookUpTable
(
object
):
def
__init__
(
self
,
objlist
):
def
__init__
(
self
,
objlist
):
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
def
size
(
self
):
def
size
(
self
):
return
len
(
self
.
idx2obj
)
return
len
(
self
.
idx2obj
)
...
...
tensorpack/utils/rect.py
View file @
fb2a051c
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
numpy
as
np
import
numpy
as
np
class
Rect
(
object
):
class
Rect
(
object
):
"""
"""
A Rectangle.
A Rectangle.
...
@@ -68,7 +69,7 @@ class Rect(object):
...
@@ -68,7 +69,7 @@ class Rect(object):
def
roi
(
self
,
img
):
def
roi
(
self
,
img
):
assert
self
.
validate
(
img
.
shape
[:
2
]),
"{} vs {}"
.
format
(
self
,
img
.
shape
[:
2
])
assert
self
.
validate
(
img
.
shape
[:
2
]),
"{} vs {}"
.
format
(
self
,
img
.
shape
[:
2
])
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
def
expand
(
self
,
frac
):
def
expand
(
self
,
frac
):
assert
frac
>
1.0
,
frac
assert
frac
>
1.0
,
frac
...
@@ -92,7 +93,7 @@ class Rect(object):
...
@@ -92,7 +93,7 @@ class Rect(object):
xmax
=
min
(
self
.
x1
,
img
.
shape
[
1
])
xmax
=
min
(
self
.
x1
,
img
.
shape
[
1
])
ymax
=
min
(
self
.
y1
,
img
.
shape
[
0
])
ymax
=
min
(
self
.
y1
,
img
.
shape
[
0
])
patch
=
img
[
ymin
:
ymax
,
xmin
:
xmax
]
patch
=
img
[
ymin
:
ymax
,
xmin
:
xmax
]
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
return
ret
return
ret
__repr__
=
__str__
__repr__
=
__str__
...
@@ -101,6 +102,6 @@ class Rect(object):
...
@@ -101,6 +102,6 @@ class Rect(object):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
img
=
np
.
random
.
rand
(
3
,
3
)
img
=
np
.
random
.
rand
(
3
,
3
)
print
(
img
)
print
(
img
)
print
(
x
.
roi_zeropad
(
img
))
print
(
x
.
roi_zeropad
(
img
))
tensorpack/utils/serialize.py
View file @
fb2a051c
...
@@ -10,10 +10,12 @@ msgpack_numpy.patch()
...
@@ -10,10 +10,12 @@ msgpack_numpy.patch()
__all__
=
[
'loads'
,
'dumps'
]
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
def
dumps
(
obj
):
#return dill.dumps(obj)
#
return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
def
loads
(
buf
):
#return dill.loads(buf)
#
return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/stats.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
tensorpack/utils/timer.py
View file @
fb2a051c
...
@@ -14,10 +14,12 @@ from .stats import StatCounter
...
@@ -14,10 +14,12 @@ from .stats import StatCounter
from
.
import
logger
from
.
import
logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
,
'IterSpeedCounter'
]
'print_total_timer'
,
'IterSpeedCounter'
]
class
IterSpeedCounter
(
object
):
class
IterSpeedCounter
(
object
):
""" To count how often some code gets reached"""
""" To count how often some code gets reached"""
def
__init__
(
self
,
print_every
,
name
=
None
):
def
__init__
(
self
,
print_every
,
name
=
None
):
self
.
cnt
=
0
self
.
cnt
=
0
self
.
print_every
=
int
(
print_every
)
self
.
print_every
=
int
(
print_every
)
...
@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
...
@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
@
contextmanager
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
if
log_start
:
...
@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
...
@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
@
contextmanager
@
contextmanager
def
total_timer
(
msg
):
def
total_timer
(
msg
):
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -54,6 +58,7 @@ def total_timer(msg):
...
@@ -54,6 +58,7 @@ def total_timer(msg):
t
=
time
.
time
()
-
start
t
=
time
.
time
()
-
start
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
def
print_total_timer
():
def
print_total_timer
():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
return
...
...
tensorpack/utils/utils.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
tensorpack/utils/viz.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment