Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fb2a051c
Commit
fb2a051c
authored
Jan 02, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
run autopep8 over tensorpack/
parent
59553585
Changes
100
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
100 changed files
with
1047 additions
and
532 deletions
+1047
-532
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+2
-1
tensorpack/RL/common.py
tensorpack/RL/common.py
+9
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+13
-1
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+50
-45
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+4
-2
tensorpack/RL/history.py
tensorpack/RL/history.py
+2
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+22
-5
tensorpack/__init__.py
tensorpack/__init__.py
+1
-1
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+2
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+5
-1
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+4
-2
tensorpack/callbacks/dispatcher.py
tensorpack/callbacks/dispatcher.py
+2
-0
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+2
-1
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-0
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+7
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+17
-12
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+24
-10
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+14
-10
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+13
-4
tensorpack/dataflow/__init__.py
tensorpack/dataflow/__init__.py
+2
-2
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+4
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+28
-4
tensorpack/dataflow/dataset/__init__.py
tensorpack/dataflow/dataset/__init__.py
+2
-1
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+5
-3
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+16
-7
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+13
-8
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+12
-6
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+2
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+4
-3
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+6
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+9
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+22
-8
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+6
-2
tensorpack/dataflow/imgaug/__init__.py
tensorpack/dataflow/imgaug/__init__.py
+1
-1
tensorpack/dataflow/imgaug/_test.py
tensorpack/dataflow/imgaug/_test.py
+4
-4
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+5
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+26
-16
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+18
-12
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+22
-18
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+21
-8
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+11
-2
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+5
-0
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+13
-6
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+10
-3
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+11
-2
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+8
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+4
-2
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+14
-12
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+2
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+9
-4
tensorpack/models/_test.py
tensorpack/models/_test.py
+3
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+19
-16
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+10
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-1
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+31
-26
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+14
-6
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+5
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+30
-19
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-1
tensorpack/models/shapes.py
tensorpack/models/shapes.py
+1
-0
tensorpack/models/softmax.py
tensorpack/models/softmax.py
+2
-1
tensorpack/predict/__init__.py
tensorpack/predict/__init__.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+22
-12
tensorpack/predict/common.py
tensorpack/predict/common.py
+6
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+18
-10
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+18
-10
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+2
-0
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+13
-3
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+12
-1
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+14
-4
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+6
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+21
-12
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+4
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+11
-4
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+7
-4
tensorpack/train/config.py
tensorpack/train/config.py
+3
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+28
-21
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+26
-13
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+22
-16
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+9
-4
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+2
-3
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+20
-11
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+13
-2
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+5
-3
tensorpack/utils/discretize.py
tensorpack/utils/discretize.py
+14
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+6
-2
tensorpack/utils/globvars.py
tensorpack/utils/globvars.py
+3
-1
tensorpack/utils/gpu.py
tensorpack/utils/gpu.py
+3
-1
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+11
-9
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+14
-4
tensorpack/utils/lut.py
tensorpack/utils/lut.py
+3
-1
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+4
-3
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+4
-2
tensorpack/utils/stats.py
tensorpack/utils/stats.py
+11
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+6
-1
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+22
-14
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+26
-17
No files found.
tensorpack/RL/__init__.py
View file @
fb2a051c
...
...
@@ -8,6 +8,8 @@ import os
import
os.path
__all__
=
[]
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
tensorpack/RL/common.py
View file @
fb2a051c
...
...
@@ -9,7 +9,8 @@ from collections import deque
from
.envbase
import
ProxyPlayer
__all__
=
[
'PreventStuckPlayer'
,
'LimitLengthPlayer'
,
'AutoRestartPlayer'
,
'MapPlayerState'
]
'MapPlayerState'
]
class
PreventStuckPlayer
(
ProxyPlayer
):
""" Prevent the player from getting stuck (repeating a no-op)
...
...
@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer):
where the agent needs to press the 'start' button to start playing.
"""
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
It does auto-reset, but doesn't auto-restart the underlying player.
...
...
@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
super
(
PreventStuckPlayer
,
self
)
.
restart_episode
()
self
.
act_que
.
clear
()
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout
"""
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
...
...
@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
self
.
cnt
=
0
class
AutoRestartPlayer
(
ProxyPlayer
):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
if
isOver
:
...
...
@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
self
.
player
.
restart_episode
()
return
r
,
isOver
class
MapPlayerState
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
func
):
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
self
.
func
=
func
...
...
tensorpack/RL/envbase.py
View file @
fb2a051c
...
...
@@ -13,8 +13,10 @@ from ..utils import get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
'DiscreteActionSpace'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
RLEnvironment
(
object
):
def
__init__
(
self
):
self
.
reset_stat
()
...
...
@@ -60,13 +62,15 @@ class RLEnvironment(object):
s
=
self
.
current_state
()
act
=
func
(
s
)
r
,
isOver
=
self
.
action
(
act
)
#print r
#
print r
if
isOver
:
s
=
[
self
.
stats
[
k
]
for
k
in
stat
]
self
.
reset_stat
()
return
s
if
len
(
s
)
>
1
else
s
[
0
]
class
ActionSpace
(
object
):
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
...
...
@@ -77,7 +81,9 @@ class ActionSpace(object):
def
num_actions
(
self
):
raise
NotImplementedError
()
class
DiscreteActionSpace
(
ActionSpace
):
def
__init__
(
self
,
num
):
super
(
DiscreteActionSpace
,
self
)
.
__init__
()
self
.
num
=
num
...
...
@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
def
__str__
(
self
):
return
"DiscreteActionSpace({})"
.
format
(
self
.
num
)
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
def
__init__
(
self
):
self
.
k
=
0
def
current_state
(
self
):
self
.
k
+=
1
return
self
.
k
def
action
(
self
,
act
):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
def
__init__
(
self
,
player
):
self
.
player
=
player
...
...
tensorpack/RL/expreplay.py
View file @
fb2a051c
...
...
@@ -10,14 +10,15 @@ import six
from
six.moves
import
queue
from
..dataflow
import
DataFlow
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
,
get_rng
from
..utils.concurrency
import
LoopThread
from
..callbacks.base
import
Callback
__all__
=
[
'ExpReplay'
]
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
ExpReplay
(
DataFlow
,
Callback
):
"""
...
...
@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def
__init__
(
self
,
predictor_io_names
,
player
,
batch_size
=
32
,
memory_size
=
1e6
,
init_memory_size
=
50000
,
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
,
update_frequency
=
1
,
history_len
=
1
):
predictor_io_names
,
player
,
batch_size
=
32
,
memory_size
=
1e6
,
init_memory_size
=
50000
,
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
,
update_frequency
=
1
,
history_len
=
1
):
"""
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
...
...
@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
#if len(self.mem):
#
from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem[0]))
#
return
#
if len(self.mem):
#
from copy import deepcopy # quickly fill the memory for debug
#
self.mem.append(deepcopy(self.mem[0]))
#
return
old_s
=
self
.
player
.
current_state
()
if
self
.
rng
.
rand
()
<=
self
.
exploration
:
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
...
...
@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
while
True
:
batch_exp
=
[
self
.
_sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2 # for debug
#def view_state(state, next_state):
#
""" for debugging state representation"""
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([r, r2], axis=0)
#
print r.shape
#
cv2.imshow("state", r)
#
cv2.waitKey()
#exp = batch_exp[0]
#print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#if exp[2] or exp[4]:
#
view_state(exp[0], exp[1])
#
import cv2 # for debug
#
def view_state(state, next_state):
#
""" for debugging state representation"""
#
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#
r = np.concatenate([r, r2], axis=0)
#
print r.shape
#
cv2.imshow("state", r)
#
cv2.waitKey()
#
exp = batch_exp[0]
#
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#
if exp[2] or exp[4]:
#
view_state(exp[0], exp[1])
yield
self
.
_process_batch
(
batch_exp
)
self
.
_populate_job_queue
.
put
(
1
)
...
...
@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
# when x.isOver==True, (x+1).state is of a different episode
idx
=
self
.
rng
.
randint
(
len
(
self
.
mem
)
-
self
.
history_len
-
1
)
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
samples
=
[
self
.
mem
[
k
]
for
k
in
range
(
idx
,
idx
+
self
.
history_len
+
1
)]
def
concat
(
idx
):
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
v
=
[
x
.
state
for
x
in
samples
[
idx
:
idx
+
self
.
history_len
]]
return
np
.
concatenate
(
v
,
axis
=
2
)
state
=
concat
(
0
)
next_state
=
concat
(
1
)
...
...
@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
# zero-fill state before starting
zero_fill
=
False
for
k
in
range
(
1
,
self
.
history_len
):
if
samples
[
start_idx
-
k
]
.
isOver
:
if
samples
[
start_idx
-
k
]
.
isOver
:
zero_fill
=
True
if
zero_fill
:
state
[:,
:,
-
k
-
1
]
=
0
state
[:,
:,
-
k
-
1
]
=
0
if
k
+
2
<=
self
.
history_len
:
next_state
[:,
:,
-
k
-
2
]
=
0
next_state
[:,
:,
-
k
-
2
]
=
0
return
(
state
,
next_state
,
reward
,
action
,
isOver
)
def
_process_batch
(
self
,
batch_exp
):
...
...
@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
def
_before_train
(
self
):
# spawn a separate thread to run policy, can speed up 1.3x
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
1
)
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
...
...
@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback):
pass
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
from
.atari
import
AtariPlayer
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
player
=
player
,
num_actions
=
player
.
get_action_space
()
.
num_actions
(),
populate_size
=
1001
,
history_len
=
4
)
player
=
player
,
num_actions
=
player
.
get_action_space
()
.
num_actions
(),
populate_size
=
1001
,
history_len
=
4
)
E
.
_init_memory
()
for
k
in
E
.
get_data
():
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
pass
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break
#
import IPython;
#
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#
break
tensorpack/RL/gymenv.py
View file @
fb2a051c
...
...
@@ -9,7 +9,7 @@ from ..utils import logger
try
:
import
gym
# TODO
#gym.undo_logger_setup()
#
gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__
=
[
'GymEnv'
]
...
...
@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
_ENV_LOCK
=
threading
.
Lock
()
class
GymEnv
(
RLEnvironment
):
"""
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
with
_ENV_LOCK
:
self
.
gymenv
=
gym
.
make
(
name
)
...
...
@@ -82,7 +84,7 @@ if __name__ == '__main__':
rng
=
get_rng
(
num
)
while
True
:
act
=
rng
.
choice
(
range
(
num
))
#print act
#
print act
r
,
o
=
env
.
action
(
act
)
env
.
current_state
()
if
r
!=
0
or
o
:
...
...
tensorpack/RL/history.py
View file @
fb2a051c
...
...
@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images
Assume player will do auto-restart.
"""
def
__init__
(
self
,
player
,
hist_len
):
"""
:param hist_len: total length of the state, including the current
...
...
@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
super
(
HistoryFramePlayer
,
self
)
.
restart_episode
()
self
.
history
.
clear
()
self
.
history
.
append
(
self
.
player
.
current_state
())
tensorpack/RL/simulator.py
View file @
fb2a051c
...
...
@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
try
:
import
zmq
...
...
@@ -34,8 +34,10 @@ except ImportError:
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
__all__
=
[]
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
...
...
@@ -44,6 +46,7 @@ class TransitionExperience(object):
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SimulatorProcessBase
(
mp
.
Process
):
...
...
@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to
send states and receive the next action
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
"""
:param idx: idx of this process
...
...
@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
#s2c_socket.set_hwm(5)
#
s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
state
=
player
.
current_state
()
...
...
@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
...
...
@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
class
SimulatorProcessDF
(
SimulatorProcessBase
):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def
__init__
(
self
,
idx
,
pipe_c2s
):
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
self
.
pipe_c2s
=
pipe_c2s
...
...
@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
def
get_data
(
self
):
pass
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
condvar
=
condvar
...
...
@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
self
.
sess_updater
=
SessionUpdate
(
self
.
predictor
.
session
,
vars_to_update
)
self
.
predictor
.
session
,
vars_to_update
)
# TODO setup callback for explore?
self
.
predictor
.
graph
.
finalize
()
...
...
@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
...
...
@@ -260,6 +272,7 @@ class WeightSync(Callback):
def
_before_train
(
self
):
self
.
_sync
()
def
_trigger_epoch
(
self
):
self
.
_sync
()
...
...
@@ -274,13 +287,18 @@ class WeightSync(Callback):
if
__name__
==
'__main__'
:
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
client
.
memory
=
[]
...
...
@@ -296,4 +314,3 @@ if __name__ == '__main__':
import
time
time
.
sleep
(
100
)
tensorpack/__init__.py
View file @
fb2a051c
...
...
@@ -2,7 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
# avoid https://github.com/tensorflow/tensorflow/issues/2034
import
numpy
# avoid https://github.com/tensorflow/tensorflow/issues/2034
import
cv2
# avoid https://github.com/tensorflow/tensorflow/issues/1924
from
tensorpack.train
import
*
...
...
tensorpack/callbacks/__init__.py
View file @
fb2a051c
...
...
@@ -7,6 +7,8 @@ import os
__all__
=
[]
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
continue
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
tensorpack/callbacks/base.py
View file @
fb2a051c
...
...
@@ -11,6 +11,7 @@ import six
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
Callback
(
object
):
""" Base class for all callbacks """
...
...
@@ -72,7 +73,9 @@ class Callback(object):
def
__str__
(
self
):
return
type
(
self
)
.
__name__
class
ProxyCallback
(
Callback
):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
...
...
@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
"""
def
__init__
(
self
,
cb
,
period
):
"""
:param cb: a `Callback`
...
...
@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
def
__str__
(
self
):
return
"Periodic-"
+
str
(
self
.
cb
)
tensorpack/callbacks/concurrency.py
View file @
fb2a051c
...
...
@@ -9,7 +9,9 @@ from ..utils import logger
__all__
=
[
'StartProcOrThread'
]
class
StartProcOrThread
(
Callback
):
def
__init__
(
self
,
procs_threads
):
"""
Start extra threads and processes before training
...
...
@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
self
.
_procs_threads
=
procs_threads
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
\
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
logger
.
info
(
"Starting "
+
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_procs_threads
)
tensorpack/callbacks/dispatcher.py
View file @
fb2a051c
...
...
@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
__all__
=
[
'OutputTensorDispatcer'
]
class
OutputTensorDispatcer
(
object
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_idxs
=
[]
...
...
tensorpack/callbacks/dump.py
View file @
fb2a051c
...
...
@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
__all__
=
[
'DumpParamAsImage'
]
class
DumpParamAsImage
(
Callback
):
"""
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
"""
def
__init__
(
self
,
var_name
,
prefix
=
None
,
map_func
=
None
,
scale
=
255
,
clip
=
False
):
"""
:param var_name: the name of the variable.
...
...
@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
if
self
.
clip
:
res
=
np
.
clip
(
res
,
0
,
255
)
cv2
.
imwrite
(
fname
,
res
.
astype
(
'uint8'
))
tensorpack/callbacks/graph.py
View file @
fb2a051c
...
...
@@ -10,8 +10,10 @@ from ..utils import logger
__all__
=
[
'RunOp'
]
class
RunOp
(
Callback
):
""" Run an op periodically"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
:param setup_func: a function that returns the op in the graph
...
...
@@ -34,5 +36,5 @@ class RunOp(Callback):
if
self
.
run_epoch
:
self
.
_op
.
run
()
#def _log(self):
#
def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/group.py
View file @
fb2a051c
...
...
@@ -12,7 +12,9 @@ from ..utils import logger
__all__
=
[
'Callbacks'
]
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
tot
=
0
...
...
@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}"
.
format
(
self
.
tot
,
'; '
.
join
(
msgs
)))
class
Callbacks
(
Callback
):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def
__init__
(
self
,
cbs
):
"""
:param cbs: a list of `Callbacks`
...
...
tensorpack/callbacks/inference.py
View file @
fb2a051c
...
...
@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics
from
..tfutils
import
get_op_var_name
__all__
=
[
'ClassificationError'
,
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
'ScalarStats'
,
'Inferencer'
,
'BinaryClassificationStats'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
object
):
...
...
@@ -59,12 +60,14 @@ class Inferencer(object):
def
_get_output_tensors
(
self
):
pass
class
ScalarStats
(
Inferencer
):
"""
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
"""
def
__init__
(
self
,
names_to_print
,
prefix
=
'validation'
):
"""
:param names_to_print: list of names of tensors, or just a name
...
...
@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
ret
[
name
]
=
stat
return
ret
class
ClassificationError
(
Inferencer
):
"""
Compute classification error in batch mode, from a `wrong` variable
...
...
@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
"""
def
__init__
(
self
,
wrong_var_name
=
'incorrect_vector'
,
summary_name
=
'val_error'
):
"""
:param wrong_var_name: name of the `wrong` variable
...
...
@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
def
_after_inference
(
self
):
return
{
self
.
summary_name
:
self
.
err_stat
.
ratio
}
class
BinaryClassificationStats
(
Inferencer
):
""" Compute precision/recall in binary classification, given the
prediction vector and the label vector.
...
...
tensorpack/callbacks/inference_runner.py
View file @
fb2a051c
...
...
@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
__all__
=
[
'InferenceRunner'
]
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
...
...
@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
continue
trainer
.
write_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
"""
A callback that runs different kinds of inferencer.
...
...
@@ -54,16 +56,17 @@ class InferenceRunner(Callback):
self
.
input_tensors
=
input_tensors
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
pred_func
=
self
.
trainer
.
get_predict_func
(
self
.
input_tensors
,
self
.
output_tensors
)
self
.
input_tensors
,
self
.
output_tensors
)
def
_find_input_tensors
(
self
):
if
self
.
input_tensors
is
None
:
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def
get_name
(
x
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
...
...
@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_tensors
,
all_names
))
def
find_oid
(
idxs
):
ret
=
[]
for
idx
in
idxs
:
...
...
@@ -102,7 +106,7 @@ class InferenceRunner(Callback):
outputs
=
self
.
pred_func
(
dp
)
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
inf_output
=
[(
outputs
if
k
.
isOutput
else
dp
)[
k
.
index
]
for
k
in
tensormap
]
for
k
in
tensormap
]
inf
.
datapoint
(
inf_output
)
pbar
.
update
()
self
.
_write_summary_after_inference
()
...
...
@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
def
_write_summary_after_inference
(
self
):
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
class
FeedfreeInferenceRunner
(
Callback
):
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
...
...
@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback):
if
self
.
input_tensor_names
is
not
None
:
assert
isinstance
(
self
.
input_tensor_names
,
list
)
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_tensor_names
),
\
"names of input tensors are not defined in the Model"
"names of input tensors are not defined in the Model"
def
_find_output_tensors
(
self
):
# doesn't support output an input tensor
...
...
@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
all_names
def
find_oid
(
idxs
):
ret
=
[]
for
idx
in
idxs
:
...
...
@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
# list of list of (var_name: IOTensor)
def
_trigger_epoch
(
self
):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
...
...
@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
#outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
for k in tensormap]
#
inf.datapoint(inf_output)
#
outputs = self.pred_func(dp)
#
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#
inf_output = [(outputs if k.isOutput else dp)[k.index]
#
for k in tensormap]
#
inf.datapoint(inf_output)
pbar
.
update
()
self
.
_write_summary_after_inference
()
...
...
tensorpack/callbacks/param.py
View file @
fb2a051c
...
...
@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
HyperParam
(
object
):
""" Base class for a hyper param"""
...
...
@@ -35,8 +37,10 @@ class HyperParam(object):
""" A name to display"""
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
""" a variable in the graph can be a hyperparam"""
def
__init__
(
self
,
name
,
shape
=
[]):
self
.
name
=
name
self
.
shape
=
shape
...
...
@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
""" an attribute of an object can be a hyperparam"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
self
.
obj
=
obj
...
...
@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
def
get_value
(
self
,
v
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after every epoch.
...
...
@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
if
v
is
not
None
:
self
.
param
.
set_value
(
v
)
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by loading the value from a file each time it get called.
"""
def
__init__
(
self
,
param
,
file_name
=
'hyper.txt'
):
"""
:param file_name: a file containing the value of the variable.
...
...
@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
with
open
(
self
.
file_name
)
as
f
:
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
ret
=
dic
[
self
.
param
.
readable_name
]
return
ret
except
:
...
...
@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
self
.
param
.
readable_name
,
self
.
file_name
))
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by a predefined schedule.
"""
def
__init__
(
self
,
param
,
schedule
,
interp
=
None
):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
...
...
@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
...
...
@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter):
def
_get_value_to_set
(
self
):
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
class
StatMonitorParamSetter
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
):
last_k
,
reverse
=
False
):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
...
...
@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter):
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
return
None
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
hist_first
=
hist
[
0
]
if
not
self
.
reverse
:
hist_min
=
min
(
hist
)
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
if
hist_min
<
hist_first
-
self
.
threshold
:
# small enough
return
None
else
:
hist_max
=
max
(
hist
)
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
if
hist_max
>
hist_first
+
self
.
threshold
:
# large enough
return
None
self
.
last_changed_epoch
=
self
.
epoch_num
logger
.
info
(
"[StatMonitorParamSetter] Triggered, history: "
+
','
.
join
(
map
(
str
,
hist
)))
','
.
join
(
map
(
str
,
hist
)))
return
self
.
value_func
(
self
.
get_current_value
())
tensorpack/callbacks/saver.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
os
,
shutil
import
os
import
shutil
import
re
from
.base
import
Callback
...
...
@@ -13,12 +14,14 @@ from ..tfutils import get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
class
ModelSaver
(
Callback
):
"""
Save the model to logger directory.
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
None
):
var_collections
=
None
):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
...
...
@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try
:
if
not
self
.
meta_graph_written
:
self
.
saver
.
export_meta_graph
(
os
.
path
.
join
(
logger
.
LOG_DIR
,
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
collection_list
=
self
.
graph
.
get_all_collection_keys
())
os
.
path
.
join
(
logger
.
LOG_DIR
,
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
collection_list
=
self
.
graph
.
get_all_collection_keys
())
self
.
meta_graph_written
=
True
self
.
saver
.
save
(
tf
.
get_default_session
(),
...
...
@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
...
...
@@ -116,15 +121,14 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
path
=
ckpt
.
model_checkpoint_path
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
self
.
filename
or
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
self
.
filename
or
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
shutil
.
copy
(
path
,
newname
)
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
class
MaxSaver
(
MinSaver
):
def
__init__
(
self
,
monitor_stat
):
super
(
MaxSaver
,
self
)
.
__init__
(
monitor_stat
,
True
)
tensorpack/callbacks/stats.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
re
,
os
import
re
import
os
import
operator
import
json
...
...
@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
class
StatHolder
(
object
):
"""
A holder to keep all statistics aside from tensorflow events.
"""
def
__init__
(
self
,
log_dir
):
"""
:param log_dir: directory to save the stats.
...
...
@@ -62,9 +65,11 @@ class StatHolder(object):
ret
=
[]
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
if
v
is
not
None
:
ret
.
append
(
v
)
return
ret
def
finalize
(
self
):
...
...
@@ -88,13 +93,15 @@ class StatHolder(object):
with
open
(
tmp_filename
,
'w'
)
as
f
:
json
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
except
IOError
:
# disk error sometimes..
except
IOError
:
# disk error sometimes..
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
class
StatPrinter
(
Callback
):
"""
Control what stats to print.
"""
def
__init__
(
self
,
print_tag
=
None
):
"""
:param print_tag: a list of regex to match scalar summary to print.
...
...
@@ -116,6 +123,7 @@ class StatPrinter(Callback):
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
class
SendStat
(
Callback
):
"""
Execute a command with some specific stats.
...
...
@@ -126,6 +134,7 @@ class SendStat(Callback):
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
def
__init__
(
self
,
command
,
stats
):
self
.
command
=
command
if
not
isinstance
(
stats
,
list
):
...
...
tensorpack/dataflow/__init__.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ from . import imgaug
__all__
=
[
'dataset'
,
'imgaug'
,
'dftools'
]
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug']
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
)
and
\
module_name
not
in
__SKIP
:
module_name
not
in
__SKIP
:
_global_import
(
module_name
)
tensorpack/dataflow/base.py
View file @
fb2a051c
...
...
@@ -10,6 +10,7 @@ from ..utils import get_rng
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
DataFlow
(
object
):
""" Base class for all DataFlow """
...
...
@@ -17,7 +18,6 @@ class DataFlow(object):
class
Infinity
:
pass
@
abstractmethod
def
get_data
(
self
):
"""
...
...
@@ -44,11 +44,14 @@ class DataFlow(object):
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
def
__init__
(
self
,
ds
):
"""
:param ds: a :mod:`DataFlow` instance to proxy
...
...
tensorpack/dataflow/common.py
View file @
fb2a051c
...
...
@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'LocallyShuffleData'
,
'TestDataSpeed'
,
'BatchDataByShape'
]
class
TestDataSpeed
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
size
=
1000
):
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
self
.
test_size
=
size
...
...
@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
Group data in `ds` into batches.
...
...
@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
raise
except
:
logger
.
exception
(
"Cannot batch data. Perhaps they are of inconsistent shape?"
)
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
return
result
class
BatchDataByShape
(
BatchData
):
def
__init__
(
self
,
ds
,
batch_size
,
idx
):
""" Group datapoint of the same shape together to batches
...
...
@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
yield
BatchData
.
_aggregate_batch
(
holder
)
del
holder
[:]
class
FixedSizeData
(
ProxyDataFlow
):
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
"""
def
__init__
(
self
,
ds
,
size
):
"""
:param ds: a :mod:`DataFlow` to produce data
...
...
@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
if
cnt
==
self
.
_size
:
return
class
RepeatedData
(
ProxyDataFlow
):
""" Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times.
"""
def
__init__
(
self
,
ds
,
nr
):
"""
:param ds: a :mod:`DataFlow` instance.
...
...
@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
yield
dp
class
MapData
(
ProxyDataFlow
):
""" Apply map/filter a function on the datapoint"""
def
__init__
(
self
,
ds
,
func
):
"""
:param ds: a :mod:`DataFlow` instance.
...
...
@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
if
ret
is
not
None
:
yield
ret
class
MapDataComponent
(
ProxyDataFlow
):
""" Apply map/filter on the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
"""
:param ds: a :mod:`DataFlow` instance.
...
...
@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
dp
[
self
.
index
]
=
repl
# NOTE modifying
yield
dp
class
RandomChooseData
(
RNGDataFlow
):
"""
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
"""
def
__init__
(
self
,
df_lists
):
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
...
...
@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
except
StopIteration
:
return
class
RandomMixData
(
RNGDataFlow
):
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
def
__init__
(
self
,
df_lists
):
"""
:param df_lists: list of dataflow.
...
...
@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
idxs
=
np
.
array
(
list
(
map
(
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
)))
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
for
k
in
idxs
:
yield
next
(
itrs
[
k
])
class
ConcatData
(
DataFlow
):
"""
Concatenate several dataflows.
"""
def
__init__
(
self
,
df_lists
):
"""
:param df_lists: list of :mod:`DataFlow` instances
...
...
@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
for
dp
in
d
.
get_data
():
yield
dp
class
JoinData
(
DataFlow
):
"""
Join the components from each DataFlow.
...
...
@@ -321,6 +342,7 @@ class JoinData(DataFlow):
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
def
__init__
(
self
,
df_lists
):
"""
:param df_lists: list of :mod:`DataFlow` instances
...
...
@@ -329,7 +351,7 @@ class JoinData(DataFlow):
self
.
_size
=
self
.
df_lists
[
0
]
.
size
()
for
d
in
self
.
df_lists
:
assert
d
.
size
()
==
self
.
_size
,
\
"All DataFlow must have the same size! {} != {}"
.
format
(
d
.
size
(),
self
.
_size
)
"All DataFlow must have the same size! {} != {}"
.
format
(
d
.
size
(),
self
.
_size
)
def
reset_state
(
self
):
for
d
in
self
.
df_lists
:
...
...
@@ -352,7 +374,9 @@ class JoinData(DataFlow):
for
itr
in
itrs
:
del
itr
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
,
nr_reuse
=
1
):
"""
Cache a number of datapoints and shuffle them.
...
...
@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield
v
return
def
SelectComponent
(
ds
,
idxs
):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
tensorpack/dataflow/dataset/__init__.py
View file @
fb2a051c
...
...
@@ -7,6 +7,8 @@ import os
import
os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
tensorpack/dataflow/dataset/bsds500.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
glob
import
os
import
glob
import
cv2
import
numpy
as
np
...
...
@@ -21,6 +22,7 @@ except ImportError:
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
RNGDataFlow
):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
...
...
@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
...
...
@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
if
__name__
==
'__main__'
:
a
=
BSDS500
(
'val'
)
for
k
in
a
.
get_data
():
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
waitKey
(
1000
)
tensorpack/dataflow/dataset/cifar.py
View file @
fb2a051c
...
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import
os
,
sys
import
os
import
sys
import
pickle
import
numpy
as
np
import
random
...
...
@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
...
...
@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
import
tarfile
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar
(
filenames
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
ret
=
[]
...
...
@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum):
data
=
dic
[
b
'data'
]
if
cifar_classnum
==
10
:
label
=
dic
[
b
'labels'
]
IMG_NUM
=
10000
# cifar10 data are split into blocks of 10000
IMG_NUM
=
10000
# cifar10 data are split into blocks of 10000
elif
cifar_classnum
==
100
:
label
=
dic
[
b
'fine_labels'
]
IMG_NUM
=
50000
if
'train'
in
fname
else
10000
...
...
@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
ret
.
append
([
img
,
label
[
k
]])
return
ret
def
get_filenames
(
dir
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
...
...
@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
class
CifarBase
(
RNGDataFlow
):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
"""
Args:
...
...
@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
return three values as mean of each channel
"""
mean
=
self
.
get_per_pixel_mean
()
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
class
Cifar10
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
class
Cifar100
(
CifarBase
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
...
...
@@ -149,7 +159,6 @@ if __name__ == '__main__':
print
(
mean
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
#for (img, label) in ds.get_data():
#from IPython import embed; embed()
#break
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
# break
tensorpack/dataflow/dataset/ilsvrc.py
View file @
fb2a051c
...
...
@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
"""
Some metadata for ILSVRC dataset.
"""
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
dir
=
get_dataset_path
(
'ilsvrc_metadata'
)
...
...
@@ -82,14 +84,16 @@ class ILSVRCMeta(object):
with
open
(
mean_file
,
'rb'
)
as
f
:
obj
.
ParseFromString
(
f
.
read
())
arr
=
np
.
array
(
obj
.
data
)
.
reshape
((
3
,
256
,
256
))
.
astype
(
'float32'
)
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
arr
=
np
.
transpose
(
arr
,
[
1
,
2
,
0
])
if
size
is
not
None
:
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
return
arr
class
ILSVRC12
(
RNGDataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
,
include_bb
=
False
):
dir_structure
=
'original'
,
include_bb
=
False
):
"""
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
...
...
@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow):
if
include_bb
:
bbdir
=
os
.
path
.
join
(
dir
,
'bbox'
)
if
not
\
isinstance
(
include_bb
,
six
.
string_types
)
else
include_bb
isinstance
(
include_bb
,
six
.
string_types
)
else
include_bb
assert
name
==
'train'
,
'Bounding box only available for training'
self
.
bblist
=
ILSVRC12
.
get_training_bbox
(
bbdir
,
self
.
imglist
)
self
.
include_bb
=
include_bb
...
...
@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
im
=
cv2
.
imread
(
fname
.
strip
(),
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
if
self
.
include_bb
:
bb
=
self
.
bblist
[
k
]
if
bb
is
None
:
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
bb
=
[
0
,
0
,
im
.
shape
[
1
]
-
1
,
im
.
shape
[
0
]
-
1
]
yield
[
im
,
label
,
bb
]
else
:
yield
[
im
,
label
]
...
...
@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
#print(meta.get_synset_words_1000())
#
print(meta.get_synset_words_1000())
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
shuffle
=
False
)
shuffle
=
False
)
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
tensorpack/dataflow/dataset/mnist.py
View file @
fb2a051c
...
...
@@ -17,6 +17,7 @@ __all__ = ['Mnist']
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
...
...
@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory):
download
(
SOURCE_URL
+
filename
,
work_directory
)
return
filepath
def
_read32
(
bytestream
):
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
extract_images
(
filename
):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2051
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
num_images
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
...
...
@@ -46,24 +49,27 @@ def extract_images(filename):
data
=
data
.
astype
(
'float32'
)
/
255.0
return
data
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
num_items
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
num_items
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
class
Mnist
(
RNGDataFlow
):
"""
Return [image, label],
image is 28x28 in the range [0,1]
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
...
...
@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
from
IPython
import
embed
embed
()
break
tensorpack/dataflow/dataset/ptb.py
View file @
fb2a051c
...
...
@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
...
...
@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None):
# TODO these functions in TF might not be available in the future
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
tensorpack/dataflow/dataset/svhn.py
View file @
fb2a051c
...
...
@@ -19,6 +19,7 @@ except ImportError:
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
RNGDataFlow
):
"""
SVHN Cropped Digit Dataset.
...
...
@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow):
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
"File {} not found! Please download it from {}."
.
format
(
filename
,
SVHN_URL
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
X
=
data
[
'X'
]
.
transpose
(
3
,
0
,
1
,
2
)
self
.
Y
=
data
[
'y'
]
.
reshape
((
-
1
))
self
.
Y
[
self
.
Y
==
10
]
=
0
self
.
Y
[
self
.
Y
==
10
]
=
0
SVHNDigit
.
_Cache
[
name
]
=
(
self
.
X
,
self
.
Y
)
def
size
(
self
):
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ import json
__all__
=
[
'VisualQA'
]
def
read_json
(
fname
):
f
=
open
(
fname
)
ret
=
json
.
load
(
f
)
...
...
@@ -19,11 +20,14 @@ def read_json(fname):
return
ret
# TODO shuffle
class
VisualQA
(
DataFlow
):
"""
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
with
timed_operation
(
'Reading VQA JSON file'
):
qobj
,
aobj
=
list
(
map
(
read_json
,
[
question_file
,
annotation_file
]))
...
...
@@ -62,7 +66,7 @@ class VisualQA(DataFlow):
""" Get the n most common words in questions
n=4600 ~= thresh 6
"""
from
nltk.tokenize
import
word_tokenize
# will need to download 'punckt'
from
nltk.tokenize
import
word_tokenize
# will need to download 'punckt'
cnt
=
Counter
()
for
q
in
self
.
questions
:
cnt
.
update
(
word_tokenize
(
q
[
'question'
]
.
lower
()))
...
...
@@ -72,7 +76,7 @@ class VisualQA(DataFlow):
if
__name__
==
'__main__'
:
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
for
k
in
vqa
.
get_data
():
print
(
json
.
dumps
(
k
))
break
...
...
tensorpack/dataflow/dftools.py
View file @
fb2a051c
...
...
@@ -2,7 +2,8 @@
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
sys
,
os
import
sys
import
os
import
cv2
import
multiprocessing
as
mp
import
six
...
...
@@ -23,6 +24,8 @@ else:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
# TODO pass a name_func to write label as filename?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a `DataFlow` to a directory.
...
...
@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img
=
dp
[
index
]
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
...
...
@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
assert
not
os
.
path
.
isfile
(
lmdb_path
),
"LMDB file exists!"
ds
.
reset_state
()
db
=
lmdb
.
open
(
lmdb_path
,
subdir
=
isdir
,
map_size
=
1099511627776
*
2
,
readonly
=
False
,
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
map_size
=
1099511627776
*
2
,
readonly
=
False
,
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
try
:
sz
=
ds
.
size
()
except
NotImplementedError
:
...
...
@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
the queue once you start it. Each element is (task_id, dp).
"""
q
=
mp
.
Queue
(
size
)
class
EnqueProc
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
ds
=
ds
...
...
@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
tensorpack/dataflow/format.py
View file @
fb2a051c
...
...
@@ -40,10 +40,13 @@ Adapters for different data format.
"""
# TODO lazy load
class
HDF5Data
(
RNGDataFlow
):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
:param filename: h5 data file.
...
...
@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
self
.
dps
=
[
self
.
f
[
k
]
.
value
for
k
in
data_paths
]
lens
=
[
len
(
k
)
for
k
in
self
.
dps
]
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
assert
all
([
k
==
lens
[
0
]
for
k
in
lens
])
self
.
_size
=
lens
[
0
]
self
.
shuffle
=
shuffle
...
...
@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
...
...
@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow):
def
open_lmdb
(
self
):
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
if
self
.
_shuffle
:
...
...
@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
v
=
self
.
_txn
.
get
(
k
)
yield
[
k
,
v
]
class
LMDBDataDecoder
(
LMDBData
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
"""
:param decoder: a function taking k, v and return a data point,
...
...
@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
def
get_data
(
self
):
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
if
v
:
yield
v
if
v
:
yield
v
class
LMDBDataPoint
(
LMDBDataDecoder
):
""" Read a LMDB file where each value is a serialized datapoint"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
class
CaffeLMDB
(
LMDBDataDecoder
):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
cpb
=
get_caffe_pb
()
def
decoder
(
k
,
v
):
try
:
datum
=
cpb
.
Datum
()
...
...
@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder):
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
super
(
CaffeLMDB
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
)
class
SVMLightData
(
RNGDataFlow
):
""" Read X,y from a svmlight file """
def
__init__
(
self
,
filename
,
shuffle
=
True
):
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
...
...
@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
for
id
in
idxs
:
yield
[
self
.
X
[
id
,:],
self
.
y
[
id
]]
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
tensorpack/dataflow/image.py
View file @
fb2a051c
...
...
@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageComponents'
]
class
ImageFromFile
(
RNGDataFlow
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
...
...
@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
if
self
.
resize
is
not
None
:
im
=
cv2
.
resize
(
im
,
self
.
resize
[::
-
1
])
if
self
.
channel
==
1
:
im
=
im
[:,
:,
np
.
newaxis
]
im
=
im
[:,
:,
np
.
newaxis
]
yield
[
im
]
class
AugmentImageComponent
(
MapDataComponent
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
"""
Augment the image component of datapoints
...
...
@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
class
AugmentImageComponents
(
MapData
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
""" Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
...
...
tensorpack/dataflow/imgaug/__init__.py
View file @
fb2a051c
...
...
@@ -7,6 +7,7 @@ from pkgutil import walk_packages
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
tensorpack/dataflow/imgaug/_test.py
View file @
fb2a051c
...
...
@@ -10,15 +10,15 @@ from .crop import *
from
.imgproc
import
*
from
.noname
import
*
from
.deform
import
*
from
.noise
import
SaltPepperNoise
from
.noise
import
SaltPepperNoise
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
augmentors
=
AugmentorList
([
Contrast
((
0.8
,
1.2
)),
Contrast
((
0.8
,
1.2
)),
Flip
(
horiz
=
True
),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
#RandomCropRandomShape(0.3),
GaussianDeform
(
anchors
,
(
360
,
480
),
0.2
,
randrange
=
20
),
#
RandomCropRandomShape(0.3),
SaltPepperNoise
()
])
...
...
tensorpack/dataflow/imgaug/base.py
View file @
fb2a051c
...
...
@@ -9,6 +9,7 @@ from six.moves import zip
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
Augmentor
(
object
):
""" Base class for an augmentor"""
...
...
@@ -58,7 +59,9 @@ class Augmentor(object):
size
=
[]
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
class
ImageAugmentor
(
Augmentor
):
def
augment
(
self
,
img
):
"""
Perform augmentation on the image in-place.
...
...
@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
return
coord
class
AugmentorList
(
ImageAugmentor
):
"""
Augment by a list of augmentors
"""
def
__init__
(
self
,
augmentors
):
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
...
...
@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """
for
a
in
self
.
augs
:
a
.
reset_state
()
tensorpack/dataflow/imgaug/crop.py
View file @
fb2a051c
...
...
@@ -10,10 +10,12 @@ from six.moves import range
import
numpy
as
np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
class
RandomCrop
(
ImageAugmentor
):
""" Randomly crop the image into a smaller one """
def
__init__
(
self
,
crop_shape
):
"""
:param crop_shape: a shape like (h, w)
...
...
@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
orig_shape
=
img
.
shape
assert
orig_shape
[
0
]
>=
self
.
crop_shape
[
0
]
\
and
orig_shape
[
1
]
>=
self
.
crop_shape
[
1
],
orig_shape
and
orig_shape
[
1
]
>=
self
.
crop_shape
[
1
],
orig_shape
diffh
=
orig_shape
[
0
]
-
self
.
crop_shape
[
0
]
h0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
diffw
=
orig_shape
[
1
]
-
self
.
crop_shape
[
1
]
...
...
@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
param
):
h0
,
w0
=
param
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
class
CenterCrop
(
ImageAugmentor
):
""" Crop the image at the center"""
def
__init__
(
self
,
crop_shape
):
"""
:param crop_shape: a shape like (h, w)
...
...
@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
orig_shape
=
img
.
shape
h0
=
int
((
orig_shape
[
0
]
-
self
.
crop_shape
[
0
])
*
0.5
)
w0
=
int
((
orig_shape
[
1
]
-
self
.
crop_shape
[
1
])
*
0.5
)
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
return
img
[
h0
:
h0
+
self
.
crop_shape
[
0
],
w0
:
w0
+
self
.
crop_shape
[
1
]]
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
class
FixedCrop
(
ImageAugmentor
):
""" Crop a rectangle at a given location"""
def
__init__
(
self
,
rect
):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
...
...
@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
orig_shape
=
img
.
shape
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
def
perturb_BB
(
image_shape
,
bb
,
max_pertub_pixel
,
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
max_try
=
100
):
rng
=
None
,
max_aspect_ratio_diff
=
0.3
,
max_try
=
100
):
"""
Perturb a bounding box.
:param image_shape: [h, w]
...
...
@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
"""
Crop a box around a bounding box
"""
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
...
...
@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
shape
=
img
.
shape
[:
2
]
box
=
Rect
(
0
,
0
,
shape
[
1
]
-
1
,
shape
[
0
]
-
1
)
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
dist
=
self
.
perturb_ratio
*
np
.
sqrt
(
shape
[
0
]
*
shape
[
1
])
newbox
=
perturb_BB
(
shape
,
box
,
dist
,
self
.
rng
,
self
.
max_aspect_ratio_diff
)
self
.
rng
,
self
.
max_aspect_ratio_diff
)
return
newbox
def
_augment
(
self
,
img
,
newbox
):
...
...
@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
class
RandomCropRandomShape
(
ImageAugmentor
):
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
"""
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape.
...
...
@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
diffh
=
img
.
shape
[
0
]
-
h
diffw
=
img
.
shape
[
1
]
-
w
assert
diffh
>=
0
and
diffw
>=
0
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
return
(
y0
,
x0
,
h
,
w
)
return
(
y0
,
x0
,
h
,
w
)
def
_augment
(
self
,
img
,
param
):
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/dataflow/imgaug/deform.py
View file @
fb2a051c
...
...
@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup
class
GaussianMap
(
object
):
""" Generate gaussian weighted deformation map"""
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
assert
len
(
image_shape
)
==
2
self
.
shape
=
image_shape
...
...
@@ -25,17 +27,18 @@ class GaussianMap(object):
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
#cv2.imshow(" ", g)
#cv2.waitKey()
#
cv2.waitKey()
return
g
def
np_sample
(
img
,
coords
):
# a numpy implementation of ImageSample layer
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
lcoor
=
np
.
floor
(
coords
)
.
astype
(
'int32'
)
ucoor
=
lcoor
+
1
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
ucoor
=
np
.
minimum
(
ucoor
,
np
.
array
([
img
.
shape
[
0
]
-
1
,
img
.
shape
[
1
]
-
1
]))
diff
=
coords
-
lcoor
neg_diff
=
1.0
-
diff
...
...
@@ -46,17 +49,20 @@ def np_sample(img, coords):
diffy
,
diffx
=
np
.
split
(
diff
,
2
,
axis
=
2
)
ndiffy
,
ndiffx
=
np
.
split
(
neg_diff
,
2
,
axis
=
2
)
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
ret
=
img
[
lcoory
,
lcoorx
,
:]
*
ndiffx
*
ndiffy
+
\
img
[
ucoory
,
ucoorx
,
:]
*
diffx
*
diffy
+
\
img
[
lcoory
,
ucoorx
,
:]
*
ndiffy
*
diffx
+
\
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
# TODO input/output with different shape
class
GaussianDeform
(
ImageAugmentor
):
"""
Some kind of deformation. Quite slow.
"""
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
"""
:param anchors: in [0,1] coordinate
...
...
@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
self
.
anchors
=
anchors
self
.
K
=
len
(
self
.
anchors
)
self
.
shape
=
shape
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
self
.
grid
=
np
.
mgrid
[
0
:
self
.
shape
[
0
],
0
:
self
.
shape
[
1
]]
.
transpose
(
1
,
2
,
0
)
self
.
grid
=
self
.
grid
.
astype
(
'float32'
)
# HxWx2
gm
=
GaussianMap
(
self
.
shape
,
sigma
=
sigma
)
self
.
gws
=
np
.
array
([
gm
.
get_gaussian_weight
(
ank
)
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#HxWxK
for
ank
in
self
.
anchors
],
dtype
=
'float32'
)
# KxHxW
self
.
gws
=
self
.
gws
.
transpose
(
1
,
2
,
0
)
#
HxWxK
if
randrange
is
None
:
self
.
randrange
=
self
.
shape
[
0
]
/
8
else
:
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
fb2a051c
...
...
@@ -10,11 +10,13 @@ import numpy as np
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
class
Rotation
(
ImageAugmentor
):
""" Random rotate the image w.r.t a random center"""
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
def
__init__
(
self
,
max_deg
,
center_range
=
(
0
,
1
),
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
"""
:param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center
...
...
@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
center
=
img
.
shape
[
1
::
-
1
]
*
self
.
_rand_range
(
self
.
center_range
[
0
],
self
.
center_range
[
1
],
(
2
,))
self
.
center_range
[
0
],
self
.
center_range
[
1
],
(
2
,))
deg
=
self
.
_rand_range
(
-
self
.
max_deg
,
self
.
max_deg
)
return
cv2
.
getRotationMatrix2D
(
tuple
(
center
),
deg
,
1
)
def
_augment
(
self
,
img
,
rot_m
):
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
flags
=
self
.
interp
,
borderMode
=
self
.
border
)
return
ret
class
RotationAndCropValid
(
ImageAugmentor
):
""" Random rotate and crop the largest possible rect without the border
This will produce images of different shapes.
"""
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
super
(
RotationAndCropValid
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor):
return
deg
def
_augment
(
self
,
img
,
deg
):
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
center
=
(
img
.
shape
[
1
]
*
0.5
,
img
.
shape
[
0
]
*
0.5
)
rot_m
=
cv2
.
getRotationMatrix2D
(
center
,
deg
,
1
)
ret
=
cv2
.
warpAffine
(
img
,
rot_m
,
img
.
shape
[
1
::
-
1
],
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
flags
=
self
.
interp
,
borderMode
=
cv2
.
BORDER_CONSTANT
)
neww
,
newh
=
RotationAndCropValid
.
largest_rotated_rect
(
ret
.
shape
[
1
],
ret
.
shape
[
0
],
deg
)
neww
=
min
(
neww
,
ret
.
shape
[
1
])
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
#print(ret.shape, deg, newx, newy, neww, newh)
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
@
staticmethod
def
largest_rotated_rect
(
w
,
h
,
angle
):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle
=
angle
/
180.0
*
math
.
pi
if
w
<=
0
or
h
<=
0
:
return
0
,
0
return
0
,
0
width_is_longer
=
w
>=
h
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
side_long
,
side_short
=
(
w
,
h
)
if
width_is_longer
else
(
h
,
w
)
# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a
,
cos_a
=
abs
(
math
.
sin
(
angle
)),
abs
(
math
.
cos
(
angle
))
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x
=
0.5
*
side_short
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
if
side_short
<=
2.
*
sin_a
*
cos_a
*
side_long
:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x
=
0.5
*
side_short
wr
,
hr
=
(
x
/
sin_a
,
x
/
cos_a
)
if
width_is_longer
else
(
x
/
cos_a
,
x
/
sin_a
)
else
:
# fully constrained case: crop touches all 4 sides
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
# fully constrained case: crop touches all 4 sides
cos_2a
=
cos_a
*
cos_a
-
sin_a
*
sin_a
wr
,
hr
=
(
w
*
cos_a
-
h
*
sin_a
)
/
cos_2a
,
(
h
*
cos_a
-
w
*
sin_a
)
/
cos_2a
return
int
(
wr
),
int
(
hr
)
tensorpack/dataflow/imgaug/imgproc.py
View file @
fb2a051c
...
...
@@ -7,12 +7,14 @@ import numpy as np
import
cv2
__all__
=
[
'Brightness'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
]
class
Brightness
(
ImageAugmentor
):
"""
Random adjust brightness.
"""
def
__init__
(
self
,
delta
,
clip
=
True
):
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...
...
@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
class
Contrast
(
ImageAugmentor
):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
"""
def
__init__
(
self
,
factor_range
,
clip
=
True
):
"""
:param factor_range: an interval to random sample the `contrast_factor`.
...
...
@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
return
self
.
_rand_range
(
*
self
.
factor_range
)
def
_augment
(
self
,
img
,
r
):
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
img
=
(
img
-
mean
)
*
r
+
mean
if
self
.
clip
:
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
class
MeanVarianceNormalize
(
ImageAugmentor
):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
def
__init__
(
self
,
all_channel
=
True
):
"""
:param all_channel: if True, normalize all channels together. else separately.
...
...
@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
mean
=
np
.
mean
(
img
)
std
=
np
.
std
(
img
)
else
:
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
mean
=
np
.
mean
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
std
(
img
,
axis
=
(
0
,
1
),
keepdims
=
True
)
std
=
np
.
maximum
(
std
,
1.0
/
np
.
sqrt
(
np
.
prod
(
img
.
shape
)))
img
=
(
img
-
mean
)
/
std
return
img
class
GaussianBlur
(
ImageAugmentor
):
def
__init__
(
self
,
max_size
=
3
):
""":params max_size: (maximum kernel size-1)/2"""
super
(
GaussianBlur
,
self
)
.
__init__
()
...
...
@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor):
def
_augment
(
self
,
img
,
s
):
return
cv2
.
GaussianBlur
(
img
,
s
,
sigmaX
=
0
,
sigmaY
=
0
,
borderType
=
cv2
.
BORDER_REPLICATE
)
borderType
=
cv2
.
BORDER_REPLICATE
)
class
Gamma
(
ImageAugmentor
):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
super
(
Gamma
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
img
=
cv2
.
LUT
(
img
,
lut
)
.
astype
(
'float32'
)
return
img
class
Clip
(
ImageAugmentor
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
self
.
_init
(
locals
())
...
...
@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
img
=
np
.
clip
(
img
,
self
.
min
,
self
.
max
)
return
img
class
Saturation
(
ImageAugmentor
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
...
...
@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
def
_augment
(
self
,
img
,
v
):
grey
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,:,
np
.
newaxis
]
return
img
*
v
+
(
grey
*
(
1
-
v
))[:,
:,
np
.
newaxis
]
class
Lighting
(
ImageAugmentor
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
...
...
@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
eigval
=
np
.
asarray
(
eigval
)
eigvec
=
np
.
asarray
(
eigvec
)
assert
eigval
.
shape
==
(
3
,)
assert
eigvec
.
shape
==
(
3
,
3
)
assert
eigvec
.
shape
==
(
3
,
3
)
self
.
_init
(
locals
())
def
_get_augment_params
(
self
,
img
):
...
...
@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
inc
=
np
.
dot
(
self
.
eigvec
,
v
)
.
reshape
((
3
,))
img
+=
inc
return
img
tensorpack/dataflow/imgaug/meta.py
View file @
fb2a051c
...
...
@@ -7,14 +7,18 @@
from
.base
import
ImageAugmentor
__all__
=
[
'RandomChooseAug'
,
'MapImage'
,
'Identity'
,
'RandomApplyAug'
,
'RandomOrderAug'
]
'RandomOrderAug'
]
class
Identity
(
ImageAugmentor
):
def
_augment
(
self
,
img
,
_
):
return
img
class
RandomApplyAug
(
ImageAugmentor
):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
def
__init__
(
self
,
aug
,
prob
):
self
.
_init
(
locals
())
super
(
RandomApplyAug
,
self
)
.
__init__
()
...
...
@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
else
:
return
self
.
aug
.
_augment
(
img
,
prm
[
1
])
class
RandomChooseAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
...
...
@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
idx
,
prm
=
prm
return
self
.
aug_lists
[
idx
]
.
_augment
(
img
,
prm
)
class
RandomOrderAug
(
ImageAugmentor
):
def
__init__
(
self
,
aug_lists
):
"""
Shuffle the augmentors into random order.
...
...
@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
img
=
self
.
aug_lists
[
k
]
.
_augment
(
img
,
prms
[
k
])
return
img
class
MapImage
(
ImageAugmentor
):
"""
Map the image array by a function.
"""
def
__init__
(
self
,
func
):
"""
:param func: a function which takes a image array and return a augmented one
...
...
@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
def
_augment
(
self
,
img
,
_
):
return
self
.
func
(
img
)
tensorpack/dataflow/imgaug/noise.py
View file @
fb2a051c
...
...
@@ -9,7 +9,9 @@ import cv2
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
class
JpegNoise
(
ImageAugmentor
):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
super
(
JpegNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
class
GaussianNoise
(
ImageAugmentor
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
...
...
@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
ret
=
np
.
clip
(
ret
,
0
,
255
)
return
ret
class
SaltPepperNoise
(
ImageAugmentor
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
fb2a051c
...
...
@@ -10,10 +10,12 @@ import cv2
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
]
class
Flip
(
ImageAugmentor
):
"""
Random flip.
"""
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
"""
Only one of horiz, vert can be set.
...
...
@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
class
Resize
(
ImageAugmentor
):
""" Resize image to a target size"""
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
"""
:param shape: shape in (h, w)
...
...
@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
img
,
self
.
shape
[::
-
1
],
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
class
ResizeShortestEdge
(
ImageAugmentor
):
""" Resize the shortest edge to a certain number while
keeping the aspect ratio
"""
def
__init__
(
self
,
size
):
size
=
size
*
1.0
self
.
_init
(
locals
())
...
...
@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor):
desSize
=
map
(
int
,
[
scale
*
w
,
scale
*
h
])
ret
=
cv2
.
resize
(
img
,
tuple
(
desSize
),
interpolation
=
cv2
.
INTER_CUBIC
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
class
RandomResize
(
ImageAugmentor
):
""" randomly rescale w and h of the image"""
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
"""
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
...
...
@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
def
_augment
(
self
,
img
,
dsize
):
ret
=
cv2
.
resize
(
img
,
dsize
,
interpolation
=
self
.
interp
)
if
img
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
tensorpack/dataflow/imgaug/paste.py
View file @
fb2a051c
...
...
@@ -9,11 +9,12 @@ from abc import abstractmethod
import
numpy
as
np
__all__
=
[
'CenterPaste'
,
'BackgroundFiller'
,
'ConstantBackgroundFiller'
,
'RandomPaste'
]
'RandomPaste'
]
class
BackgroundFiller
(
object
):
""" Base class for all BackgroundFiller"""
def
fill
(
self
,
background_shape
,
img
):
"""
Return a proper background image of background_shape, given img
...
...
@@ -28,8 +29,10 @@ class BackgroundFiller(object):
def
_fill
(
self
,
background_shape
,
img
):
pass
class
ConstantBackgroundFiller
(
BackgroundFiller
):
""" Fill the background by a constant """
def
__init__
(
self
,
value
):
"""
:param value: the value to fill the background.
...
...
@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
return_shape
=
background_shape
return
np
.
zeros
(
return_shape
)
+
self
.
value
class
CenterPaste
(
ImageAugmentor
):
"""
Paste the image onto the center of a background canvas.
"""
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
"""
:param background_shape: shape of the background canvas.
...
...
@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
self
.
background_shape
,
img
)
y0
=
int
((
self
.
background_shape
[
0
]
-
img_shape
[
0
])
*
0.5
)
x0
=
int
((
self
.
background_shape
[
1
]
-
img_shape
[
1
])
*
0.5
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
class
RandomPaste
(
CenterPaste
):
"""
Randomly paste the image onto a background convas
"""
def
_get_augment_params
(
self
,
img
):
img_shape
=
img
.
shape
[:
2
]
assert
self
.
background_shape
[
0
]
>
img_shape
[
0
]
and
self
.
background_shape
[
1
]
>
img_shape
[
1
]
...
...
@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
img_shape
=
img
.
shape
[:
2
]
background
=
self
.
background_filler
.
fill
(
self
.
background_shape
,
img
)
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
background
[
y0
:
y0
+
img_shape
[
0
],
x0
:
x0
+
img_shape
[
1
]]
=
img
return
background
tensorpack/dataflow/prefetch.py
View file @
fb2a051c
...
...
@@ -13,7 +13,7 @@ import os
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
(
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
)
mask_sigint
,
start_proc_mask_signal
)
from
..utils.serialize
import
loads
,
dumps
from
..utils
import
logger
from
..utils.gpu
import
change_gpu
...
...
@@ -28,6 +28,7 @@ else:
class
PrefetchProcess
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
"""
:param ds: ds to take data from
...
...
@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
class
PrefetchData
(
ProxyDataFlow
):
"""
Prefetch data from a `DataFlow` using multiprocessing
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
:param ds: a `DataFlow` instance.
...
...
@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes
pass
def
BlockParallel
(
ds
,
queue_size
):
# TODO more doc
"""
...
...
@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
"""
return
PrefetchData
(
ds
,
queue_size
,
1
)
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
:param ds: a `DataFlow` instance.
...
...
@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
for
dp
in
self
.
ds
.
get_data
():
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
"""
:param ds: a `DataFlow` instance.
...
...
@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
except
:
pass
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
...
...
@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
with
change_gpu
(
gpu
):
proc
.
start
()
tensorpack/dataflow/raw.py
View file @
fb2a051c
...
...
@@ -17,8 +17,10 @@ except:
else
:
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
""" Generate fake fixed data of given shapes"""
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
"""
:param shapes: a list of lists/tuples
...
...
@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
for
_
in
range
(
self
.
_size
):
yield
copy
.
deepcopy
(
v
)
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
...
...
@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
while
True
:
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
...
...
@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
for
k
in
idxs
:
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
...
...
@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
tensorpack/dataflow/remote.py
View file @
fb2a051c
...
...
@@ -17,6 +17,7 @@ from .common import RepeatedData
from
..utils
import
logger
from
..utils.serialize
import
dumps
,
loads
def
serve_data
(
ds
,
addr
):
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
...
...
@@ -36,7 +37,9 @@ def serve_data(ds, addr):
if
not
ctx
.
closed
:
ctx
.
destroy
(
0
)
class
RemoteData
(
DataFlow
):
def
__init__
(
self
,
addr
):
self
.
ctx
=
zmq
.
Context
()
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
...
...
@@ -54,7 +57,7 @@ if __name__ == '__main__':
from
.raw
import
FakeData
addr
=
"tcp://127.0.0.1:8877"
if
sys
.
argv
[
1
]
==
'serve'
:
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
ds
=
FakeData
([(
128
,
244
,
244
,
3
)],
1000
)
serve_data
(
ds
,
addr
)
else
:
ds
=
RemoteData
(
addr
)
...
...
@@ -62,4 +65,3 @@ if __name__ == '__main__':
with
tqdm
(
total
=
10000
)
as
pbar
:
for
k
in
ds
.
get_data
():
pbar
.
update
()
tensorpack/dataflow/tf_func.py
View file @
fb2a051c
...
...
@@ -14,9 +14,11 @@ except ImportError:
else
:
__all__
=
[
'TFFuncMapper'
]
class
TFFuncMapper
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
...
...
@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow):
def
run_func
(
vals
):
return
self
.
sess
.
run
(
self
.
output_vars
,
feed_dict
=
dict
(
zip
(
self
.
placeholders
,
vals
)))
feed_dict
=
dict
(
zip
(
self
.
placeholders
,
vals
)))
self
.
run_func
=
run_func
def
get_data
(
self
):
...
...
@@ -63,16 +65,16 @@ if __name__ == '__main__':
v
=
tf
.
image
.
random_flip_left_right
(
v
)
return
v
ds
=
TFFuncMapper
(
ds
,
lambda
:
[
tf
.
placeholder
(
tf
.
float32
,
[
224
,
224
,
3
],
name
=
'img'
)],
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
#ds = AugmentImageComponent(ds,
#
[imgaug.Brightness(0.1, clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Flip(horiz=True)
#
])
#ds = PrefetchDataZMQ(ds, 4)
lambda
:
[
tf
.
placeholder
(
tf
.
float32
,
[
224
,
224
,
3
],
name
=
'img'
)],
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
#
ds = AugmentImageComponent(ds,
#
[imgaug.Brightness(0.1, clip=False),
#
imgaug.Contrast((0.8, 1.2), clip=False),
#
imgaug.Flip(horiz=True)
#
])
#
ds = PrefetchDataZMQ(ds, 4)
ds
.
reset_state
()
import
tqdm
...
...
tensorpack/models/__init__.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ from ..utils import logger
__all__
=
[
'LinearWrap'
]
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -32,6 +33,7 @@ class LinearWrap(object):
"""
class
TFModuleFunc
(
object
):
def
__init__
(
self
,
mod
,
tensor
):
self
.
_mod
=
mod
self
.
_t
=
tensor
...
...
@@ -88,4 +90,3 @@ class LinearWrap(object):
def
print_tensor
(
self
):
print
(
self
.
_t
)
return
self
tensorpack/models/_common.py
View file @
fb2a051c
...
...
@@ -5,7 +5,8 @@
import
tensorflow
as
tf
from
functools
import
wraps
import
six
import
copy
,
os
import
copy
import
os
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.modelutils
import
get_shape_str
...
...
@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
# make sure each layer is only logged once
_layer_logged
=
set
()
def
disable_layer_logging
():
class
ContainEverything
:
def
__contains__
(
self
,
x
):
return
True
# can use nonlocal in python3, but how
globals
()[
'_layer_logged'
]
=
ContainEverything
()
def
layer_register
(
summary_activation
=
False
,
log_shape
=
True
,
...
...
@@ -42,13 +46,13 @@ def layer_register(
def
wrapped_func
(
*
args
,
**
kwargs
):
if
use_scope
:
name
,
inputs
=
args
[
0
],
args
[
1
]
args
=
args
[
1
:]
# actual positional args used to call func
args
=
args
[
1
:]
# actual positional args used to call func
assert
isinstance
(
name
,
six
.
string_types
),
name
else
:
assert
not
log_shape
and
not
summary_activation
if
isinstance
(
args
[
0
],
six
.
string_types
):
name
,
inputs
=
args
[
0
],
args
[
1
]
args
=
args
[
1
:]
# actual positional args used to call func
args
=
args
[
1
:]
# actual positional args used to call func
else
:
inputs
=
args
[
0
]
name
=
None
...
...
@@ -97,13 +101,14 @@ def layer_register(
# need some special handling for sphinx to work with the arguments
on_doc
=
os
.
environ
.
get
(
'READTHEDOCS'
)
==
'True'
\
or
os
.
environ
.
get
(
'TENSORPACK_DOC_BUILDING'
)
or
os
.
environ
.
get
(
'TENSORPACK_DOC_BUILDING'
)
if
on_doc
:
from
decorator
import
decorator
wrapper
=
decorator
(
wrapper
)
return
wrapper
def
shape4d
(
a
):
# for use with tensorflow NHWC ops
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
tensorpack/models/_test.py
View file @
fb2a051c
...
...
@@ -7,7 +7,9 @@ import tensorflow as tf
import
numpy
as
np
import
unittest
class
TestModel
(
unittest
.
TestCase
):
def
run_variable
(
self
,
var
):
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
...
...
@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
else
:
return
tf
.
Variable
(
args
[
0
])
def
run_test_case
(
case
):
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
...
...
@@ -34,5 +37,3 @@ if __name__ == '__main__':
subs
=
tensorpack
.
models
.
_test
.
TestModel
.
__subclasses__
()
for
cls
in
subs
:
run_test_case
(
cls
)
tensorpack/models/batch_norm.py
View file @
fb2a051c
...
...
@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
(
log_shape
=
False
)
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
...
...
@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
initializer
=
tf
.
constant_initializer
())
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
initializer
=
tf
.
constant_initializer
(
1.0
))
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
...
...
@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
#reuse = tf.get_variable_scope().reuse
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
...
...
@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
#
#
use statistics in another tower
# use statistics in another tower
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
...
...
@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return
tf
.
nn
.
batch_normalization
(
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
@
layer_register
(
log_shape
=
False
)
def
BatchNormV2
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
...
...
@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
initializer
=
tf
.
constant_initializer
())
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
initializer
=
tf
.
constant_initializer
(
1.0
))
# x * gamma + beta
ctx
=
get_current_tower_context
()
...
...
@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
if
use_local_stat
:
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
x
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
)
epsilon
=
epsilon
,
is_training
=
True
)
# maintain EMA only in the main training tower
if
ctx
.
is_main_training_tower
:
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
else
:
...
...
@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
moving_mean, moving_var,
#
epsilon=epsilon, is_training=False, name='output')
#
xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#
moving_mean, moving_var,
#
epsilon=epsilon, is_training=False, name='output')
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
...
...
tensorpack/models/conv2d.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
@
layer_register
()
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'SAME'
,
stride
=
1
,
...
...
@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
3
,
outputs
)
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
class
StaticDynamicShape
(
object
):
def
__init__
(
self
,
static
,
dynamic
):
self
.
static
=
static
self
.
dynamic
=
dynamic
def
apply
(
self
,
f
):
try
:
st
=
f
(
self
.
static
)
...
...
@@ -76,11 +81,12 @@ class StaticDynamicShape(object):
except
:
return
StaticDynamicShape
(
None
,
f
(
self
.
dynamic
))
@
layer_register
()
def
Deconv2D
(
x
,
out_shape
,
kernel_shape
,
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
):
stride
,
padding
=
'SAME'
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
identity
,
use_bias
=
True
):
"""
2D deconvolution on 4D inputs.
...
...
tensorpack/models/fc.py
View file @
fb2a051c
...
...
@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'FullyConnected'
]
@
layer_register
()
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
...
...
@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
return
nl
(
prod
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
fb2a051c
...
...
@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
# XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206
def
sample
(
img
,
coords
):
"""
:param img: bxhxwxc
...
...
@@ -33,14 +35,15 @@ def sample(img, coords):
# bxh2xw2
batch_add
=
tf
.
range
(
tf
.
shape
(
img
)[
0
])
*
(
shape
[
0
]
*
shape
[
1
])
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
batch_add
=
tf
.
reshape
(
batch_add
,
[
-
1
,
1
,
1
])
#
bx1x1
flat_coords
=
coords
+
batch_add
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
img
=
tf
.
reshape
(
img
,
[
-
1
,
shape
[
2
]])
#
bhw x c
sampled
=
tf
.
gather
(
img
,
flat_coords
)
return
sampled
@
layer_register
()
def
ImageSample
(
inputs
,
borderMode
=
'repeat'
):
"""
...
...
@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'):
assert
template
.
get_shape
()
.
ndims
==
4
and
mapping
.
get_shape
()
.
ndims
==
4
input_shape
=
template
.
get_shape
()
.
as_list
()[
1
:]
assert
None
not
in
input_shape
,
\
"Images in ImageSample layer must have fully-defined shape"
"Images in ImageSample layer must have fully-defined shape"
assert
borderMode
in
[
'repeat'
,
'constant'
]
orig_mapping
=
mapping
...
...
@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
ucoor
=
lcoor
+
1
diff
=
mapping
-
lcoor
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
neg_diff
=
1.0
-
diff
#
bxh2xw2x2
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
...
...
@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'):
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
tf.reduce_max(diff), diff], summarize=50)
#
diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#
tf.reduce_max(diff), diff], summarize=50)
ret
=
tf
.
add_n
([
sample
(
template
,
lcoor
)
*
neg_diffx
*
neg_diffy
,
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
sample
(
template
,
lyux
)
*
neg_diffy
*
diffx
,
sample
(
template
,
uylx
)
*
diffy
*
neg_diffx
],
name
=
'sampled'
)
sample
(
template
,
ucoor
)
*
diffx
*
diffy
,
sample
(
template
,
lyux
)
*
neg_diffy
*
diffx
,
sample
(
template
,
uylx
)
*
diffy
*
neg_diffx
],
name
=
'sampled'
)
if
borderMode
==
'constant'
:
max_coor
=
tf
.
constant
([
input_shape
[
0
]
-
1
,
input_shape
[
1
]
-
1
],
dtype
=
tf
.
float32
)
mask
=
tf
.
greater_equal
(
orig_mapping
,
0.0
)
mask2
=
tf
.
less_equal
(
orig_mapping
,
max_coor
)
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
logical_and
(
mask
,
mask2
)
#
bxh2xw2x2
mask
=
tf
.
reduce_all
(
mask
,
[
3
])
# bxh2xw2 boolean
mask
=
tf
.
expand_dims
(
mask
,
3
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
return
ret
from
._test
import
TestModel
class
TestSample
(
TestModel
):
def
test_sample
(
self
):
import
numpy
as
np
h
,
w
=
3
,
4
def
np_sample
(
img
,
coords
):
# a reference implementation
coords
=
np
.
maximum
(
coords
,
0
)
coords
=
np
.
minimum
(
coords
,
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
xs
=
coords
[:,
:,:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
np
.
array
([
img
.
shape
[
1
]
-
1
,
img
.
shape
[
2
]
-
1
]))
xs
=
coords
[:,
:,
:,
1
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ys
=
coords
[:,
:,
:,
0
]
.
reshape
((
img
.
shape
[
0
],
-
1
))
ret
=
np
.
zeros
((
img
.
shape
[
0
],
coords
.
shape
[
1
],
coords
.
shape
[
2
],
img
.
shape
[
3
]),
dtype
=
'float32'
)
for
k
in
range
(
img
.
shape
[
0
]):
xss
,
yss
=
xs
[
k
],
ys
[
k
]
ret
[
k
,
:,:,:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
coords
.
shape
[
2
],
3
))
ret
[
k
,
:,
:,
:]
=
img
[
k
,
yss
,
xss
,
:]
.
reshape
((
coords
.
shape
[
1
],
coords
.
shape
[
2
],
3
))
return
ret
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
#mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#
mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
inp
,
mapping
=
self
.
make_variable
(
bimg
,
mat
)
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
output
=
sample
(
inp
,
tf
.
cast
(
tf
.
floor
(
mapping
+
0.5
),
tf
.
int32
))
res
=
self
.
run_variable
(
output
)
self
.
assertTrue
((
res
==
true_res
)
.
all
())
...
...
@@ -146,7 +153,7 @@ if __name__ == '__main__':
diff
=
200
for
x
in
range
(
w
):
for
y
in
range
(
h
):
mapping
[
0
,
y
,
x
,:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapping
[
0
,
y
,
x
,
:]
=
np
.
array
([
y
-
diff
+
0.4
,
x
-
diff
+
0.5
])
mapv
=
tf
.
Variable
(
mapping
)
output
=
ImageSample
(
'sample'
,
[
imv
,
mapv
],
borderMode
=
'constant'
)
...
...
@@ -155,12 +162,10 @@ if __name__ == '__main__':
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#print(out[0].min())
#print(out[0].max())
#print(out[0].sum())
#
print(out[0].min())
#
print(out[0].max())
#
print(out[0].sum())
out
=
sess
.
run
([
output
])[
0
]
im
=
out
[
0
]
cv2
.
imwrite
(
'sampled.jpg'
,
im
)
tensorpack/models/model_desc.py
View file @
fb2a051c
...
...
@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class
InputVar
(
object
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
self
.
type
=
type
self
.
shape
=
shape
self
.
name
=
name
self
.
sparse
=
sparse
def
dumps
(
self
):
return
pickle
.
dumps
(
self
)
@
staticmethod
def
loads
(
buf
):
return
pickle
.
loads
(
buf
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
""" Base class for a model description """
...
...
@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def
get_gradient_processor
(
self
):
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
#SummaryGradient(),
CheckGradient
()
]
return
[
# SummaryGradient(),
CheckGradient
()
]
class
ModelFromMetaGraph
(
ModelDesc
):
"""
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
"""
def
__init__
(
self
,
filename
):
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
()
.
VARIABLES
]:
tf
.
GraphKeys
()
.
VARIABLES
]:
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
"Collection {} not found in metagraph!"
.
format
(
k
)
def
_get_input_vars
(
self
):
col
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
...
...
tensorpack/models/nonlin.py
View file @
fb2a051c
...
...
@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
@
layer_register
()
def
Maxout
(
x
,
num_unit
):
"""
...
...
@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
x
=
tf
.
reshape
(
x
,
[
-
1
,
ch
/
num_unit
,
num_unit
])
return
tf
.
reduce_max
(
x
,
ndim
,
name
=
'output'
)
@
layer_register
(
log_shape
=
False
)
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
"""
...
...
@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name
=
'output'
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
"""
...
...
@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
# return tf.mul(x, 0.5, name=name)
@
layer_register
(
log_shape
=
False
,
use_scope
=
False
)
def
BNReLU
(
x
,
name
=
None
):
...
...
tensorpack/models/pool.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
@
layer_register
()
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
...
...
@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
def
AvgPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
...
...
@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
return
tf
.
nn
.
avg_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
@
layer_register
()
def
GlobalAvgPooling
(
x
):
"""
...
...
@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
return
tf
.
reduce_mean
(
x
,
[
1
,
2
])
# https://github.com/tensorflow/tensorflow/issues/2169
def
UnPooling2x2ZeroFilled
(
x
):
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)])
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)])
...
...
@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
return
ret
@
layer_register
()
def
FixedUnPooling
(
x
,
shape
,
unpool_mat
=
None
):
"""
...
...
@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
# perform a tensor-matrix kronecker product
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#
1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#
(bchw) x(shxsw)
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
...
...
@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
return
prod
@
layer_register
()
def
BilinearUpSample
(
x
,
shape
):
"""
...
...
@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
:param shape: an integer, the upsample factor
"""
#inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x,
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
align_corners=True)
#
return tf.image.resize_bilinear(x,
#
tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#
align_corners=True)
inp_shape
=
x
.
get_shape
()
.
as_list
()
ch
=
inp_shape
[
3
]
...
...
@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
shape
=
int
(
shape
)
filter_shape
=
2
*
shape
def
bilinear_conv_filler
(
s
):
"""
s: width, height of the conv filter
...
...
@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
ret
=
np
.
zeros
((
s
,
s
),
dtype
=
'float32'
)
for
x
in
range
(
s
):
for
y
in
range
(
s
):
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
ret
[
x
,
y
]
=
(
1
-
abs
(
x
/
f
-
c
))
*
(
1
-
abs
(
y
/
f
-
c
))
return
ret
w
=
bilinear_conv_filler
(
filter_shape
)
w
=
np
.
repeat
(
w
,
ch
*
ch
)
.
reshape
((
filter_shape
,
filter_shape
,
ch
,
ch
))
...
...
@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape):
shape
=
(
filter_shape
,
filter_shape
,
ch
,
ch
),
name
=
'bilinear_upsample_filter'
)
deconv
=
tf
.
nn
.
conv2d_transpose
(
x
,
weight_var
,
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
[
1
,
shape
,
shape
,
1
],
'SAME'
)
tf
.
shape
(
x
)
*
tf
.
constant
([
1
,
shape
,
shape
,
1
],
tf
.
int32
),
[
1
,
shape
,
shape
,
1
],
'SAME'
)
if
inp_shape
[
1
]:
inp_shape
[
1
]
*=
shape
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
if
inp_shape
[
1
]:
inp_shape
[
1
]
*=
shape
if
inp_shape
[
2
]:
inp_shape
[
2
]
*=
shape
deconv
.
set_shape
(
inp_shape
)
return
deconv
from
._test
import
TestModel
class
TestPool
(
TestModel
):
def
test_fixed_unpooling
(
self
):
h
,
w
=
3
,
4
mat
=
np
.
random
.
rand
(
h
,
w
,
3
)
.
astype
(
'float32'
)
...
...
@@ -173,13 +184,13 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
3
])
output
=
FixedUnPooling
(
'unpool'
,
inp
,
2
)
res
=
self
.
run_variable
(
output
)
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
self
.
assertEqual
(
res
.
shape
,
(
1
,
2
*
h
,
2
*
w
,
3
))
# mat is on cornser
ele
=
res
[
0
,
::
2
,::
2
,
0
]
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
ele
=
res
[
0
,
::
2
,
::
2
,
0
]
self
.
assertTrue
((
ele
==
mat
[:,
:,
0
])
.
all
())
# the rest are zeros
res
[
0
,
::
2
,::
2
,
:]
=
0
res
[
0
,
::
2
,
::
2
,
:]
=
0
self
.
assertTrue
((
res
==
0
)
.
all
())
def
test_upsample
(
self
):
...
...
@@ -191,7 +202,7 @@ class TestPool(TestModel):
inp
=
tf
.
reshape
(
inp
,
[
1
,
h
,
w
,
1
])
output
=
BilinearUpSample
(
'upsample'
,
inp
,
scale
)
res
=
self
.
run_variable
(
output
)[
0
,
:,:,
0
]
res
=
self
.
run_variable
(
output
)[
0
,
:,
:,
0
]
from
skimage.transform
import
rescale
res2
=
rescale
(
mat
,
scale
)
...
...
@@ -199,9 +210,9 @@ class TestPool(TestModel):
diff
=
np
.
abs
(
res2
-
res
)
# not equivalent to rescale on edge?
diff
[
0
,:]
=
0
diff
[:,
0
]
=
0
diff
[
0
,
:]
=
0
diff
[:,
0
]
=
0
if
not
diff
.
max
()
<
1e-4
:
import
IPython
;
import
IPython
IPython
.
embed
(
config
=
IPython
.
terminal
.
ipapp
.
load_default_config
())
self
.
assertTrue
(
diff
.
max
()
<
1e-4
)
tensorpack/models/regularize.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ from ._common import layer_register
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
@
memoized
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
...
...
@@ -19,6 +20,7 @@ def _log_regularizer(name):
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
def
regularize_cost
(
regex
,
func
,
name
=
None
):
"""
Apply a regularizer on every trainable variable matching the regex.
...
...
@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
is_training
=
get_current_tower_context
()
.
is_training
keep_prob
=
tf
.
constant
(
keep_prob
if
is_training
else
1.0
)
return
tf
.
nn
.
dropout
(
x
,
keep_prob
)
tensorpack/models/shapes.py
View file @
fb2a051c
...
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'ConcatWith'
]
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
ConcatWith
(
x
,
dim
,
tensor
):
"""
...
...
tensorpack/models/softmax.py
View file @
fb2a051c
...
...
@@ -8,6 +8,7 @@ from ._common import layer_register
__all__
=
[
'SoftMax'
]
@
layer_register
()
def
SoftMax
(
x
,
use_temperature
=
False
,
temperature_init
=
1.0
):
"""
...
...
@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0):
"""
if
use_temperature
:
t
=
tf
.
get_variable
(
'invtemp'
,
[],
initializer
=
tf
.
constant_initializer
(
1.0
/
float
(
temperature_init
)))
initializer
=
tf
.
constant_initializer
(
1.0
/
float
(
temperature_init
)))
x
=
x
*
t
return
tf
.
nn
.
softmax
(
x
,
name
=
'output'
)
tensorpack/predict/__init__.py
View file @
fb2a051c
...
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
continue
global_import
(
module_name
)
tensorpack/predict/base.py
View file @
fb2a051c
...
...
@@ -12,9 +12,10 @@ from ..utils import logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
'AsyncPredictorBase'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
PredictorBase
(
object
):
...
...
@@ -46,7 +47,9 @@ class PredictorBase(object):
:return: output as defined by the config
"""
class
AsyncPredictorBase
(
PredictorBase
):
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
...
...
@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
# in Tornado, Future.result() doesn't wait
return
fut
.
result
()
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
self
.
session
=
sess
self
.
return_input
=
return_input
...
...
@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
...
...
@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor):
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
...
...
@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
"""
for
k
in
towers
:
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
...
...
@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
output_vars
=
get_tensors_by_names
(
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
\
for
n
in
config
.
output_names
])
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
...
@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictors
(
self
,
n
):
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
...
...
@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
towername
=
PREDICT_TOWER
+
str
(
k
)
input_vars
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'-'
)
prefix
=
towername
+
'-'
)
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
config
.
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
\
for
n
in
config
.
output_names
]))
[
towername
+
'/'
+
n
for
n
in
config
.
output_names
]))
input_vars
=
get_tensors_by_names
(
input_var_names
)
config
.
session_init
.
init
(
sess
)
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
tensorpack/predict/common.py
View file @
fb2a051c
...
...
@@ -15,11 +15,13 @@ from .base import OfflinePredictor
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
"""
The config used by `get_predict_func`.
...
...
@@ -61,12 +63,14 @@ class PredictConfig(object):
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
"""
Produce a offline predictor run inside a new session.
...
...
@@ -76,4 +80,3 @@ def get_predict_func(config):
a list of output values defined in ``config.output_var_names``.
"""
return
OfflinePredictor
(
config
)
tensorpack/predict/concurrency.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
,
threading
import
multiprocessing
import
threading
import
tensorflow
as
tf
import
time
import
six
...
...
@@ -25,10 +26,12 @@ except ImportError:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
else
:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
'MultiThreadAsyncPredictor'
]
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
""" Base class for predict worker that runs offline in multiprocess"""
def
__init__
(
self
,
idx
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
...
...
@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" An offline predictor worker that takes input and produces output by queue"""
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
...
...
@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
...
...
@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
#print "Worker {} batched {} Queue {}".format(
#
self.id, len(futures), self.queue.qsize())
# debug, for speed testing
#if not hasattr(self, 'xxx'):
#
self.xxx = outputs = self.func(batched)
#else:
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
#
print "Worker {} batched {} Queue {}".format(
#
self.id, len(futures), self.queue.qsize())
#
debug, for speed testing
#
if not hasattr(self, 'xxx'):
#
self.xxx = outputs = self.func(batched)
#
else:
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
...
@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
cnt
+=
1
return
batched
,
futures
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally.
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
assert
len
(
predictors
)
...
...
@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
#assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
...
...
tensorpack/predict/dataset.py
View file @
fb2a051c
...
...
@@ -20,10 +20,12 @@ from .common import PredictConfig
from
.base
import
OfflinePredictor
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
'MultiProcessDatasetPredictor'
]
'MultiProcessDatasetPredictor'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
DatasetPredictorBase
(
object
):
def
__init__
(
self
,
config
,
dataset
):
"""
:param config: a `PredictConfig` instance.
...
...
@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
"""
return
list
(
self
.
get_result
())
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
"""
Run the predict_config on a given `DataFlow`.
"""
def
__init__
(
self
,
config
,
dataset
):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
...
...
@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
for
dp
in
self
.
dataset
.
get_data
():
res
=
self
.
predictor
(
dp
)
yield
res
pbar
.
update
()
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
...
...
@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
ordered
=
ordered
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
if
use_gpu
:
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
assert
len
(
gpus
)
>=
self
.
nr_proc
,
\
"nr_proc={} while only {} gpus available"
.
format
(
self
.
nr_proc
,
len
(
gpus
))
"nr_proc={} while only {} gpus available"
.
format
(
self
.
nr_proc
,
len
(
gpus
))
except
KeyError
:
# TODO number of GPUs not checked
gpus
=
list
(
range
(
self
.
nr_proc
))
...
...
@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# worker produces (idx, result) to outqueue
self
.
outqueue
=
multiprocessing
.
Queue
()
self
.
workers
=
[
MultiProcessQueuePredictWorker
(
i
,
self
.
inqueue
,
self
.
outqueue
,
self
.
config
)
for
i
in
range
(
self
.
nr_proc
)]
i
,
self
.
inqueue
,
self
.
outqueue
,
self
.
config
)
for
i
in
range
(
self
.
nr_proc
)]
# start inqueue and workers
self
.
inqueue_proc
.
start
()
...
...
@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
if
ordered
:
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
result_queue
)
else
:
...
...
@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
sz
=
0
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
die_cnt
=
0
while
True
:
res
=
self
.
result_queue
.
get
()
...
...
@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
result_queue
.
join
()
self
.
result_queue
.
terminate
()
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
p
.
join
()
p
.
terminate
()
tensorpack/tfutils/argscope.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack
=
[]
@
contextmanager
def
argscope
(
layers
,
**
param
):
if
not
isinstance
(
layers
,
list
):
...
...
@@ -33,6 +34,7 @@ def argscope(layers, **param):
yield
del
_ArgScopeStack
[
-
1
]
def
get_arg_scope
():
"""
:returns: the current argscope.
...
...
tensorpack/tfutils/common.py
View file @
fb2a051c
...
...
@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
'freeze_collection'
,
'get_tf_version'
]
def
get_default_sess_config
(
mem_fraction
=
0.99
):
"""
Return a better session config to use as default.
...
...
@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
#conf.log_device_placement = True
return
conf
def
get_global_step_var
():
""" :returns: the global_step variable in the current graph. create if not existed"""
try
:
...
...
@@ -45,19 +47,21 @@ def get_global_step_var():
except
KeyError
:
scope
=
tf
.
get_variable_scope
()
assert
scope
.
name
==
''
,
\
"Creating global_step_var under a variable scope would cause problems!"
"Creating global_step_var under a variable scope would cause problems!"
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
dtype
=
tf
.
int32
),
trainable
=
False
,
dtype
=
tf
.
int32
)
initializer
=
tf
.
constant_initializer
(
dtype
=
tf
.
int32
),
trainable
=
False
,
dtype
=
tf
.
int32
)
return
var
def
get_global_step
():
""" :returns: global_step value in current graph and session"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
get_global_step_var
())
def
get_op_tensor_name
(
name
):
"""
Tensor name is assumed to be ``op_name + ':0'``
...
...
@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
get_op_var_name
=
get_op_tensor_name
def
get_tensors_by_names
(
names
):
"""
Get a list of tensors in the default graph by a list of names
...
...
@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
get_vars_by_names
=
get_tensors_by_names
def
backup_collection
(
keys
):
ret
=
{}
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
yield
restore_collection
(
backup
)
def
get_tf_version
():
return
int
(
tf
.
__version__
.
split
(
'.'
)[
1
])
tensorpack/tfutils/gradproc.py
View file @
fb2a051c
...
...
@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
,
'GlobalNormClip'
]
def
apply_grad_processors
(
grads
,
gradprocs
):
"""
:param grads: list of (grad, var).
...
...
@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
g
=
proc
.
process
(
g
)
return
g
@
six
.
add_metaclass
(
ABCMeta
)
class
GradientProcessor
(
object
):
...
...
@@ -51,6 +53,7 @@ class GradientProcessor(object):
class
GlobalNormClip
(
GradientProcessor
):
def
__init__
(
self
,
global_norm
):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
...
...
@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
g
,
_
=
tf
.
clip_by_global_norm
(
g
,
self
.
_norm
,
name
=
'clip_by_global_norm'
)
return
list
(
zip
(
g
,
v
))
class
MapGradient
(
GradientProcessor
):
"""
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
...
...
@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor):
args
=
inspect
.
getargspec
(
func
)
.
args
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
assert
arg_num
in
[
1
,
2
],
\
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
if
arg_num
==
1
:
self
.
func
=
lambda
grad
,
var
:
func
(
grad
)
else
:
...
...
@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
_summaried_gradient
=
set
()
class
SummaryGradient
(
MapGradient
):
"""
Summary history and RMS for each graident variable
"""
def
__init__
(
self
):
super
(
SummaryGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
...
@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
add_moving_summary
(
rms
(
grad
,
name
=
name
+
'/rms'
))
return
grad
class
CheckGradient
(
MapGradient
):
"""
Check for numeric issue.
"""
def
__init__
(
self
):
super
(
CheckGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
...
@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
return
grad
class
ScaleGradient
(
MapGradient
):
"""
Scale certain gradient by a multiplier
"""
def
__init__
(
self
,
multipliers
,
log
=
True
):
"""
:param multipliers: list of (regex, float)
...
...
tensorpack/tfutils/modelutils.py
View file @
fb2a051c
...
...
@@ -9,6 +9,7 @@ from ..utils import logger
__all__
=
[
'describe_model'
,
'get_shape_str'
]
def
describe_model
():
""" print a description of the current model parameters """
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
...
@@ -40,5 +41,3 @@ def get_shape_str(tensors):
assert
isinstance
(
tensors
,
(
tf
.
Tensor
,
tf
.
Variable
)),
"Not a tensor: {}"
.
format
(
type
(
tensors
))
shape_str
=
str
(
tensors
.
get_shape
()
.
as_list
())
return
shape_str
tensorpack/tfutils/sessinit.py
View file @
fb2a051c
...
...
@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default.
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
...
...
@@ -35,23 +36,29 @@ class SessionInit(object):
def
_init
(
self
,
sess
):
pass
class
JustCurrentSession
(
SessionInit
):
""" Just use the current default session. This is a no-op placeholder"""
def
_init
(
self
,
sess
):
pass
class
NewSession
(
SessionInit
):
"""
Create a new session. All variables will be initialized by their
initializer.
"""
def
_init
(
self
,
sess
):
sess
.
run
(
tf
.
global_variables_initializer
())
class
SaverRestore
(
SessionInit
):
"""
Restore an old model saved by `ModelSaver`.
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
...
...
@@ -71,7 +78,7 @@ class SaverRestore(SessionInit):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
logger
.
warn
(
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
"[SaverRestore] {} is corrected to {} when restoring the model."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
self
.
set_path
(
model_path
)
...
...
@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
return
var_dict
class
ParamRestore
(
SessionInit
):
"""
Restore variables from a dictionary.
"""
def
__init__
(
self
,
param_dict
):
"""
:param param_dict: a dict of {name: value}
...
...
@@ -158,7 +167,7 @@ class ParamRestore(SessionInit):
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
...
...
@@ -174,14 +183,15 @@ class ParamRestore(SessionInit):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
\
get_savename_from_varname
(
v
.
name
)
in
intersect
])
[
v
for
v
in
variables
if
get_savename_from_varname
(
v
.
name
)
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
class
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
def
__init__
(
self
,
sess_inits
,
new_session
=
True
):
"""
:params sess_inits: list of `SessionInit` instances.
...
...
tensorpack/tfutils/summary.py
View file @
fb2a051c
...
...
@@ -15,6 +15,7 @@ from .symbolic_functions import rms
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'add_moving_summary'
,
'summary_moving_average'
]
def
create_summary
(
name
,
v
):
"""
Return a tf.Summary object with name and simple scalar value v
...
...
@@ -25,6 +26,7 @@ def create_summary(name, v):
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
return
s
def
add_activation_summary
(
x
,
name
=
None
):
"""
Add summary to graph for an activation tensor x.
...
...
@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
tf
.
summary
.
scalar
(
name
+
'-sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
tf
.
summary
.
scalar
(
name
+
'-rms'
,
rms
(
x
))
def
add_param_summary
(
summary_lists
):
"""
Add summary for all trainable variables matching the regex
...
...
@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
return
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
name
=
var
.
name
.
replace
(
':0'
,
''
)
...
...
@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
for
act
in
actions
:
perform
(
p
,
act
)
def
add_moving_summary
(
v
,
*
args
):
"""
:param v: tensor or list of tensor to summary
...
...
@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
@
memoized
def
summary_moving_average
(
tensors
=
None
):
"""
...
...
@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
return
avg_maintain_op
tensorpack/tfutils/symbolic_functions.py
View file @
fb2a051c
...
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
numpy
as
np
from
..utils
import
logger
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
"""
:param logits: NxC
...
...
@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
"""
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
tf
.
float32
,
name
=
name
)
def
flatten
(
x
):
"""
...
...
@@ -21,6 +23,7 @@ def flatten(x):
"""
return
tf
.
reshape
(
x
,
[
-
1
])
def
batch_flatten
(
x
):
"""
Flatten the tensor except the first dimension.
...
...
@@ -30,6 +33,7 @@ def batch_flatten(x):
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
tf
.
pack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
The class-balanced cross entropy loss,
...
...
@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
,
name
=
name
)
return
cost
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
"""
The class-balanced cross entropy loss,
...
...
@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
#loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#
loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
#
loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
def
print_stat
(
x
,
message
=
None
):
""" a simple print op.
Use it like: x = print_stat(x)
...
...
@@ -89,7 +95,8 @@ def print_stat(x, message=None):
if
message
is
None
:
message
=
x
.
op
.
name
return
tf
.
Print
(
x
,
[
tf
.
shape
(
x
),
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
message
=
message
,
name
=
'print_'
+
x
.
op
.
name
)
def
rms
(
x
,
name
=
None
):
if
name
is
None
:
...
...
@@ -98,14 +105,16 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
reduce_sum
(
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
def
get_scalar_var
(
name
,
init_value
,
summary
=
False
,
trainable
=
False
):
"""
...
...
@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
:param summary: summary this variable
"""
ret
=
tf
.
get_variable
(
name
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
init_value
),
trainable
=
trainable
)
initializer
=
tf
.
constant_initializer
(
init_value
),
trainable
=
trainable
)
if
summary
:
# this is recognized in callbacks.StatHolder
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ret
)
...
...
tensorpack/tfutils/tower.py
View file @
fb2a051c
...
...
@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
...
...
@@ -65,7 +67,7 @@ class TowerContext(object):
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
...
...
@@ -78,7 +80,7 @@ class TowerContext(object):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
tensorpack/tfutils/varmanip.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
,
os
import
six
import
os
import
tensorflow
as
tf
from
collections
import
defaultdict
import
re
...
...
@@ -13,7 +14,8 @@ from ..utils.naming import *
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
,
'is_training_name'
]
'get_savename_from_varname'
,
'is_training_name'
]
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
...
...
@@ -33,13 +35,15 @@ def get_savename_from_varname(
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
varname_prefix
is
not
None
\
and
name
.
startswith
(
varname_prefix
):
name
=
name
[
len
(
varname_prefix
)
+
1
:]
name
=
name
[
len
(
varname_prefix
)
+
1
:]
if
savename_prefix
is
not
None
:
name
=
savename_prefix
+
'/'
+
name
return
name
class
SessionUpdate
(
object
):
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
"""
:param vars_to_update: a collection of variables to update
...
...
@@ -66,11 +70,12 @@ class SessionUpdate(object):
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
...
...
@@ -90,6 +95,7 @@ the same name".format(v.name))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
def
dump_chkpt_vars
(
model_path
):
""" Dump all variables from a checkpoint to a dict"""
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
...
...
@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
result
[
n
]
=
reader
.
get_tensor
(
n
)
return
result
def
is_training_name
(
name
):
"""
This is only used to improve logging.
...
...
tensorpack/train/__init__.py
View file @
fb2a051c
...
...
@@ -8,6 +8,7 @@ import os.path
__all__
=
[]
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
...
...
@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if
module_name
.
startswith
(
'_'
):
continue
global_import
(
module_name
)
tensorpack/train/base.py
View file @
fb2a051c
...
...
@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
pass
@
six
.
add_metaclass
(
ABCMeta
)
class
Trainer
(
object
):
""" Base class for a trainer."""
...
...
@@ -91,7 +94,7 @@ class Trainer(object):
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
suffix
=
'-summary'
# issue#6150
suffix
=
'-summary'
# issue#6150
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
...
...
@@ -99,7 +102,7 @@ class Trainer(object):
def
write_scalar_summary
(
self
,
name
,
val
):
self
.
summary_writer
.
add_summary
(
create_summary
(
name
,
val
),
get_global_step
())
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
setup
(
self
):
...
...
@@ -138,7 +141,7 @@ class Trainer(object):
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
for
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
...
...
@@ -147,7 +150,7 @@ class Trainer(object):
**
get_tqdm_kwargs
(
leave
=
True
)):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
# not useful?
# trigger epoch outside the timing region.
self
.
trigger_epoch
()
...
...
tensorpack/train/config.py
View file @
fb2a051c
...
...
@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
.input_data
import
InputData
__all__
=
[
'TrainConfig'
]
class
TrainConfig
(
object
):
"""
Config for training a model with a single loss
"""
def
__init__
(
self
,
**
kwargs
):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
...
...
tensorpack/train/feedfree.py
View file @
fb2a051c
...
...
@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
class
FeedfreeTrainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
...
...
@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
()
...
...
@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var
=
self
.
model
.
get_cost
()
# GATE_NONE faster?
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
False
)
cost_var
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
False
)
add_moving_summary
(
cost_var
)
return
cost_var
,
grads
def
run_step
(
self
):
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
#if not hasattr(self, 'cnt'):
#self.cnt = 0
#else:
#self.cnt += 1
#if self.cnt % 10 == 0:
## debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# else:
# self.cnt += 1
# if self.cnt % 10 == 0:
# # debug-benchmark code:
# run_metadata = tf.RunMetadata()
# self.sess.run([self.train_op],
# options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
# run_metadata=run_metadata
# )
# from tensorflow.python.client import timeline
# trace = timeline.Timeline(step_stats=run_metadata.step_stats)
# trace_file = open('timeline.ctf.json', 'w')
# trace_file.write(trace.generate_chrome_trace_format())
# import sys; sys.exit()
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
):
"""
A trainer with single cost, single training tower and feed-free input
...
...
@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer(
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
def
_setup
(
self
):
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
...
...
@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
...
...
@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
tensorpack/train/input_data.py
View file @
fb2a051c
...
...
@@ -14,13 +14,16 @@ from ..utils import logger
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
,
'DummyConstantInput'
]
'DummyConstantInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
pass
class
FeedInput
(
InputData
):
def
__init__
(
self
,
ds
):
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
...
...
@@ -39,7 +42,9 @@ class FeedInput(InputData):
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
return
feed
class
FeedfreeInput
(
InputData
):
def
get_input_tensors
(
self
):
return
self
.
_get_input_tensors
()
...
...
@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
always create and return a list of new input tensors
"""
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
...
...
@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#
print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
...
...
@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
pass
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInput
(
FeedfreeInput
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
:param ds: a `DataFlow` instance
...
...
@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput):
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
"QueueInput can only be used with input placeholders!"
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
dtype=tf.float32), trainable=False),
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
#
with tf.device('/gpu:0'):
#
ret = [tf.Variable(tf.random_normal([128,224,224,3],
#
dtype=tf.float32), trainable=False),
#
tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return
ret
class
DummyConstantInput
(
QueueInput
):
""" only for debugging performance issues """
def
__init__
(
self
,
ds
,
shapes
):
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
self
.
shapes
=
shapes
...
...
@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput):
for
idx
,
p
in
enumerate
(
placehdrs
):
with
tf
.
device
(
'/gpu:0'
):
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
()))
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
()))
return
ret
class
TensorInput
(
FeedfreeInput
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
...
...
tensorpack/train/multigpu.py
View file @
fb2a051c
...
...
@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
itertools
,
re
import
itertools
import
re
from
six.moves
import
zip
,
range
from
..utils
import
logger
...
...
@@ -12,7 +13,7 @@ from ..utils.naming import *
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
.base
import
Trainer
...
...
@@ -22,6 +23,7 @@ from .input_data import QueueInput
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Trainer
):
""" Base class for multi-gpu training"""
@
staticmethod
...
...
@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
restore_collection
(
backup
)
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
...
...
@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
@
staticmethod
def
_average_grads
(
tower_grads
):
if
len
(
tower_grads
)
==
1
:
...
...
@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#return
#
return
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
...
...
@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
average_gradient
=
True
,
predict_tower
=
None
):
input_queue
=
None
,
average_gradient
=
True
,
predict_tower
=
None
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
...
...
@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
if
self
.
_average_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
...
...
@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self
.
training_threads
=
[]
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
next
(
self
.
async_step_counter
)
th
=
LoopThread
(
f
)
...
...
@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
run_step
(
self
):
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
next
(
self
.
async_step_counter
)
self
.
sess
.
run
(
self
.
train_op
)
...
...
@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
write_scalar_summary
(
'async_global_step'
,
async_step_total_cnt
)
'async_global_step'
,
async_step_total_cnt
)
except
:
logger
.
exception
(
"Cannot log async_global_step"
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
tensorpack/train/trainer.py
View file @
fb2a051c
...
...
@@ -10,13 +10,14 @@ from .base import Trainer
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
,
FeedfreeInput
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
...
...
@@ -52,8 +53,10 @@ class PredictorFactory(object):
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer """
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
...
...
@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer):
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
...
...
@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
,
predict_tower
):
# by default, use the first training gpu for prediction
predict_tower
=
predict_tower
or
[
0
]
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
predict_tower
)
self
.
sess
,
self
.
model
,
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
...
...
tensorpack/utils/__init__.py
View file @
fb2a051c
...
...
@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
__all__
=
[]
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
None
,
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
...
...
@@ -23,7 +24,7 @@ _TO_IMPORT = set([
'naming'
,
'utils'
,
'gpu'
])
])
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
for
_
,
module_name
,
_
in
walk_packages
(
...
...
@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
if
module_name
in
_TO_IMPORT
:
_global_import
(
module_name
)
__all__
.
append
(
module_name
)
tensorpack/utils/argtools.py
View file @
fb2a051c
...
...
@@ -5,10 +5,13 @@
import
operator
import
inspect
,
six
,
functools
import
inspect
import
six
import
functools
import
collections
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
def
map_arg
(
**
maps
):
"""
...
...
@@ -26,11 +29,13 @@ def map_arg(**maps):
return
wrapper
return
deco
class
memoized
(
object
):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def
__init__
(
self
,
func
):
self
.
func
=
func
self
.
cache
=
{}
...
...
@@ -60,8 +65,11 @@ class memoized(object):
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
res
=
func
(
*
args
,
**
kwargs
)
...
...
@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
return
_MEMOIZED_NOARGS
[
func
]
return
wrapper
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
# _GLOBAL_MEMOIZED_CACHE = dict()
# def global_memoized(func):
# """ Make sure that the same `memoized` object is returned on different
# calls to global_memoized(func)
# """
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# if ret is None:
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# return ret
def
shape2d
(
a
):
"""
...
...
tensorpack/utils/concurrency.py
View file @
fb2a051c
...
...
@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'mask_sigint'
,
'start_proc_mask_signal'
]
class
StoppableThread
(
threading
.
Thread
):
"""
A thread that has a 'stop' event.
"""
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
self
.
_stop_evt
=
threading
.
Event
()
...
...
@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
except
queue
.
Empty
:
pass
class
LoopThread
(
StoppableThread
):
""" A pausable thread that simply runs a loop"""
def
__init__
(
self
,
func
,
pausable
=
True
):
"""
:param func: the function to run
...
...
@@ -89,6 +93,7 @@ class DIE(object):
""" A placeholder class indicating end of queue """
pass
def
ensure_proc_terminate
(
proc
):
if
isinstance
(
proc
,
list
):
for
p
in
proc
:
...
...
@@ -114,6 +119,7 @@ def mask_sigint():
yield
signal
.
signal
(
signal
.
SIGINT
,
sigint_handler
)
def
start_proc_mask_signal
(
proc
):
if
not
isinstance
(
proc
,
list
):
proc
=
[
proc
]
...
...
@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc):
for
p
in
proc
:
p
.
start
()
def
subproc_call
(
cmd
,
timeout
=
None
):
try
:
output
=
subprocess
.
check_output
(
cmd
,
stderr
=
subprocess
.
STDOUT
,
shell
=
True
,
timeout
=
timeout
)
cmd
,
stderr
=
subprocess
.
STDOUT
,
shell
=
True
,
timeout
=
timeout
)
return
output
except
subprocess
.
TimeoutExpired
as
e
:
logger
.
warn
(
"Command timeout!"
)
...
...
@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
logger
.
warn
(
"Commnad failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
e
.
output
)
class
OrderedContainer
(
object
):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
def
__init__
(
self
,
start
=
0
):
self
.
ranks
=
[]
self
.
data
=
[]
...
...
@@ -163,11 +172,13 @@ class OrderedContainer(object):
self
.
wait_for
+=
1
return
rank
,
ret
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def
__init__
(
self
,
data_queue
,
nr_producer
,
start
=
0
):
"""
:param data_queue: a multiprocessing.Queue to produce input dp
...
...
tensorpack/utils/debug.py
View file @
fb2a051c
...
...
@@ -7,6 +7,7 @@
import
sys
__all__
=
[
'enable_call_trace'
]
def
enable_call_trace
():
def
tracer
(
frame
,
event
,
arg
):
if
event
==
'call'
:
...
...
@@ -21,9 +22,9 @@ def enable_call_trace():
if
caller
:
caller_line_no
=
caller
.
f_lineno
caller_filename
=
caller
.
f_code
.
co_filename
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
))
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
))
return
sys
.
settrace
(
tracer
)
...
...
@@ -32,6 +33,7 @@ if __name__ == '__main__':
def
b
(
a
):
print
(
2
)
def
a
():
print
(
1
)
b
(
1
)
...
...
tensorpack/utils/discretize.py
View file @
fb2a051c
...
...
@@ -12,11 +12,14 @@ from six.moves import range
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
@
memoized
def
log_once
(
s
):
logger
.
warn
(
s
)
# just a placeholder
@
six
.
add_metaclass
(
ABCMeta
)
class
Discretizer
(
object
):
...
...
@@ -28,10 +31,13 @@ class Discretizer(object):
def
get_bin
(
self
,
v
):
pass
class
Discretizer1D
(
Discretizer
):
pass
class
UniformDiscretizer1D
(
Discretizer1D
):
def
__init__
(
self
,
minv
,
maxv
,
spacing
):
"""
:params minv: minimum value of the first bin
...
...
@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D):
log_once
(
"UniformDiscretizer1D: value larger than max!"
)
return
self
.
nr_bin
-
1
return
int
(
np
.
clip
(
(
v
-
self
.
minv
)
/
self
.
spacing
,
0
,
self
.
nr_bin
-
1
))
(
v
-
self
.
minv
)
/
self
.
spacing
,
0
,
self
.
nr_bin
-
1
))
def
get_bin_center
(
self
,
bin_id
):
return
self
.
minv
+
self
.
spacing
*
(
bin_id
+
0.5
)
...
...
@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
if
v
>=
self
.
maxv
or
v
<=
self
.
minv
:
return
ret
try
:
for
k
in
range
(
1
,
smooth_radius
+
1
):
ret
[
b
+
k
]
=
smooth_factor
**
k
for
k
in
range
(
1
,
smooth_radius
+
1
):
ret
[
b
+
k
]
=
smooth_factor
**
k
except
IndexError
:
pass
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
ret
[
b
-
k
]
=
smooth_factor
**
k
for
k
in
range
(
1
,
min
(
smooth_radius
+
1
,
b
+
1
)):
ret
[
b
-
k
]
=
smooth_factor
**
k
ret
/=
ret
.
sum
()
return
ret
class
UniformDiscretizerND
(
Discretizer
):
def
__init__
(
self
,
*
min_max_spacing
):
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
...
...
@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
if
__name__
==
'__main__'
:
#u = UniformDiscretizer1D(-10, 10, 0.12)
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
import
IPython
as
IP
;
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/utils/fs.py
View file @
fb2a051c
...
...
@@ -3,13 +3,15 @@
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
sys
import
os
import
sys
from
six.moves
import
urllib
import
errno
from
.
import
logger
__all__
=
[
'mkdir_p'
,
'download'
,
'recursive_walk'
]
def
mkdir_p
(
dirname
):
""" make a dir recursively, but do nothing if the dir exists"""
assert
dirname
is
not
None
...
...
@@ -21,6 +23,7 @@ def mkdir_p(dirname):
if
e
.
errno
!=
errno
.
EEXIST
:
raise
e
def
download
(
url
,
dir
):
mkdir_p
(
dir
)
fname
=
url
.
split
(
'/'
)[
-
1
]
...
...
@@ -29,7 +32,7 @@ def download(url, dir):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
min
(
float
(
count
*
block_size
)
/
total_size
,
min
(
float
(
count
*
block_size
)
/
total_size
,
1.0
)
*
100.0
))
sys
.
stdout
.
flush
()
try
:
...
...
@@ -45,6 +48,7 @@ def download(url, dir):
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
return
fpath
def
recursive_walk
(
rootdir
):
for
r
,
dirs
,
files
in
os
.
walk
(
rootdir
):
for
f
in
files
:
...
...
tensorpack/utils/globvars.py
View file @
fb2a051c
...
...
@@ -9,13 +9,15 @@ import argparse
__all__
=
[
'globalns'
,
'use_global_argument'
]
if
six
.
PY2
:
class
NS
:
pass
class
NS
:
pass
else
:
import
types
NS
=
types
.
SimpleNamespace
globalns
=
NS
()
def
use_global_argument
(
args
):
"""
Add the content of argparse.Namespace to globalns
...
...
tensorpack/utils/gpu.py
View file @
fb2a051c
...
...
@@ -8,20 +8,22 @@ from .utils import change_env
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_gpus'
]
def
change_gpu
(
val
):
val
=
str
(
val
)
if
val
==
'-1'
:
val
=
''
return
change_env
(
'CUDA_VISIBLE_DEVICES'
,
val
)
def
get_nr_gpu
():
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
len
(
env
.
split
(
','
))
def
get_gpus
():
""" return a list of GPU physical id"""
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
,
'gpu not set!'
# TODO
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
tensorpack/utils/loadcaffe.py
View file @
fb2a051c
...
...
@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL
=
"https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
class
CaffeLayerProcessor
(
object
):
def
__init__
(
self
,
net
):
self
.
net
=
net
self
.
layer_names
=
net
.
_layer_names
...
...
@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object):
self
.
param_dict
.
update
(
dic
)
elif
len
(
layer
.
blobs
)
!=
0
:
logger
.
warn
(
"{} layer contains parameters but is not supported!"
.
format
(
layer
.
type
))
"{} layer contains parameters but is not supported!"
.
format
(
layer
.
type
))
return
self
.
param_dict
def
proc_conv
(
self
,
idx
,
name
,
param
):
assert
len
(
param
)
<=
2
assert
param
[
0
]
.
data
.
ndim
==
4
# caffe: ch_out, ch_in, h, w
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
W
=
param
[
0
]
.
data
.
transpose
(
2
,
3
,
1
,
0
)
if
len
(
param
)
==
1
:
return
{
name
+
'/W'
:
W
}
else
:
...
...
@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"FC layer {} takes spatial data."
.
format
(
name
))
W
=
param
[
0
]
.
data
# original: outx(CxHxW)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
W
=
W
.
reshape
((
-
1
,)
+
prev_layer_output
.
shape
[
1
:])
.
transpose
(
2
,
3
,
1
,
0
)
# become: (HxWxC)xout
else
:
W
=
param
[
0
]
.
data
.
transpose
()
...
...
@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
def
proc_bn
(
self
,
idx
,
name
,
param
):
assert
param
[
2
]
.
data
[
0
]
==
1.0
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
return
{
name
+
'/mean/EMA'
:
param
[
0
]
.
data
,
name
+
'/variance/EMA'
:
param
[
1
]
.
data
}
def
proc_scale
(
self
,
idx
,
name
,
param
):
bottom_name
=
self
.
net
.
bottom_names
[
name
][
0
]
...
...
@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
logger
.
info
(
"Merge {} and {} into one BatchNorm layer"
.
format
(
name
,
name2
))
return
{
name2
+
'/beta'
:
param
[
1
]
.
data
,
name2
+
'/gamma'
:
param
[
0
]
.
data
}
name2
+
'/gamma'
:
param
[
0
]
.
data
}
# assume this scaling layer is part of some BN
logger
.
error
(
"Could not find a BN layer corresponding to this Scale layer!"
)
raise
ValueError
()
...
...
@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
caffe
.
set_mode_cpu
()
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
param_dict
=
CaffeLayerProcessor
(
net
)
.
process
()
logger
.
info
(
"Model loaded from caffe. Params: "
+
\
logger
.
info
(
"Model loaded from caffe. Params: "
+
" "
.
join
(
sorted
(
param_dict
.
keys
())))
return
param_dict
def
get_caffe_pb
():
dir
=
get_dataset_path
(
'caffe'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
...
...
@@ -116,7 +119,7 @@ def get_caffe_pb():
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
"Command `protoc caffe.proto --python_out .` failed!"
"Command `protoc caffe.proto --python_out .` failed!"
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
...
...
@@ -131,4 +134,3 @@ if __name__ == '__main__':
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
tensorpack/utils/logger.py
View file @
fb2a051c
...
...
@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
logging
import
os
,
shutil
import
os
import
shutil
import
os.path
from
termcolor
import
colored
from
datetime
import
datetime
...
...
@@ -12,7 +13,9 @@ import sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
class
_MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
...
...
@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
self
.
_fmt
=
fmt
return
super
(
_MyFormatter
,
self
)
.
format
(
record
)
def
_getlogger
():
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
.
propagate
=
False
...
...
@@ -45,6 +49,8 @@ def get_time_str():
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
LOG_DIR
=
None
def
_set_file
(
path
):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
get_time_str
()
...
...
@@ -56,6 +62,7 @@ def _set_file(path):
_logger
.
addHandler
(
hdl
)
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
def
set_logger_dir
(
dirname
,
action
=
None
):
"""
Set the directory for global logging.
...
...
@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
""" disable all logging ability from this moment"""
for
func
in
_LOGGING_METHOD
:
globals
()[
func
]
=
lambda
x
:
None
def
auto_set_dir
(
action
=
None
,
overwrite
=
False
):
""" set log directory to a subdir inside 'train_log', with the name being
the main python file currently running"""
...
...
@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False):
mod
=
sys
.
modules
[
'__main__'
]
basename
=
os
.
path
.
basename
(
mod
.
__file__
)
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
def
warn_dependency
(
name
,
dependencies
):
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
tensorpack/utils/lut.py
View file @
fb2a051c
...
...
@@ -7,10 +7,12 @@ import six
__all__
=
[
'LookUpTable'
]
class
LookUpTable
(
object
):
def
__init__
(
self
,
objlist
):
self
.
idx2obj
=
dict
(
enumerate
(
objlist
))
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
self
.
obj2idx
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
self
.
idx2obj
)}
def
size
(
self
):
return
len
(
self
.
idx2obj
)
...
...
tensorpack/utils/rect.py
View file @
fb2a051c
...
...
@@ -5,6 +5,7 @@
import
numpy
as
np
class
Rect
(
object
):
"""
A Rectangle.
...
...
@@ -68,7 +69,7 @@ class Rect(object):
def
roi
(
self
,
img
):
assert
self
.
validate
(
img
.
shape
[:
2
]),
"{} vs {}"
.
format
(
self
,
img
.
shape
[:
2
])
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
return
img
[
self
.
y0
:
self
.
y1
+
1
,
self
.
x0
:
self
.
x1
+
1
]
def
expand
(
self
,
frac
):
assert
frac
>
1.0
,
frac
...
...
@@ -92,7 +93,7 @@ class Rect(object):
xmax
=
min
(
self
.
x1
,
img
.
shape
[
1
])
ymax
=
min
(
self
.
y1
,
img
.
shape
[
0
])
patch
=
img
[
ymin
:
ymax
,
xmin
:
xmax
]
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
ret
[
ystart
:
ystart
+
patch
.
shape
[
0
],
xstart
:
xstart
+
patch
.
shape
[
1
]]
=
patch
return
ret
__repr__
=
__str__
...
...
@@ -101,6 +102,6 @@ class Rect(object):
if
__name__
==
'__main__'
:
x
=
Rect
(
2
,
1
,
3
,
3
,
allow_neg
=
True
)
img
=
np
.
random
.
rand
(
3
,
3
)
img
=
np
.
random
.
rand
(
3
,
3
)
print
(
img
)
print
(
x
.
roi_zeropad
(
img
))
tensorpack/utils/serialize.py
View file @
fb2a051c
...
...
@@ -10,10 +10,12 @@ msgpack_numpy.patch()
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
#return dill.dumps(obj)
#
return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
#return dill.loads(buf)
#
return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/stats.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
tensorpack/utils/timer.py
View file @
fb2a051c
...
...
@@ -14,10 +14,12 @@ from .stats import StatCounter
from
.
import
logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
,
'IterSpeedCounter'
]
'print_total_timer'
,
'IterSpeedCounter'
]
class
IterSpeedCounter
(
object
):
""" To count how often some code gets reached"""
def
__init__
(
self
,
print_every
,
name
=
None
):
self
.
cnt
=
0
self
.
print_every
=
int
(
print_every
)
...
...
@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
...
...
@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
@
contextmanager
def
total_timer
(
msg
):
start
=
time
.
time
()
...
...
@@ -54,6 +58,7 @@ def total_timer(msg):
t
=
time
.
time
()
-
start
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
def
print_total_timer
():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
...
...
tensorpack/utils/utils.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
tensorpack/utils/viz.py
View file @
fb2a051c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment