Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
37e98945
Commit
37e98945
authored
Jan 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix flake8 style in tensorpack/
parent
233b3b90
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
69 changed files
with
164 additions
and
202 deletions
+164
-202
tensorpack/RL/common.py
tensorpack/RL/common.py
+0
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+0
-1
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+3
-1
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+1
-0
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+5
-8
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-4
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+1
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-1
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+0
-3
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+0
-1
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+0
-2
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+0
-2
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+1
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+3
-5
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+7
-6
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+2
-2
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+0
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-2
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+2
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+3
-3
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+3
-3
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+1
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+1
-1
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+1
-1
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+1
-2
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+4
-2
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+2
-3
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+1
-0
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+3
-3
tensorpack/models/_test.py
tensorpack/models/_test.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-4
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+2
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-3
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+8
-9
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+0
-6
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+2
-3
tensorpack/models/pool.py
tensorpack/models/pool.py
+3
-4
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+4
-3
tensorpack/predict/common.py
tensorpack/predict/common.py
+2
-7
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+3
-6
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+4
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+2
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-2
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+4
-7
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+2
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+0
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+4
-3
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+10
-8
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-5
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+9
-2
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+4
-2
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+1
-0
tensorpack/utils/discretize.py
tensorpack/utils/discretize.py
+5
-12
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+1
-0
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+2
-7
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+13
-12
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-1
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+0
-3
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+2
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+0
-1
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+2
-3
tox.ini
tox.ini
+6
-0
No files found.
tensorpack/RL/common.py
View file @
37e98945
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
from
.envbase
import
ProxyPlayer
from
.envbase
import
ProxyPlayer
...
...
tensorpack/RL/envbase.py
View file @
37e98945
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
from
collections
import
defaultdict
import
six
import
six
import
random
from
..utils
import
get_rng
from
..utils
import
get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
...
...
tensorpack/RL/expreplay.py
View file @
37e98945
...
@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback):
...
@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
.atari
import
AtariPlayer
from
.atari
import
AtariPlayer
import
sys
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
def
predictor
(
x
):
np
.
array
([
1
,
1
,
1
,
1
])
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
E
=
ExpReplay
(
predictor
,
player
=
player
,
player
=
player
,
...
...
tensorpack/RL/gymenv.py
View file @
37e98945
...
@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment):
...
@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment):
assert
isinstance
(
spc
,
gym
.
spaces
.
discrete
.
Discrete
)
assert
isinstance
(
spc
,
gym
.
spaces
.
discrete
.
Discrete
)
return
DiscreteActionSpace
(
spc
.
n
)
return
DiscreteActionSpace
(
spc
.
n
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
num
=
env
.
get_action_space
()
.
num_actions
()
num
=
env
.
get_action_space
()
.
num_actions
()
...
...
tensorpack/RL/simulator.py
View file @
37e98945
...
@@ -7,10 +7,8 @@ import tensorflow as tf
...
@@ -7,10 +7,8 @@ import tensorflow as tf
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
time
import
time
import
threading
import
threading
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
from
collections
import
defaultdict
import
numpy
as
np
import
six
import
six
from
six.moves
import
queue
from
six.moves
import
queue
...
@@ -20,7 +18,6 @@ from ..callbacks import Callback
...
@@ -20,7 +18,6 @@ from ..callbacks import Callback
from
..tfutils.varmanip
import
SessionUpdate
from
..tfutils.varmanip
import
SessionUpdate
from
..predict
import
OfflinePredictor
from
..predict
import
OfflinePredictor
from
..utils
import
logger
from
..utils
import
logger
#from ..utils.timer import *
from
..utils.serialize
import
loads
,
dumps
from
..utils.serialize
import
loads
,
dumps
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
...
@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
reward
,
isOver
=
player
.
action
(
action
)
reward
,
isOver
=
player
.
action
(
action
)
state
=
player
.
current_state
()
state
=
player
.
current_state
()
# compatibility
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
SimulatorProcess
=
SimulatorProcessStateExchange
...
@@ -284,6 +282,7 @@ class WeightSync(Callback):
...
@@ -284,6 +282,7 @@ class WeightSync(Callback):
self
.
condvar
.
notify_all
()
self
.
condvar
.
notify_all
()
self
.
condvar
.
release
()
self
.
condvar
.
release
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
random
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
from
tensorpack.RL
import
NaiveRLEnvironment
...
@@ -293,14 +292,13 @@ if __name__ == '__main__':
...
@@ -293,14 +292,13 @@ if __name__ == '__main__':
def
_build_player
(
self
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
class
NaiveActioner
(
SimulatorMaster
):
def
_get_action
(
self
,
state
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
#
print("Over: ", client.memory)
client
.
memory
=
[]
client
.
memory
=
[]
client
.
state
=
0
client
.
state
=
0
...
@@ -312,5 +310,4 @@ if __name__ == '__main__':
...
@@ -312,5 +310,4 @@ if __name__ == '__main__':
ensure_proc_terminate
(
procs
)
ensure_proc_terminate
(
procs
)
th
.
start
()
th
.
start
()
import
time
time
.
sleep
(
100
)
time
.
sleep
(
100
)
tensorpack/callbacks/base.py
View file @
37e98945
...
@@ -3,10 +3,7 @@
...
@@ -3,10 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
sys
from
abc
import
ABCMeta
import
os
import
time
from
abc
import
abstractmethod
,
ABCMeta
import
six
import
six
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
...
...
tensorpack/callbacks/graph.py
View file @
37e98945
...
@@ -6,7 +6,6 @@
...
@@ -6,7 +6,6 @@
""" Graph related callbacks"""
""" Graph related callbacks"""
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
__all__
=
[
'RunOp'
]
__all__
=
[
'RunOp'
]
...
@@ -26,7 +25,7 @@ class RunOp(Callback):
...
@@ -26,7 +25,7 @@ class RunOp(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
self
.
_op
=
self
.
setup_func
()
#self._op_name = self._op.name
#
self._op_name = self._op.name
def
_before_train
(
self
):
def
_before_train
(
self
):
if
self
.
run_before
:
if
self
.
run_before
:
...
@@ -35,6 +34,3 @@ class RunOp(Callback):
...
@@ -35,6 +34,3 @@ class RunOp(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
run_epoch
:
if
self
.
run_epoch
:
self
.
_op
.
run
()
self
.
_op
.
run
()
# def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/group.py
View file @
37e98945
...
@@ -86,7 +86,6 @@ class Callbacks(Callback):
...
@@ -86,7 +86,6 @@ class Callbacks(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
display_name
=
str
(
cb
)
display_name
=
str
(
cb
)
with
tm
.
timed_callback
(
display_name
):
with
tm
.
timed_callback
(
display_name
):
...
...
tensorpack/callbacks/inference.py
View file @
37e98945
...
@@ -2,14 +2,13 @@
...
@@ -2,14 +2,13 @@
# File: inference.py
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
sys
import
sys
import
six
import
six
from
six.moves
import
zip
from
six.moves
import
zip
from
..utils
import
logger
,
execute_only_once
from
..utils
import
logger
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_var_name
from
..tfutils
import
get_op_var_name
...
...
tensorpack/callbacks/inference_runner.py
View file @
37e98945
...
@@ -12,7 +12,6 @@ from ..dataflow import DataFlow
...
@@ -12,7 +12,6 @@ from ..dataflow import DataFlow
from
.base
import
Callback
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
.dispatcher
import
OutputTensorDispatcer
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
from
..train.input_data
import
FeedfreeInput
from
..train.input_data
import
FeedfreeInput
...
@@ -99,7 +98,6 @@ class InferenceRunner(Callback):
...
@@ -99,7 +98,6 @@ class InferenceRunner(Callback):
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
before_inference
()
inf
.
before_inference
()
sess
=
tf
.
get_default_session
()
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
...
@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback):
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
before_inference
()
inf
.
before_inference
()
sess
=
tf
.
get_default_session
()
sz
=
self
.
_input_data
.
size
()
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
for
_
in
range
(
sz
):
...
...
tensorpack/callbacks/param.py
View file @
37e98945
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
from
abc
import
abstractmethod
,
ABCMeta
import
operator
import
operator
import
six
import
six
import
os
import
os
...
...
tensorpack/callbacks/saver.py
View file @
37e98945
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
os
import
os
import
shutil
import
shutil
import
re
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
...
...
tensorpack/callbacks/stats.py
View file @
37e98945
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
# File: stats.py
# File: stats.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
re
import
os
import
os
import
operator
import
operator
import
json
import
json
...
...
tensorpack/dataflow/common.py
View file @
37e98945
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
__future__
import
division
from
__future__
import
division
import
copy
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
,
defaultdict
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
from
six.moves
import
range
,
map
...
@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow):
...
@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow):
super
(
BatchData
,
self
)
.
__init__
(
ds
)
super
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
if
not
remainder
:
try
:
try
:
s
=
ds
.
size
()
assert
batch_size
<=
ds
.
size
()
assert
batch_size
<=
ds
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
pass
pass
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
37e98945
...
@@ -8,7 +8,7 @@ import glob
...
@@ -8,7 +8,7 @@ import glob
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.fs
import
download
from
..base
import
RNGDataFlow
from
..base
import
RNGDataFlow
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
37e98945
...
@@ -5,16 +5,13 @@
...
@@ -5,16 +5,13 @@
# Yukun Chen <cykustc@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import
os
import
os
import
sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
random
import
six
import
six
from
six.moves
import
urllib
,
range
from
six.moves
import
range
import
copy
import
copy
import
logging
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.fs
import
download
from
..base
import
RNGDataFlow
from
..base
import
RNGDataFlow
...
@@ -152,6 +149,7 @@ class Cifar100(CifarBase):
...
@@ -152,6 +149,7 @@ class Cifar100(CifarBase):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Cifar10
(
'train'
)
ds
=
Cifar10
(
'train'
)
from
tensorpack.dataflow.dftools
import
dump_dataset_images
from
tensorpack.dataflow.dftools
import
dump_dataset_images
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
37e98945
...
@@ -5,11 +5,11 @@
...
@@ -5,11 +5,11 @@
import
os
import
os
import
tarfile
import
tarfile
import
cv2
import
cv2
import
six
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
import
xml.etree.ElementTree
as
ET
import
xml.etree.ElementTree
as
ET
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
from
...utils.fs
import
mkdir_p
,
download
from
...utils.timer
import
timed_operation
from
...utils.timer
import
timed_operation
...
@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow):
...
@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow):
box
=
root
.
find
(
'object'
)
.
find
(
'bndbox'
)
.
getchildren
()
box
=
root
.
find
(
'object'
)
.
find
(
'bndbox'
)
.
getchildren
()
box
=
map
(
lambda
x
:
float
(
x
.
text
),
box
)
box
=
map
(
lambda
x
:
float
(
x
.
text
),
box
)
#box[0] /= size[0]
#
box[0] /= size[0]
#box[1] /= size[1]
#
box[1] /= size[1]
#box[2] /= size[0]
#
box[2] /= size[0]
#box[3] /= size[1]
#
box[3] /= size[1]
return
np
.
asarray
(
box
,
dtype
=
'float32'
)
return
np
.
asarray
(
box
,
dtype
=
'float32'
)
with
timed_operation
(
'Loading Bounding Boxes ...'
):
with
timed_operation
(
'Loading Bounding Boxes ...'
):
...
@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow):
...
@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow):
logger
.
info
(
"{}/{} images have bounding box."
.
format
(
cnt
,
len
(
imglist
)))
logger
.
info
(
"{}/{} images have bounding box."
.
format
(
cnt
,
len
(
imglist
)))
return
ret
return
ret
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
meta
=
ILSVRCMeta
()
# print(meta.get_synset_words_1000())
# print(meta.get_synset_words_1000())
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
37e98945
...
@@ -5,9 +5,8 @@
...
@@ -5,9 +5,8 @@
import
os
import
os
import
gzip
import
gzip
import
random
import
numpy
import
numpy
from
six.moves
import
urllib
,
range
from
six.moves
import
range
from
...utils
import
logger
,
get_dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.fs
import
download
...
@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow):
...
@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow):
label
=
self
.
labels
[
k
]
label
=
self
.
labels
[
k
]
yield
[
img
,
label
]
yield
[
img
,
label
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
for
(
img
,
label
)
in
ds
.
get_data
():
...
...
tensorpack/dataflow/dataset/ptb.py
View file @
37e98945
...
@@ -9,9 +9,7 @@ import numpy as np
...
@@ -9,9 +9,7 @@ import numpy as np
from
...utils
import
logger
,
get_dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.fs
import
download
from
...utils.argtools
import
memoized_ignoreargs
from
...utils.argtools
import
memoized_ignoreargs
from
..base
import
RNGDataFlow
try
:
try
:
import
tensorflow
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
except
ImportError
:
except
ImportError
:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
37e98945
...
@@ -5,9 +5,8 @@
...
@@ -5,9 +5,8 @@
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
..base
import
RNGDataFlow
from
..base
import
RNGDataFlow
try
:
try
:
...
@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow):
...
@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow):
c
=
SVHNDigit
(
'extra'
)
c
=
SVHNDigit
(
'extra'
)
return
np
.
concatenate
((
a
.
X
,
b
.
X
,
c
.
X
))
.
mean
(
axis
=
0
)
return
np
.
concatenate
((
a
.
X
,
b
.
X
,
c
.
X
))
.
mean
(
axis
=
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
a
=
SVHNDigit
(
'train'
)
a
=
SVHNDigit
(
'train'
)
b
=
SVHNDigit
.
get_per_pixel_mean
()
b
=
SVHNDigit
.
get_per_pixel_mean
()
tensorpack/dataflow/dataset/visualqa.py
View file @
37e98945
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..base
import
DataFlow
from
..base
import
DataFlow
from
...utils
import
*
from
...utils.timer
import
timed_operation
from
...utils.timer
import
*
from
six.moves
import
zip
,
map
from
six.moves
import
zip
,
map
from
collections
import
Counter
from
collections
import
Counter
import
json
import
json
...
@@ -74,12 +73,11 @@ class VisualQA(DataFlow):
...
@@ -74,12 +73,11 @@ class VisualQA(DataFlow):
ret
=
cnt
.
most_common
(
n
)
ret
=
cnt
.
most_common
(
n
)
return
[
k
[
0
]
for
k
in
ret
]
return
[
k
[
0
]
for
k
in
ret
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
for
k
in
vqa
.
get_data
():
for
k
in
vqa
.
get_data
():
print
(
json
.
dumps
(
k
))
print
(
json
.
dumps
(
k
))
break
break
# vqa.get_common_question_words(100)
vqa
.
get_common_answer
(
100
)
vqa
.
get_common_answer
(
100
)
#from IPython import embed; embed()
tensorpack/dataflow/format.py
View file @
37e98945
...
@@ -6,10 +6,11 @@ import numpy as np
...
@@ -6,10 +6,11 @@ import numpy as np
from
six.moves
import
range
from
six.moves
import
range
import
os
import
os
from
..utils
import
logger
,
get_
rng
,
get_
tqdm
from
..utils
import
logger
,
get_tqdm
from
..utils.timer
import
timed_operation
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.serialize
import
loads
from
..utils.serialize
import
loads
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
from
.base
import
RNGDataFlow
try
:
try
:
...
@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow):
...
@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow):
if
k
!=
'__keys__'
:
if
k
!=
'__keys__'
:
yield
[
k
,
v
]
yield
[
k
,
v
]
else
:
else
:
s
=
self
.
size
()
self
.
rng
.
shuffle
(
self
.
keys
)
self
.
rng
.
shuffle
(
self
.
keys
)
for
k
in
self
.
keys
:
for
k
in
self
.
keys
:
v
=
self
.
_txn
.
get
(
k
)
v
=
self
.
_txn
.
get
(
k
)
...
@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder):
...
@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder):
img
=
np
.
fromstring
(
datum
.
data
,
dtype
=
np
.
uint8
)
img
=
np
.
fromstring
(
datum
.
data
,
dtype
=
np
.
uint8
)
img
=
img
.
reshape
(
datum
.
channels
,
datum
.
height
,
datum
.
width
)
img
=
img
.
reshape
(
datum
.
channels
,
datum
.
height
,
datum
.
width
)
except
Exception
:
except
Exception
:
log_once
(
"Cannot read key {}"
.
format
(
k
))
log_once
(
"Cannot read key {}"
.
format
(
k
)
,
'warn'
)
return
None
return
None
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
...
...
tensorpack/dataflow/image.py
View file @
37e98945
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
import
copy
from
.base
import
RNGDataFlow
from
.base
import
RNGDataFlow
,
DataFlow
,
ProxyDataFlow
from
.common
import
MapDataComponent
,
MapData
from
.common
import
MapDataComponent
,
MapData
from
.imgaug
import
AugmentorList
from
.imgaug
import
AugmentorList
...
@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent):
...
@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent):
Augment the image component of datapoints
Augment the image component of datapoints
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index (or list of indices) of the image component in the produced datapoints by `ds`. default to be 0
:param index: the index (or list of indices) of the image component
in the produced datapoints by `ds`. default to be 0
"""
"""
if
isinstance
(
augmentors
,
AugmentorList
):
if
isinstance
(
augmentors
,
AugmentorList
):
self
.
augs
=
augmentors
self
.
augs
=
augmentors
...
...
tensorpack/dataflow/imgaug/base.py
View file @
37e98945
...
@@ -55,7 +55,7 @@ class Augmentor(object):
...
@@ -55,7 +55,7 @@ class Augmentor(object):
def
_rand_range
(
self
,
low
=
1.0
,
high
=
None
,
size
=
None
):
def
_rand_range
(
self
,
low
=
1.0
,
high
=
None
,
size
=
None
):
if
high
is
None
:
if
high
is
None
:
low
,
high
=
0
,
low
low
,
high
=
0
,
low
if
size
==
None
:
if
size
is
None
:
size
=
[]
size
=
[]
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
37e98945
...
@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor):
...
@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
orig_shape
=
img
.
shape
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
...
@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor):
y0
,
x0
,
h
,
w
=
param
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/dataflow/imgaug/deform.py
View file @
37e98945
...
@@ -26,7 +26,7 @@ class GaussianMap(object):
...
@@ -26,7 +26,7 @@ class GaussianMap(object):
y
=
y
.
astype
(
'float32'
)
/
ret
.
shape
[
0
]
-
anchor
[
0
]
y
=
y
.
astype
(
'float32'
)
/
ret
.
shape
[
0
]
-
anchor
[
0
]
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
#cv2.imshow(" ", g)
#
cv2.imshow(" ", g)
# cv2.waitKey()
# cv2.waitKey()
return
g
return
g
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
37e98945
...
@@ -6,7 +6,6 @@
...
@@ -6,7 +6,6 @@
from
.base
import
ImageAugmentor
from
.base
import
ImageAugmentor
import
math
import
math
import
cv2
import
cv2
import
numpy
as
np
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
...
@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor):
...
@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor):
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
#print(ret.shape, deg, newx, newy, neww, newh)
#
print(ret.shape, deg, newx, newy, neww, newh)
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
@
staticmethod
@
staticmethod
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
37e98945
...
@@ -131,7 +131,8 @@ class Clip(ImageAugmentor):
...
@@ -131,7 +131,8 @@ class Clip(ImageAugmentor):
class
Saturation
(
ImageAugmentor
):
class
Saturation
(
ImageAugmentor
):
def
__init__
(
self
,
alpha
=
0.4
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" Saturation,
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
"""
super
(
Saturation
,
self
)
.
__init__
()
super
(
Saturation
,
self
)
.
__init__
()
assert
alpha
<
1
assert
alpha
<
1
...
@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor):
...
@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
The implementation follows 'fb.resnet.torch': https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
The implementation follows 'fb.resnet.torch':
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
:param eigvec: each column is one eigen vector
:param eigvec: each column is one eigen vector
"""
"""
...
...
tensorpack/dataflow/prefetch.py
View file @
37e98945
...
@@ -4,10 +4,8 @@
...
@@ -4,10 +4,8 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
multiprocessing
as
mp
import
multiprocessing
as
mp
from
threading
import
Thread
import
itertools
import
itertools
from
six.moves
import
range
,
zip
from
six.moves
import
range
,
zip
from
six.moves.queue
import
Queue
import
uuid
import
uuid
import
os
import
os
...
@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
of datapoints will be random.
:param pipedir: a local directory where the pipes would be. Useful if you're running on non-local FS such as NFS.
:param pipedir: a local directory where the pipes would be.
Useful if you're running on non-local FS such as NFS.
"""
"""
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
try
:
try
:
...
...
tensorpack/dataflow/remote.py
View file @
37e98945
...
@@ -51,6 +51,7 @@ class RemoteData(DataFlow):
...
@@ -51,6 +51,7 @@ class RemoteData(DataFlow):
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
yield
dp
yield
dp
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
sys
import
sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
...
tensorpack/dataflow/tf_func.py
View file @
37e98945
...
@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow):
...
@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
.raw
import
FakeData
from
.raw
import
FakeData
from
.prefetch
import
PrefetchDataZMQ
from
.image
import
AugmentImageComponent
from
.
import
imgaug
ds
=
FakeData
([[
224
,
224
,
3
]],
100000
,
random
=
False
)
ds
=
FakeData
([[
224
,
224
,
3
]],
100000
,
random
=
False
)
def
tf_aug
(
v
):
def
tf_aug
(
v
):
...
@@ -69,6 +66,9 @@ if __name__ == '__main__':
...
@@ -69,6 +66,9 @@ if __name__ == '__main__':
tf_aug
,
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
)
# from .prefetch import PrefetchDataZMQ
# from .image import AugmentImageComponent
# from . import imgaug
# ds = AugmentImageComponent(ds,
# ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False),
# [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
...
...
tensorpack/models/_test.py
View file @
37e98945
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
unittest
import
unittest
...
@@ -29,6 +28,7 @@ def run_test_case(case):
...
@@ -29,6 +28,7 @@ def run_test_case(case):
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
tensorpack
import
tensorpack
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
...
...
tensorpack/models/batch_norm.py
View file @
37e98945
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
from
copy
import
copy
import
re
from
..tfutils.common
import
get_tf_version
from
..tfutils.common
import
get_tf_version
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
...
@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
if
use_local_stat
:
# training tower
# training tower
if
ctx
.
is_training
:
if
ctx
.
is_training
:
#reuse = tf.get_variable_scope().reuse
#
reuse = tf.get_variable_scope().reuse
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
...
@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
var_var_name
=
ema
.
average_name
(
batch_var
)
var_var_name
=
ema
.
average_name
(
batch_var
)
sc
=
tf
.
get_variable_scope
()
if
ctx
.
is_main_tower
:
if
ctx
.
is_main_tower
:
# main tower, but needs to use global stat. global stat must be from outside
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the desired variable name could
# TODO when reuse=True, the desired variable name could
...
@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else
:
else
:
return
tf
.
identity
(
xn
,
name
=
'output'
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
if
get_tf_version
()
>=
12
:
if
get_tf_version
()
>=
12
:
BatchNorm
=
BatchNormV2
BatchNorm
=
BatchNormV2
else
:
else
:
...
...
tensorpack/models/conv2d.py
View file @
37e98945
...
@@ -3,12 +3,9 @@
...
@@ -3,12 +3,9 @@
# File: conv2d.py
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
math
from
._common
import
layer_register
,
shape2d
,
shape4d
from
._common
import
layer_register
,
shape2d
,
shape4d
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
shape2d
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
...
@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape,
conv
=
tf
.
concat
(
3
,
outputs
)
conv
=
tf
.
concat
(
3
,
outputs
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. "
"Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
...
...
tensorpack/models/fc.py
View file @
37e98945
...
@@ -4,10 +4,10 @@
...
@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
math
from
._common
import
layer_register
from
._common
import
layer_register
from
..tfutils
import
symbolic_functions
as
symbf
from
..tfutils
import
symbolic_functions
as
symbf
from
..utils
import
logger
__all__
=
[
'FullyConnected'
]
__all__
=
[
'FullyConnected'
]
...
@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim,
...
@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim,
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
if
W_init
is
None
:
#W_init = tf.uniform_unit_scaling_initializer(factor=1.43)
W_init
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
()
W_init
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
()
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
...
@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim,
...
@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim,
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
if
nl
is
None
:
if
nl
is
None
:
logger
.
warn
(
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated."
" Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
nl
=
tf
.
nn
.
relu
return
nl
(
prod
,
name
=
'output'
)
return
nl
(
prod
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
37e98945
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
._common
import
layer_register
from
._common
import
layer_register
from
._test
import
TestModel
__all__
=
[
'ImageSample'
]
__all__
=
[
'ImageSample'
]
...
@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'):
diffy
,
diffx
=
tf
.
split
(
3
,
2
,
diff
)
diffy
,
diffx
=
tf
.
split
(
3
,
2
,
diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#
prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50)
# tf.reduce_max(diff), diff], summarize=50)
...
@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'):
...
@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'):
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
return
ret
return
ret
from
._test
import
TestModel
class
TestSample
(
TestModel
):
class
TestSample
(
TestModel
):
...
@@ -128,9 +127,9 @@ class TestSample(TestModel):
...
@@ -128,9 +127,9 @@ class TestSample(TestModel):
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
# mat = np.array([
# mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#
[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#
[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
#
], dtype='float32') #2x2x2x2
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
...
@@ -140,10 +139,10 @@ class TestSample(TestModel):
...
@@ -140,10 +139,10 @@ class TestSample(TestModel):
self
.
assertTrue
((
res
==
true_res
)
.
all
())
self
.
assertTrue
((
res
==
true_res
)
.
all
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
import
sys
im
=
cv2
.
imread
(
'cat.jpg'
)
im
=
cv2
.
imread
(
'cat.jpg'
)
im
=
im
.
reshape
((
1
,)
+
im
.
shape
)
.
astype
(
'float32'
)
im
=
im
.
reshape
((
1
,)
+
im
.
shape
)
.
astype
(
'float32'
)
imv
=
tf
.
Variable
(
im
)
imv
=
tf
.
Variable
(
im
)
...
@@ -160,8 +159,8 @@ if __name__ == '__main__':
...
@@ -160,8 +159,8 @@ if __name__ == '__main__':
sess
=
tf
.
Session
()
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
global_variables_initializer
())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#
out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#
out = sess.run(output)
# print(out[0].min())
# print(out[0].min())
# print(out[0].max())
# print(out[0].max())
# print(out[0].sum())
# print(out[0].sum())
...
...
tensorpack/models/model_desc.py
View file @
37e98945
...
@@ -4,25 +4,19 @@
...
@@ -4,25 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
re
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
import
inspect
import
inspect
import
pickle
import
pickle
import
six
import
six
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class
InputVar
(
object
):
class
InputVar
(
object
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
self
.
type
=
type
self
.
type
=
type
self
.
shape
=
shape
self
.
shape
=
shape
...
...
tensorpack/models/nonlin.py
View file @
37e98945
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
from
._common
import
layer_register
from
._common
import
layer_register
from
.batch_norm
import
BatchNorm
from
.batch_norm
import
BatchNorm
...
@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None):
...
@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None):
if
name
is
None
:
if
name
is
None
:
name
=
'output'
name
=
'output'
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
#alpha = float(alpha)
#
alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
# return tf.mul(x, 0.5, name=name)
# return tf.mul(x, 0.5, name=name)
...
...
tensorpack/models/pool.py
View file @
37e98945
...
@@ -8,6 +8,8 @@ import numpy as np
...
@@ -8,6 +8,8 @@ import numpy as np
from
._common
import
layer_register
,
shape4d
from
._common
import
layer_register
,
shape4d
from
..utils.argtools
import
shape2d
from
..utils.argtools
import
shape2d
from
..tfutils
import
symbolic_functions
as
symbf
from
..tfutils
import
symbolic_functions
as
symbf
from
._test
import
TestModel
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
'BilinearUpSample'
]
...
@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape):
...
@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape):
:param x: input NHWC tensor
:param x: input NHWC tensor
:param shape: an integer, the upsample factor
:param shape: an integer, the upsample factor
"""
"""
#inp_shape = tf.shape(x)
#
inp_shape = tf.shape(x)
# return tf.image.resize_bilinear(x,
# return tf.image.resize_bilinear(x,
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True)
# align_corners=True)
...
@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape):
...
@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape):
return
deconv
return
deconv
from
._test
import
TestModel
class
TestPool
(
TestModel
):
class
TestPool
(
TestModel
):
def
test_fixed_unpooling
(
self
):
def
test_fixed_unpooling
(
self
):
...
...
tensorpack/models/regularize.py
View file @
37e98945
...
@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
...
@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
def
_log_regularizer
(
name
):
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
...
...
tensorpack/predict/base.py
View file @
37e98945
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
from
abc
import
abstractmethod
,
ABCMeta
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
...
@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
predictors
=
[]
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
# TODO backup summary keys?
fn
=
lambda
_
:
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
())
def
fn
(
_
):
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
towers
)
build_multi_tower_prediction_graph
(
fn
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
...
...
tensorpack/predict/common.py
View file @
37e98945
...
@@ -2,19 +2,14 @@
...
@@ -2,19 +2,14 @@
# File: common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
import
six
import
six
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
get_default_sess_config
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
.base
import
OfflinePredictor
from
.base
import
OfflinePredictor
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
...
@@ -53,7 +48,7 @@ class PredictConfig(object):
...
@@ -53,7 +48,7 @@ class PredictConfig(object):
self
.
input_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
self
.
input_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
if
self
.
input_names
is
not
None
:
if
self
.
input_names
is
not
None
:
pass
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
#
logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if
self
.
input_names
is
None
:
if
self
.
input_names
is
None
:
# neither options is set, assume all inputs
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input_vars_desc
()
raw_vars
=
self
.
model
.
get_input_vars_desc
()
...
@@ -61,7 +56,7 @@ class PredictConfig(object):
...
@@ -61,7 +56,7 @@ class PredictConfig(object):
self
.
output_names
=
kwargs
.
pop
(
'output_names'
,
None
)
self
.
output_names
=
kwargs
.
pop
(
'output_names'
,
None
)
if
self
.
output_names
is
None
:
if
self
.
output_names
is
None
:
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
#
logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
assert_type
(
v
,
six
.
string_types
)
...
...
tensorpack/predict/concurrency.py
View file @
37e98945
...
@@ -5,10 +5,8 @@
...
@@ -5,10 +5,8 @@
import
multiprocessing
import
multiprocessing
import
threading
import
threading
import
tensorflow
as
tf
import
time
import
six
import
six
from
six.moves
import
queue
,
range
,
zip
from
six.moves
import
queue
,
range
from
..utils.concurrency
import
DIE
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
...
@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from
tensorpack.models._common
import
disable_layer_logging
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
disable_layer_logging
()
self
.
predictor
=
OfflinePredictor
(
self
.
config
)
self
.
predictor
=
OfflinePredictor
(
self
.
config
)
import
sys
if
self
.
idx
==
0
:
if
self
.
idx
==
0
:
with
self
.
predictor
.
graph
.
as_default
():
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
describe_model
()
...
@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
""" :param predictors: a list of OnlinePredictor"""
assert
len
(
predictors
)
assert
len
(
predictors
)
for
k
in
predictors
:
for
k
in
predictors
:
#assert isinstance(k, OnlinePredictor), type(k)
#
assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
assert
not
k
.
return_input
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
...
...
tensorpack/predict/dataset.py
View file @
37e98945
...
@@ -9,7 +9,7 @@ import multiprocessing
...
@@ -9,7 +9,7 @@ import multiprocessing
import
os
import
os
import
six
import
six
from
..dataflow
import
DataFlow
,
BatchData
from
..dataflow
import
DataFlow
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
...
...
tensorpack/tfutils/common.py
View file @
37e98945
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# File: common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
..utils.naming
import
*
from
..utils.naming
import
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
from
copy
import
copy
import
six
import
six
...
@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf
.
gpu_options
.
allocator_type
=
'BFC'
conf
.
gpu_options
.
allocator_type
=
'BFC'
conf
.
gpu_options
.
allow_growth
=
True
conf
.
gpu_options
.
allow_growth
=
True
conf
.
allow_soft_placement
=
True
conf
.
allow_soft_placement
=
True
#conf.log_device_placement = True
#
conf.log_device_placement = True
return
conf
return
conf
...
@@ -74,6 +74,7 @@ def get_op_tensor_name(name):
...
@@ -74,6 +74,7 @@ def get_op_tensor_name(name):
else
:
else
:
return
name
,
name
+
':0'
return
name
,
name
+
':0'
get_op_var_name
=
get_op_tensor_name
get_op_var_name
=
get_op_tensor_name
...
@@ -88,6 +89,7 @@ def get_tensors_by_names(names):
...
@@ -88,6 +89,7 @@ def get_tensors_by_names(names):
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
return
ret
get_vars_by_names
=
get_tensors_by_names
get_vars_by_names
=
get_tensors_by_names
...
...
tensorpack/tfutils/gradproc.py
View file @
37e98945
...
@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor):
...
@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor):
ret
.
append
((
grad
,
var
))
ret
.
append
((
grad
,
var
))
return
ret
return
ret
_summaried_gradient
=
set
()
_summaried_gradient
=
set
()
...
@@ -133,7 +134,7 @@ class CheckGradient(MapGradient):
...
@@ -133,7 +134,7 @@ class CheckGradient(MapGradient):
def
_mapper
(
self
,
grad
,
var
):
def
_mapper
(
self
,
grad
,
var
):
# this is very slow.... see #3649
# this is very slow.... see #3649
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
#
op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
return
grad
return
grad
...
...
tensorpack/tfutils/sessinit.py
View file @
37e98945
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
os
import
os
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
from
collections
import
defaultdict
import
re
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
...
@@ -120,7 +119,8 @@ class SaverRestore(SessionInit):
...
@@ -120,7 +119,8 @@ class SaverRestore(SessionInit):
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
for
v
in
ckpt_vars
:
if
v
.
startswith
(
PREDICT_TOWER
):
if
v
.
startswith
(
PREDICT_TOWER
):
logger
.
error
(
"Found {} in checkpoint. But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
logger
.
error
(
"Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
return
set
(
ckpt_vars
)
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
...
...
tensorpack/tfutils/summary.py
View file @
37e98945
...
@@ -7,7 +7,7 @@ import tensorflow as tf
...
@@ -7,7 +7,7 @@ import tensorflow as tf
import
re
import
re
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.naming
import
*
from
..utils.naming
import
MOVING_SUMMARY_VARS_KEY
from
.tower
import
get_current_tower_context
from
.tower
import
get_current_tower_context
from
.
import
get_global_step_var
from
.
import
get_global_step_var
from
.symbolic_functions
import
rms
from
.symbolic_functions
import
rms
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
37e98945
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
from
..utils
import
logger
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
...
@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
...
@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost
=
tf
.
nn
.
weighted_cross_entropy_with_logits
(
logits
,
y
,
pos_weight
)
cost
=
tf
.
nn
.
weighted_cross_entropy_with_logits
(
logits
,
y
,
pos_weight
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y *
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
#(logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
# cost = tf.sub(loss_pos, loss_neg, name=name)
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
return
cost
...
...
tensorpack/tfutils/tower.py
View file @
37e98945
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
re
import
re
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
...
@@ -49,7 +49,7 @@ class TowerContext(object):
...
@@ -49,7 +49,7 @@ class TowerContext(object):
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
scope
=
tf
.
get_variable_scope
()
scope
=
tf
.
get_variable_scope
()
assert
scope
.
reuse
==
Fal
se
assert
not
scope
.
reu
se
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
...
...
tensorpack/tfutils/varmanip.py
View file @
37e98945
...
@@ -10,7 +10,7 @@ from collections import defaultdict
...
@@ -10,7 +10,7 @@ from collections import defaultdict
import
re
import
re
import
numpy
as
np
import
numpy
as
np
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
from
.common
import
get_op_tensor_name
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
...
@@ -51,7 +51,7 @@ class SessionUpdate(object):
...
@@ -51,7 +51,7 @@ class SessionUpdate(object):
self
.
sess
=
sess
self
.
sess
=
sess
self
.
assign_ops
=
defaultdict
(
list
)
self
.
assign_ops
=
defaultdict
(
list
)
for
v
in
vars_to_update
:
for
v
in
vars_to_update
:
#p = tf.placeholder(v.dtype, shape=v.get_shape())
#
p = tf.placeholder(v.dtype, shape=v.get_shape())
with
tf
.
device
(
'/cpu:0'
):
with
tf
.
device
(
'/cpu:0'
):
p
=
tf
.
placeholder
(
v
.
dtype
)
p
=
tf
.
placeholder
(
v
.
dtype
)
savename
=
get_savename_from_varname
(
v
.
name
)
savename
=
get_savename_from_varname
(
v
.
name
)
...
...
tensorpack/train/base.py
View file @
37e98945
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
signal
import
re
import
re
import
weakref
import
weakref
import
six
import
six
...
...
tensorpack/train/feedfree.py
View file @
37e98945
...
@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
...
@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
.input_data
import
QueueInput
,
FeedfreeInput
,
DummyConstantInput
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.base
import
Trainer
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
from
.trainer
import
MultiPredictorTowerTrainer
...
@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(
...
@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#
self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
...
@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
...
@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
"""
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
...
...
tensorpack/train/input_data.py
View file @
37e98945
...
@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread):
...
@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread):
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
:
pass
pass
except
Exception
:
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
logger
.
exception
(
"Exception in EnqueueThread:"
)
...
...
tensorpack/train/multigpu.py
View file @
37e98945
...
@@ -9,9 +9,9 @@ import re
...
@@ -9,9 +9,9 @@ import re
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
*
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
from
..tfutils
import
(
backup_collection
,
restore_collection
,
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
...
@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer):
...
@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer):
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
))
as
scope
:
TowerContext
(
'tower{}'
.
format
(
idx
)):
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
grad_list
.
append
(
get_tower_grad_func
())
grad_list
.
append
(
get_tower_grad_func
())
...
@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
...
@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
if
None
in
nones
and
len
(
nones
)
!=
1
:
if
None
in
nones
and
len
(
nones
)
!=
1
:
raise
RuntimeError
(
"Gradient w.r.t {} is None in some but not all towers!"
.
format
(
v
.
name
))
raise
RuntimeError
(
"Gradient w.r.t {} is None in some but not all towers!"
.
format
(
v
.
name
))
elif
nones
[
0
]
is
None
:
elif
nones
[
0
]
is
None
:
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
v
ar
.
op
.
name
))
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
v
.
op
.
name
))
continue
continue
try
:
try
:
grad
=
tf
.
add_n
(
all_grad
)
/
float
(
len
(
tower_grads
))
grad
=
tf
.
add_n
(
all_grad
)
/
float
(
len
(
tower_grads
))
...
@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
# debug tower performance:
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#
ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#
self.train_op = tf.group(*ops)
# return
# return
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
...
@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
...
...
tensorpack/train/trainer.py
View file @
37e98945
...
@@ -3,18 +3,16 @@
...
@@ -3,18 +3,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
time
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
,
FeedfreeInput
from
.input_data
import
FeedInput
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
...
@@ -49,7 +47,8 @@ class PredictorFactory(object):
...
@@ -49,7 +47,8 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
fn
=
lambda
_
:
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
self
.
tower_built
=
True
...
...
tensorpack/utils/argtools.py
View file @
37e98945
...
@@ -9,8 +9,9 @@ import inspect
...
@@ -9,8 +9,9 @@ import inspect
import
six
import
six
import
functools
import
functools
import
collections
import
collections
from
.
import
logger
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
,
'log_once'
]
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
...
@@ -64,11 +65,12 @@ class memoized(object):
...
@@ -64,11 +65,12 @@ class memoized(object):
'''Support instance methods.'''
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable.
is it necessary?
h
ash
(
func
)
# make sure it is hashable. TODO
is it necessary?
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
if
func
not
in
_MEMOIZED_NOARGS
:
...
@@ -99,3 +101,8 @@ def shape2d(a):
...
@@ -99,3 +101,8 @@ def shape2d(a):
assert
len
(
a
)
==
2
assert
len
(
a
)
==
2
return
list
(
a
)
return
list
(
a
)
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
@
memoized
def
log_once
(
message
,
func
):
getattr
(
logger
,
func
)(
message
)
tensorpack/utils/concurrency.py
View file @
37e98945
...
@@ -11,13 +11,15 @@ from contextlib import contextmanager
...
@@ -11,13 +11,15 @@ from contextlib import contextmanager
import
signal
import
signal
import
weakref
import
weakref
import
six
import
six
from
six.moves
import
queue
from
.
import
logger
if
six
.
PY2
:
if
six
.
PY2
:
import
subprocess32
as
subprocess
import
subprocess32
as
subprocess
else
:
else
:
import
subprocess
import
subprocess
from
six.moves
import
queue
from
.
import
logger
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
...
...
tensorpack/utils/debug.py
View file @
37e98945
...
@@ -28,6 +28,7 @@ def enable_call_trace():
...
@@ -28,6 +28,7 @@ def enable_call_trace():
return
return
sys
.
settrace
(
tracer
)
sys
.
settrace
(
tracer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
enable_call_trace
()
enable_call_trace
()
...
...
tensorpack/utils/discretize.py
View file @
37e98945
...
@@ -3,8 +3,7 @@
...
@@ -3,8 +3,7 @@
# File: discretize.py
# File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.
import
logger
from
.argtools
import
log_once
from
.argtools
import
memoized
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -13,13 +12,6 @@ from six.moves import range
...
@@ -13,13 +12,6 @@ from six.moves import range
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
@
memoized
def
log_once
(
s
):
logger
.
warn
(
s
)
# just a placeholder
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
Discretizer
(
object
):
class
Discretizer
(
object
):
...
@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D):
...
@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D):
def
get_bin
(
self
,
v
):
def
get_bin
(
self
,
v
):
if
v
<
self
.
minv
:
if
v
<
self
.
minv
:
log_once
(
"UniformDiscretizer1D: value smaller than min!"
)
log_once
(
"UniformDiscretizer1D: value smaller than min!"
,
'warn'
)
return
0
return
0
if
v
>
self
.
maxv
:
if
v
>
self
.
maxv
:
log_once
(
"UniformDiscretizer1D: value larger than max!"
)
log_once
(
"UniformDiscretizer1D: value larger than max!"
,
'warn'
)
return
self
.
nr_bin
-
1
return
self
.
nr_bin
-
1
return
int
(
np
.
clip
(
return
int
(
np
.
clip
(
(
v
-
self
.
minv
)
/
self
.
spacing
,
(
v
-
self
.
minv
)
/
self
.
spacing
,
...
@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer):
...
@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer):
bin_id_nd
=
self
.
get_nd_bin_ids
(
bin_id
)
bin_id_nd
=
self
.
get_nd_bin_ids
(
bin_id
)
return
[
self
.
discretizers
[
k
]
.
get_bin_center
(
bin_id_nd
[
k
])
for
k
in
range
(
self
.
n
)]
return
[
self
.
discretizers
[
k
]
.
get_bin_center
(
bin_id_nd
[
k
])
for
k
in
range
(
self
.
n
)]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#u = UniformDiscretizer1D(-10, 10, 0.12)
#
u = UniformDiscretizer1D(-10, 10, 0.12)
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
import
IPython
as
IP
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/utils/fs.py
View file @
37e98945
...
@@ -54,5 +54,6 @@ def recursive_walk(rootdir):
...
@@ -54,5 +54,6 @@ def recursive_walk(rootdir):
for
f
in
files
:
for
f
in
files
:
yield
os
.
path
.
join
(
r
,
f
)
yield
os
.
path
.
join
(
r
,
f
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
download
(
'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz'
,
'.'
)
download
(
'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz'
,
'.'
)
tensorpack/utils/loadcaffe.py
View file @
37e98945
...
@@ -3,14 +3,9 @@
...
@@ -3,14 +3,9 @@
# File: loadcaffe.py
# File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
collections
import
namedtuple
,
defaultdict
from
abc
import
abstractmethod
import
numpy
as
np
import
numpy
as
np
import
copy
import
os
import
os
from
six.moves
import
zip
from
.utils
import
change_env
,
get_dataset_path
from
.utils
import
change_env
,
get_dataset_path
from
.fs
import
download
from
.fs
import
download
from
.
import
logger
from
.
import
logger
...
@@ -115,7 +110,7 @@ def get_caffe_pb():
...
@@ -115,7 +110,7 @@ def get_caffe_pb():
dir
=
get_dataset_path
(
'caffe'
)
dir
=
get_dataset_path
(
'caffe'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
proto_path
=
download
(
CAFFE_PROTO_URL
,
dir
)
download
(
CAFFE_PROTO_URL
,
dir
)
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
assert
ret
==
0
,
\
...
@@ -123,6 +118,7 @@ def get_caffe_pb():
...
@@ -123,6 +118,7 @@ def get_caffe_pb():
import
imp
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -132,5 +128,4 @@ if __name__ == '__main__':
...
@@ -132,5 +128,4 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
ret
=
load_caffe
(
args
.
model
,
args
.
weights
)
ret
=
load_caffe
(
args
.
model
,
args
.
weights
)
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
np
.
save
(
args
.
output
,
ret
)
tensorpack/utils/logger.py
View file @
37e98945
...
@@ -11,11 +11,11 @@ from datetime import datetime
...
@@ -11,11 +11,11 @@ from datetime import datetime
from
six.moves
import
input
from
six.moves
import
input
import
sys
import
sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
class
_MyFormatter
(
logging
.
Formatter
):
class
_MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
msg
=
'
%(message)
s'
...
@@ -40,12 +40,19 @@ def _getlogger():
...
@@ -40,12 +40,19 @@ def _getlogger():
handler
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
handler
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
logger
.
addHandler
(
handler
)
logger
.
addHandler
(
handler
)
return
logger
return
logger
_logger
=
_getlogger
()
_logger
=
_getlogger
()
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]
# export logger functions
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
get_time_str
():
def
get_time_str
():
return
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
return
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
# logger file and directory:
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
global
LOG_FILE
,
LOG_DIR
LOG_DIR
=
None
LOG_DIR
=
None
...
@@ -55,7 +62,7 @@ def _set_file(path):
...
@@ -55,7 +62,7 @@ def _set_file(path):
if
os
.
path
.
isfile
(
path
):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
get_time_str
()
backup_name
=
path
+
'.'
+
get_time_str
()
shutil
.
move
(
path
,
backup_name
)
shutil
.
move
(
path
,
backup_name
)
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
# noqa: F821
hdl
=
logging
.
FileHandler
(
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
hdl
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
hdl
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
...
@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""")
...
@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""")
if
act
==
'b'
:
if
act
==
'b'
:
backup_name
=
dirname
+
get_time_str
()
backup_name
=
dirname
+
get_time_str
()
shutil
.
move
(
dirname
,
backup_name
)
shutil
.
move
(
dirname
,
backup_name
)
info
(
"Directory '{}' backuped to '{}'"
.
format
(
dirname
,
backup_name
))
info
(
"Directory '{}' backuped to '{}'"
.
format
(
dirname
,
backup_name
))
# noqa: F821
elif
act
==
'd'
:
elif
act
==
'd'
:
shutil
.
rmtree
(
dirname
)
shutil
.
rmtree
(
dirname
)
elif
act
==
'n'
:
elif
act
==
'n'
:
dirname
=
dirname
+
get_time_str
()
dirname
=
dirname
+
get_time_str
()
info
(
"Use a new log directory {}"
.
format
(
dirname
))
info
(
"Use a new log directory {}"
.
format
(
dirname
))
# noqa: F821
elif
act
==
'k'
:
elif
act
==
'k'
:
pass
pass
else
:
else
:
...
@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""")
...
@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""")
_set_file
(
LOG_FILE
)
_set_file
(
LOG_FILE
)
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]
# export logger functions
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
def
disable_logger
():
""" disable all logging ability from this moment"""
""" disable all logging ability from this moment"""
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
...
@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False):
...
@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False):
def
warn_dependency
(
name
,
dependencies
):
def
warn_dependency
(
name
,
dependencies
):
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
# noqa: F821
tensorpack/utils/naming.py
View file @
37e98945
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# File: naming.py
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
...
@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
...
@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
# placeholders for input variables
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
import
tensorflow
as
tf
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
# export all upper case variables
# export all upper case variables
...
...
tensorpack/utils/serialize.py
View file @
37e98945
...
@@ -6,16 +6,13 @@
...
@@ -6,16 +6,13 @@
import
msgpack
import
msgpack
import
msgpack_numpy
import
msgpack_numpy
msgpack_numpy
.
patch
()
msgpack_numpy
.
patch
()
#import dill
__all__
=
[
'loads'
,
'dumps'
]
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
def
dumps
(
obj
):
# return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
def
loads
(
buf
):
# return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/timer.py
View file @
37e98945
...
@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False):
...
@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False):
logger
.
info
(
'{} finished, time:{:.2f}sec.'
.
format
(
logger
.
info
(
'{} finished, time:{:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
msg
,
time
.
time
()
-
start
))
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
...
@@ -66,4 +67,5 @@ def print_total_timer():
...
@@ -66,4 +67,5 @@ def print_total_timer():
logger
.
info
(
"Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
logger
.
info
(
"Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
atexit
.
register
(
print_total_timer
)
atexit
.
register
(
print_total_timer
)
tensorpack/utils/utils.py
View file @
37e98945
...
@@ -8,7 +8,6 @@ from contextlib import contextmanager
...
@@ -8,7 +8,6 @@ from contextlib import contextmanager
import
inspect
import
inspect
from
datetime
import
datetime
from
datetime
import
datetime
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
time
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'change_env'
,
__all__
=
[
'change_env'
,
...
...
tensorpack/utils/viz.py
View file @
37e98945
...
@@ -106,7 +106,6 @@ def build_patch_list(patch_list,
...
@@ -106,7 +106,6 @@ def build_patch_list(patch_list,
ph
,
pw
=
patch_list
.
shape
[
1
:
3
]
ph
,
pw
=
patch_list
.
shape
[
1
:
3
]
if
border
is
None
:
if
border
is
None
:
border
=
int
(
0.1
*
min
(
ph
,
pw
))
border
=
int
(
0.1
*
min
(
ph
,
pw
))
mh
,
mw
=
max
(
max_height
,
ph
+
border
),
max
(
max_width
,
pw
+
border
)
if
nr_row
is
None
:
if
nr_row
is
None
:
nr_row
=
minnone
(
nr_row
,
max_height
/
(
ph
+
border
))
nr_row
=
minnone
(
nr_row
,
max_height
/
(
ph
+
border
))
if
nr_col
is
None
:
if
nr_col
is
None
:
...
@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True,
...
@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True,
if
viz
is
not
None
:
if
viz
is
not
None
:
vizlist
.
append
(
img
)
vizlist
.
append
(
img
)
if
viz
is
not
None
and
len
(
vizlist
)
>=
vizsize
:
if
viz
is
not
None
and
len
(
vizlist
)
>=
vizsize
:
patch
=
next
(
build_patch_list
(
next
(
build_patch_list
(
vizlist
[:
vizsize
],
vizlist
[:
vizsize
],
nr_row
=
viz
[
0
],
nr_col
=
viz
[
1
],
viz
=
True
))
nr_row
=
viz
[
0
],
nr_col
=
viz
[
1
],
viz
=
True
))
vizlist
=
vizlist
[
vizsize
:]
vizlist
=
vizlist
[
vizsize
:]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
cv2
imglist
=
[]
imglist
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
fname
=
"{:03d}.png"
.
format
(
i
)
fname
=
"{:03d}.png"
.
format
(
i
)
...
...
tox.ini
0 → 100644
View file @
37e98945
[flake8]
max-line-length
=
120
exclude
=
.git,
__init__.py,
snippet,
docs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment