Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
37e98945
Commit
37e98945
authored
Jan 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix flake8 style in tensorpack/
parent
233b3b90
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
69 changed files
with
164 additions
and
202 deletions
+164
-202
tensorpack/RL/common.py
tensorpack/RL/common.py
+0
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+0
-1
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+3
-1
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+1
-0
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+5
-8
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-4
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+1
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-1
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+0
-3
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+0
-1
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+0
-2
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+0
-2
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+1
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+3
-5
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+7
-6
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+2
-2
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+0
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-2
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+2
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+3
-3
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+3
-3
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+1
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+1
-1
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+1
-1
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+1
-2
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+4
-2
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+2
-3
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+1
-0
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+3
-3
tensorpack/models/_test.py
tensorpack/models/_test.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-4
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+2
-4
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-3
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+8
-9
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+0
-6
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+2
-3
tensorpack/models/pool.py
tensorpack/models/pool.py
+3
-4
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+4
-3
tensorpack/predict/common.py
tensorpack/predict/common.py
+2
-7
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+3
-6
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+4
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+2
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-2
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+4
-7
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+2
-2
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+0
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+4
-3
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+10
-8
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-5
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+9
-2
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+4
-2
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+1
-0
tensorpack/utils/discretize.py
tensorpack/utils/discretize.py
+5
-12
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+1
-0
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+2
-7
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+13
-12
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-1
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+0
-3
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+2
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+0
-1
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+2
-3
tox.ini
tox.ini
+6
-0
No files found.
tensorpack/RL/common.py
View file @
37e98945
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
from
collections
import
deque
from
.envbase
import
ProxyPlayer
...
...
tensorpack/RL/envbase.py
View file @
37e98945
...
...
@@ -7,7 +7,6 @@
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
import
six
import
random
from
..utils
import
get_rng
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
,
...
...
tensorpack/RL/expreplay.py
View file @
37e98945
...
...
@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback):
if
__name__
==
'__main__'
:
from
.atari
import
AtariPlayer
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
def
predictor
(
x
):
np
.
array
([
1
,
1
,
1
,
1
])
player
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0
,
frame_skip
=
10
,
height_range
=
(
36
,
204
))
E
=
ExpReplay
(
predictor
,
player
=
player
,
...
...
tensorpack/RL/gymenv.py
View file @
37e98945
...
...
@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment):
assert
isinstance
(
spc
,
gym
.
spaces
.
discrete
.
Discrete
)
return
DiscreteActionSpace
(
spc
.
n
)
if
__name__
==
'__main__'
:
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
num
=
env
.
get_action_space
()
.
num_actions
()
...
...
tensorpack/RL/simulator.py
View file @
37e98945
...
...
@@ -7,10 +7,8 @@ import tensorflow as tf
import
multiprocessing
as
mp
import
time
import
threading
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
import
numpy
as
np
from
collections
import
defaultdict
import
six
from
six.moves
import
queue
...
...
@@ -20,7 +18,6 @@ from ..callbacks import Callback
from
..tfutils.varmanip
import
SessionUpdate
from
..predict
import
OfflinePredictor
from
..utils
import
logger
#from ..utils.timer import *
from
..utils.serialize
import
loads
,
dumps
from
..utils.concurrency
import
LoopThread
,
ensure_proc_terminate
...
...
@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
reward
,
isOver
=
player
.
action
(
action
)
state
=
player
.
current_state
()
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
...
...
@@ -284,6 +282,7 @@ class WeightSync(Callback):
self
.
condvar
.
notify_all
()
self
.
condvar
.
release
()
if
__name__
==
'__main__'
:
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
...
...
@@ -293,14 +292,13 @@ if __name__ == '__main__':
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
class
NaiveActioner
(
SimulatorMaster
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
#
print("Over: ", client.memory)
client
.
memory
=
[]
client
.
state
=
0
...
...
@@ -312,5 +310,4 @@ if __name__ == '__main__':
ensure_proc_terminate
(
procs
)
th
.
start
()
import
time
time
.
sleep
(
100
)
tensorpack/callbacks/base.py
View file @
37e98945
...
...
@@ -3,10 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
sys
import
os
import
time
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
ABCMeta
import
six
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
...
...
tensorpack/callbacks/graph.py
View file @
37e98945
...
...
@@ -6,7 +6,6 @@
""" Graph related callbacks"""
from
.base
import
Callback
from
..utils
import
logger
__all__
=
[
'RunOp'
]
...
...
@@ -26,7 +25,7 @@ class RunOp(Callback):
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
#self._op_name = self._op.name
#
self._op_name = self._op.name
def
_before_train
(
self
):
if
self
.
run_before
:
...
...
@@ -35,6 +34,3 @@ class RunOp(Callback):
def
_trigger_epoch
(
self
):
if
self
.
run_epoch
:
self
.
_op
.
run
()
# def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/group.py
View file @
37e98945
...
...
@@ -86,7 +86,6 @@ class Callbacks(Callback):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
test_sess_restored
=
False
for
cb
in
self
.
cbs
:
display_name
=
str
(
cb
)
with
tm
.
timed_callback
(
display_name
):
...
...
tensorpack/callbacks/inference.py
View file @
37e98945
...
...
@@ -2,14 +2,13 @@
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
from
abc
import
ABCMeta
,
abstractmethod
import
sys
import
six
from
six.moves
import
zip
from
..utils
import
logger
,
execute_only_once
from
..utils
import
logger
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_var_name
...
...
tensorpack/callbacks/inference_runner.py
View file @
37e98945
...
...
@@ -12,7 +12,6 @@ from ..dataflow import DataFlow
from
.base
import
Callback
from
.inference
import
Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
,
get_tqdm
from
..train.input_data
import
FeedfreeInput
...
...
@@ -99,7 +98,6 @@ class InferenceRunner(Callback):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
sess
=
tf
.
get_default_session
()
self
.
ds
.
reset_state
()
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
...
...
@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
sess
=
tf
.
get_default_session
()
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
...
...
tensorpack/callbacks/param.py
View file @
37e98945
...
...
@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
from
abc
import
abstractmethod
,
ABCMeta
import
operator
import
six
import
os
...
...
tensorpack/callbacks/saver.py
View file @
37e98945
...
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
import
os
import
shutil
import
re
from
.base
import
Callback
from
..utils
import
logger
...
...
tensorpack/callbacks/stats.py
View file @
37e98945
...
...
@@ -2,8 +2,6 @@
# File: stats.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
re
import
os
import
operator
import
json
...
...
tensorpack/dataflow/common.py
View file @
37e98945
...
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
__future__
import
division
import
copy
import
numpy
as
np
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
...
...
@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow):
super
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
try
:
s
=
ds
.
size
()
assert
batch_size
<=
ds
.
size
()
except
NotImplementedError
:
pass
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
37e98945
...
...
@@ -8,7 +8,7 @@ import glob
import
cv2
import
numpy
as
np
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
..base
import
RNGDataFlow
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
37e98945
...
...
@@ -5,16 +5,13 @@
# Yukun Chen <cykustc@gmail.com>
import
os
import
sys
import
pickle
import
numpy
as
np
import
random
import
six
from
six.moves
import
urllib
,
range
from
six.moves
import
range
import
copy
import
logging
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
..base
import
RNGDataFlow
...
...
@@ -152,6 +149,7 @@ class Cifar100(CifarBase):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
if
__name__
==
'__main__'
:
ds
=
Cifar10
(
'train'
)
from
tensorpack.dataflow.dftools
import
dump_dataset_images
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
37e98945
...
...
@@ -5,11 +5,11 @@
import
os
import
tarfile
import
cv2
import
six
import
numpy
as
np
from
six.moves
import
range
import
xml.etree.ElementTree
as
ET
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
from
...utils.timer
import
timed_operation
...
...
@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow):
box
=
root
.
find
(
'object'
)
.
find
(
'bndbox'
)
.
getchildren
()
box
=
map
(
lambda
x
:
float
(
x
.
text
),
box
)
#box[0] /= size[0]
#box[1] /= size[1]
#box[2] /= size[0]
#box[3] /= size[1]
#
box[0] /= size[0]
#
box[1] /= size[1]
#
box[2] /= size[0]
#
box[3] /= size[1]
return
np
.
asarray
(
box
,
dtype
=
'float32'
)
with
timed_operation
(
'Loading Bounding Boxes ...'
):
...
...
@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow):
logger
.
info
(
"{}/{} images have bounding box."
.
format
(
cnt
,
len
(
imglist
)))
return
ret
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
# print(meta.get_synset_words_1000())
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
37e98945
...
...
@@ -5,9 +5,8 @@
import
os
import
gzip
import
random
import
numpy
from
six.moves
import
urllib
,
range
from
six.moves
import
range
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
...
...
@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow):
label
=
self
.
labels
[
k
]
yield
[
img
,
label
]
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
...
...
tensorpack/dataflow/dataset/ptb.py
View file @
37e98945
...
...
@@ -9,9 +9,7 @@ import numpy as np
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.argtools
import
memoized_ignoreargs
from
..base
import
RNGDataFlow
try
:
import
tensorflow
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
except
ImportError
:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
37e98945
...
...
@@ -5,9 +5,8 @@
import
os
import
numpy
as
np
from
six.moves
import
range
from
...utils
import
logger
,
get_
rng
,
get_
dataset_path
from
...utils
import
logger
,
get_dataset_path
from
..base
import
RNGDataFlow
try
:
...
...
@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow):
c
=
SVHNDigit
(
'extra'
)
return
np
.
concatenate
((
a
.
X
,
b
.
X
,
c
.
X
))
.
mean
(
axis
=
0
)
if
__name__
==
'__main__'
:
a
=
SVHNDigit
(
'train'
)
b
=
SVHNDigit
.
get_per_pixel_mean
()
tensorpack/dataflow/dataset/visualqa.py
View file @
37e98945
...
...
@@ -4,8 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..base
import
DataFlow
from
...utils
import
*
from
...utils.timer
import
*
from
...utils.timer
import
timed_operation
from
six.moves
import
zip
,
map
from
collections
import
Counter
import
json
...
...
@@ -74,12 +73,11 @@ class VisualQA(DataFlow):
ret
=
cnt
.
most_common
(
n
)
return
[
k
[
0
]
for
k
in
ret
]
if
__name__
==
'__main__'
:
vqa
=
VisualQA
(
'/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json'
,
'/home/wyx/data/VQA/mscoco_train2014_annotations.json'
)
for
k
in
vqa
.
get_data
():
print
(
json
.
dumps
(
k
))
break
# vqa.get_common_question_words(100)
vqa
.
get_common_answer
(
100
)
#from IPython import embed; embed()
tensorpack/dataflow/format.py
View file @
37e98945
...
...
@@ -6,10 +6,11 @@ import numpy as np
from
six.moves
import
range
import
os
from
..utils
import
logger
,
get_
rng
,
get_
tqdm
from
..utils
import
logger
,
get_tqdm
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.serialize
import
loads
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
try
:
...
...
@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow):
if
k
!=
'__keys__'
:
yield
[
k
,
v
]
else
:
s
=
self
.
size
()
self
.
rng
.
shuffle
(
self
.
keys
)
for
k
in
self
.
keys
:
v
=
self
.
_txn
.
get
(
k
)
...
...
@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder):
img
=
np
.
fromstring
(
datum
.
data
,
dtype
=
np
.
uint8
)
img
=
img
.
reshape
(
datum
.
channels
,
datum
.
height
,
datum
.
width
)
except
Exception
:
log_once
(
"Cannot read key {}"
.
format
(
k
))
log_once
(
"Cannot read key {}"
.
format
(
k
)
,
'warn'
)
return
None
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
...
...
tensorpack/dataflow/image.py
View file @
37e98945
...
...
@@ -4,8 +4,7 @@
import
numpy
as
np
import
cv2
import
copy
from
.base
import
RNGDataFlow
,
DataFlow
,
ProxyDataFlow
from
.base
import
RNGDataFlow
from
.common
import
MapDataComponent
,
MapData
from
.imgaug
import
AugmentorList
...
...
@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent):
Augment the image component of datapoints
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index (or list of indices) of the image component in the produced datapoints by `ds`. default to be 0
:param index: the index (or list of indices) of the image component
in the produced datapoints by `ds`. default to be 0
"""
if
isinstance
(
augmentors
,
AugmentorList
):
self
.
augs
=
augmentors
...
...
tensorpack/dataflow/imgaug/base.py
View file @
37e98945
...
...
@@ -55,7 +55,7 @@ class Augmentor(object):
def
_rand_range
(
self
,
low
=
1.0
,
high
=
None
,
size
=
None
):
if
high
is
None
:
low
,
high
=
0
,
low
if
size
==
None
:
if
size
is
None
:
size
=
[]
return
self
.
rng
.
uniform
(
low
,
high
,
size
)
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
37e98945
...
...
@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor):
self
.
_init
(
locals
())
def
_augment
(
self
,
img
,
_
):
orig_shape
=
img
.
shape
return
img
[
self
.
rect
.
y0
:
self
.
rect
.
y1
+
1
,
self
.
rect
.
x0
:
self
.
rect
.
x0
+
1
]
...
...
@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor):
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/dataflow/imgaug/deform.py
View file @
37e98945
...
...
@@ -26,7 +26,7 @@ class GaussianMap(object):
y
=
y
.
astype
(
'float32'
)
/
ret
.
shape
[
0
]
-
anchor
[
0
]
x
=
x
.
astype
(
'float32'
)
/
ret
.
shape
[
1
]
-
anchor
[
1
]
g
=
np
.
exp
(
-
(
x
**
2
+
y
**
2
)
/
self
.
sigma
)
#cv2.imshow(" ", g)
#
cv2.imshow(" ", g)
# cv2.waitKey()
return
g
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
37e98945
...
...
@@ -6,7 +6,6 @@
from
.base
import
ImageAugmentor
import
math
import
cv2
import
numpy
as
np
__all__
=
[
'Rotation'
,
'RotationAndCropValid'
]
...
...
@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor):
newh
=
min
(
newh
,
ret
.
shape
[
0
])
newx
=
int
(
center
[
0
]
-
neww
*
0.5
)
newy
=
int
(
center
[
1
]
-
newh
*
0.5
)
#print(ret.shape, deg, newx, newy, neww, newh)
#
print(ret.shape, deg, newx, newy, neww, newh)
return
ret
[
newy
:
newy
+
newh
,
newx
:
newx
+
neww
]
@
staticmethod
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
37e98945
...
...
@@ -131,7 +131,8 @@ class Clip(ImageAugmentor):
class
Saturation
(
ImageAugmentor
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" Saturation,
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
super
(
Saturation
,
self
)
.
__init__
()
assert
alpha
<
1
...
...
@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
The implementation follows 'fb.resnet.torch': https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
The implementation follows 'fb.resnet.torch':
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
:param eigvec: each column is one eigen vector
"""
...
...
tensorpack/dataflow/prefetch.py
View file @
37e98945
...
...
@@ -4,10 +4,8 @@
from
__future__
import
print_function
import
multiprocessing
as
mp
from
threading
import
Thread
import
itertools
from
six.moves
import
range
,
zip
from
six.moves.queue
import
Queue
import
uuid
import
os
...
...
@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
:param pipedir: a local directory where the pipes would be. Useful if you're running on non-local FS such as NFS.
:param pipedir: a local directory where the pipes would be.
Useful if you're running on non-local FS such as NFS.
"""
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
try
:
...
...
tensorpack/dataflow/remote.py
View file @
37e98945
...
...
@@ -51,6 +51,7 @@ class RemoteData(DataFlow):
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
yield
dp
if
__name__
==
'__main__'
:
import
sys
from
tqdm
import
tqdm
...
...
tensorpack/dataflow/tf_func.py
View file @
37e98945
...
...
@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow):
if
__name__
==
'__main__'
:
from
.raw
import
FakeData
from
.prefetch
import
PrefetchDataZMQ
from
.image
import
AugmentImageComponent
from
.
import
imgaug
ds
=
FakeData
([[
224
,
224
,
3
]],
100000
,
random
=
False
)
def
tf_aug
(
v
):
...
...
@@ -69,6 +66,9 @@ if __name__ == '__main__':
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
# from .prefetch import PrefetchDataZMQ
# from .image import AugmentImageComponent
# from . import imgaug
# ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
...
...
tensorpack/models/_test.py
View file @
37e98945
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
unittest
...
...
@@ -29,6 +28,7 @@ def run_test_case(case):
suite
=
unittest
.
TestLoader
()
.
loadTestsFromTestCase
(
case
)
unittest
.
TextTestRunner
(
verbosity
=
2
)
.
run
(
suite
)
if
__name__
==
'__main__'
:
import
tensorpack
from
tensorpack.utils
import
logger
...
...
tensorpack/models/batch_norm.py
View file @
37e98945
...
...
@@ -6,8 +6,6 @@
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
copy
import
copy
import
re
from
..tfutils.common
import
get_tf_version
from
..tfutils.tower
import
get_current_tower_context
...
...
@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
# training tower
if
ctx
.
is_training
:
#reuse = tf.get_variable_scope().reuse
#
reuse = tf.get_variable_scope().reuse
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
...
...
@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
var_var_name
=
ema
.
average_name
(
batch_var
)
sc
=
tf
.
get_variable_scope
()
if
ctx
.
is_main_tower
:
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the desired variable name could
...
...
@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else
:
return
tf
.
identity
(
xn
,
name
=
'output'
)
if
get_tf_version
()
>=
12
:
BatchNorm
=
BatchNormV2
else
:
...
...
tensorpack/models/conv2d.py
View file @
37e98945
...
...
@@ -3,12 +3,9 @@
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
tensorflow
as
tf
import
math
from
._common
import
layer_register
,
shape2d
,
shape4d
from
..utils
import
logger
from
..utils.argtools
import
shape2d
__all__
=
[
'Conv2D'
,
'Deconv2D'
]
...
...
@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape,
conv
=
tf
.
concat
(
3
,
outputs
)
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. "
"Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
...
...
tensorpack/models/fc.py
View file @
37e98945
...
...
@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
math
from
._common
import
layer_register
from
..tfutils
import
symbolic_functions
as
symbf
from
..utils
import
logger
__all__
=
[
'FullyConnected'
]
...
...
@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim,
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
#W_init = tf.uniform_unit_scaling_initializer(factor=1.43)
W_init
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
()
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
...
...
@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim,
prod
=
tf
.
nn
.
xw_plus_b
(
x
,
W
,
b
)
if
use_bias
else
tf
.
matmul
(
x
,
W
)
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead."
)
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated."
" Please use argscope instead."
)
nl
=
tf
.
nn
.
relu
return
nl
(
prod
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
37e98945
...
...
@@ -6,6 +6,7 @@
import
tensorflow
as
tf
from
._common
import
layer_register
from
._test
import
TestModel
__all__
=
[
'ImageSample'
]
...
...
@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'):
diffy
,
diffx
=
tf
.
split
(
3
,
2
,
diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#
prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50)
...
...
@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'):
ret
=
ret
*
tf
.
cast
(
mask
,
tf
.
float32
)
return
ret
from
._test
import
TestModel
class
TestSample
(
TestModel
):
...
...
@@ -128,9 +127,9 @@ class TestSample(TestModel):
bimg
=
np
.
random
.
rand
(
2
,
h
,
w
,
3
)
.
astype
(
'float32'
)
# mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
#
[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#
[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#
], dtype='float32') #2x2x2x2
mat
=
(
np
.
random
.
rand
(
2
,
5
,
5
,
2
)
-
0.2
)
*
np
.
array
([
h
+
3
,
w
+
3
])
true_res
=
np_sample
(
bimg
,
np
.
floor
(
mat
+
0.5
)
.
astype
(
'int32'
))
...
...
@@ -140,10 +139,10 @@ class TestSample(TestModel):
self
.
assertTrue
((
res
==
true_res
)
.
all
())
if
__name__
==
'__main__'
:
import
cv2
import
numpy
as
np
import
sys
im
=
cv2
.
imread
(
'cat.jpg'
)
im
=
im
.
reshape
((
1
,)
+
im
.
shape
)
.
astype
(
'float32'
)
imv
=
tf
.
Variable
(
im
)
...
...
@@ -160,8 +159,8 @@ if __name__ == '__main__':
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#
out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#
out = sess.run(output)
# print(out[0].min())
# print(out[0].max())
# print(out[0].sum())
...
...
tensorpack/models/model_desc.py
View file @
37e98945
...
...
@@ -4,25 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
import
re
import
tensorflow
as
tf
from
collections
import
namedtuple
import
inspect
import
pickle
import
six
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class
InputVar
(
object
):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
self
.
type
=
type
self
.
shape
=
shape
...
...
tensorpack/models/nonlin.py
View file @
37e98945
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
copy
import
copy
from
._common
import
layer_register
from
.batch_norm
import
BatchNorm
...
...
@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None):
if
name
is
None
:
name
=
'output'
return
tf
.
maximum
(
x
,
alpha
*
x
,
name
=
name
)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#
alpha = float(alpha)
#
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
# return tf.mul(x, 0.5, name=name)
...
...
tensorpack/models/pool.py
View file @
37e98945
...
...
@@ -8,6 +8,8 @@ import numpy as np
from
._common
import
layer_register
,
shape4d
from
..utils.argtools
import
shape2d
from
..tfutils
import
symbolic_functions
as
symbf
from
._test
import
TestModel
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
...
...
@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape):
:param x: input NHWC tensor
:param shape: an integer, the upsample factor
"""
#inp_shape = tf.shape(x)
#
inp_shape = tf.shape(x)
# return tf.image.resize_bilinear(x,
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True)
...
...
@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape):
return
deconv
from
._test
import
TestModel
class
TestPool
(
TestModel
):
def
test_fixed_unpooling
(
self
):
...
...
tensorpack/models/regularize.py
View file @
37e98945
...
...
@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
l2_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
l1_regularizer
=
tf
.
contrib
.
layers
.
l1_regularizer
...
...
tensorpack/predict/base.py
View file @
37e98945
...
...
@@ -3,11 +3,11 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
from
abc
import
abstractmethod
,
ABCMeta
import
tensorflow
as
tf
import
six
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
...
...
@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
fn
=
lambda
_
:
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
())
def
fn
(
_
):
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
...
...
tensorpack/predict/common.py
View file @
37e98945
...
...
@@ -2,19 +2,14 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
collections
import
namedtuple
import
six
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
.base
import
OfflinePredictor
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
...
...
@@ -53,7 +48,7 @@ class PredictConfig(object):
self
.
input_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
if
self
.
input_names
is
not
None
:
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
#
logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if
self
.
input_names
is
None
:
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input_vars_desc
()
...
...
@@ -61,7 +56,7 @@ class PredictConfig(object):
self
.
output_names
=
kwargs
.
pop
(
'output_names'
,
None
)
if
self
.
output_names
is
None
:
self
.
output_names
=
kwargs
.
pop
(
'output_var_names'
)
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
#
logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert
len
(
self
.
input_names
),
self
.
input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
)
...
...
tensorpack/predict/concurrency.py
View file @
37e98945
...
...
@@ -5,10 +5,8 @@
import
multiprocessing
import
threading
import
tensorflow
as
tf
import
time
import
six
from
six.moves
import
queue
,
range
,
zip
from
six.moves
import
queue
,
range
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
...
...
@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
self
.
predictor
=
OfflinePredictor
(
self
.
config
)
import
sys
if
self
.
idx
==
0
:
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
...
...
@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
assert
len
(
predictors
)
for
k
in
predictors
:
#assert isinstance(k, OnlinePredictor), type(k)
#
assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
assert
not
k
.
return_input
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
PredictorWorkerThread
(
...
...
tensorpack/predict/dataset.py
View file @
37e98945
...
...
@@ -9,7 +9,7 @@ import multiprocessing
import
os
import
six
from
..dataflow
import
DataFlow
,
BatchData
from
..dataflow
import
DataFlow
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils
import
logger
,
get_tqdm
...
...
tensorpack/tfutils/common.py
View file @
37e98945
...
...
@@ -3,7 +3,7 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
..utils.naming
import
*
from
..utils.naming
import
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
import
tensorflow
as
tf
from
copy
import
copy
import
six
...
...
@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf
.
gpu_options
.
allocator_type
=
'BFC'
conf
.
gpu_options
.
allow_growth
=
True
conf
.
allow_soft_placement
=
True
#conf.log_device_placement = True
#
conf.log_device_placement = True
return
conf
...
...
@@ -74,6 +74,7 @@ def get_op_tensor_name(name):
else
:
return
name
,
name
+
':0'
get_op_var_name
=
get_op_tensor_name
...
...
@@ -88,6 +89,7 @@ def get_tensors_by_names(names):
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
get_vars_by_names
=
get_tensors_by_names
...
...
tensorpack/tfutils/gradproc.py
View file @
37e98945
...
...
@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor):
ret
.
append
((
grad
,
var
))
return
ret
_summaried_gradient
=
set
()
...
...
@@ -133,7 +134,7 @@ class CheckGradient(MapGradient):
def
_mapper
(
self
,
grad
,
var
):
# this is very slow.... see #3649
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
#
op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad
=
tf
.
check_numerics
(
grad
,
'CheckGradient-'
+
var
.
op
.
name
)
return
grad
...
...
tensorpack/tfutils/sessinit.py
View file @
37e98945
...
...
@@ -5,7 +5,6 @@
import
os
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
import
re
import
numpy
as
np
import
tensorflow
as
tf
import
six
...
...
@@ -120,7 +119,8 @@ class SaverRestore(SessionInit):
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
if
v
.
startswith
(
PREDICT_TOWER
):
logger
.
error
(
"Found {} in checkpoint. But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
logger
.
error
(
"Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
...
...
tensorpack/tfutils/summary.py
View file @
37e98945
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
import
re
from
..utils.argtools
import
memoized
from
..utils.naming
import
*
from
..utils.naming
import
MOVING_SUMMARY_VARS_KEY
from
.tower
import
get_current_tower_context
from
.
import
get_global_step_var
from
.symbolic_functions
import
rms
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
37e98945
...
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
import
numpy
as
np
from
..utils
import
logger
def
prediction_incorrect
(
logits
,
label
,
topk
=
1
,
name
=
'incorrect_vector'
):
...
...
@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost
=
tf
.
nn
.
weighted_cross_entropy_with_logits
(
logits
,
y
,
pos_weight
)
cost
=
tf
.
reduce_mean
(
cost
*
(
1
-
beta
),
name
=
name
)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# cost = tf.sub(loss_pos, loss_neg, name=name)
return
cost
...
...
tensorpack/tfutils/tower.py
View file @
37e98945
...
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
re
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
...
...
@@ -49,7 +49,7 @@ class TowerContext(object):
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
scope
=
tf
.
get_variable_scope
()
assert
scope
.
reuse
==
Fal
se
assert
not
scope
.
reu
se
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
...
...
tensorpack/tfutils/varmanip.py
View file @
37e98945
...
...
@@ -10,7 +10,7 @@ from collections import defaultdict
import
re
import
numpy
as
np
from
..utils
import
logger
from
..utils.naming
import
*
from
..utils.naming
import
PREDICT_TOWER
from
.common
import
get_op_tensor_name
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
...
...
@@ -51,7 +51,7 @@ class SessionUpdate(object):
self
.
sess
=
sess
self
.
assign_ops
=
defaultdict
(
list
)
for
v
in
vars_to_update
:
#p = tf.placeholder(v.dtype, shape=v.get_shape())
#
p = tf.placeholder(v.dtype, shape=v.get_shape())
with
tf
.
device
(
'/cpu:0'
):
p
=
tf
.
placeholder
(
v
.
dtype
)
savename
=
get_savename_from_varname
(
v
.
name
)
...
...
tensorpack/train/base.py
View file @
37e98945
...
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
import
signal
import
re
import
weakref
import
six
...
...
tensorpack/train/feedfree.py
View file @
37e98945
...
...
@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from
..tfutils.tower
import
TowerContext
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
.input_data
import
QueueInput
,
FeedfreeInput
,
DummyConstantInput
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
...
...
@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#
self.train_op = tf.group(*self.dequed_inputs)
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
...
...
@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
...
...
tensorpack/train/input_data.py
View file @
37e98945
...
...
@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread):
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
...
...
tensorpack/train/multigpu.py
View file @
37e98945
...
...
@@ -9,9 +9,9 @@ import re
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils.naming
import
*
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
...
...
@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer):
for
idx
,
t
in
enumerate
(
towers
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
))
as
scope
:
TowerContext
(
'tower{}'
.
format
(
idx
)):
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
grad_list
.
append
(
get_tower_grad_func
())
...
...
@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
...
...
@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
if
None
in
nones
and
len
(
nones
)
!=
1
:
raise
RuntimeError
(
"Gradient w.r.t {} is None in some but not all towers!"
.
format
(
v
.
name
))
elif
nones
[
0
]
is
None
:
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
v
ar
.
op
.
name
))
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
v
.
op
.
name
))
continue
try
:
grad
=
tf
.
add_n
(
all_grad
)
/
float
(
len
(
tower_grads
))
...
...
@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#
ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#
self.train_op = tf.group(*ops)
# return
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
...
...
@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
...
...
tensorpack/train/trainer.py
View file @
37e98945
...
...
@@ -3,18 +3,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
time
from
six.moves
import
zip
from
.base
import
Trainer
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
,
FeedfreeInput
from
.input_data
import
FeedInput
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
...
...
@@ -49,7 +47,8 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
fn
=
lambda
_
:
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
...
...
tensorpack/utils/argtools.py
View file @
37e98945
...
...
@@ -9,8 +9,9 @@ import inspect
import
six
import
functools
import
collections
from
.
import
logger
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
,
'log_once'
]
def
map_arg
(
**
maps
):
...
...
@@ -64,11 +65,12 @@ class memoized(object):
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable.
is it necessary?
h
ash
(
func
)
# make sure it is hashable. TODO
is it necessary?
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
...
...
@@ -99,3 +101,8 @@ def shape2d(a):
assert
len
(
a
)
==
2
return
list
(
a
)
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
@
memoized
def
log_once
(
message
,
func
):
getattr
(
logger
,
func
)(
message
)
tensorpack/utils/concurrency.py
View file @
37e98945
...
...
@@ -11,13 +11,15 @@ from contextlib import contextmanager
import
signal
import
weakref
import
six
from
six.moves
import
queue
from
.
import
logger
if
six
.
PY2
:
import
subprocess32
as
subprocess
else
:
import
subprocess
from
six.moves
import
queue
from
.
import
logger
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
...
...
tensorpack/utils/debug.py
View file @
37e98945
...
...
@@ -28,6 +28,7 @@ def enable_call_trace():
return
sys
.
settrace
(
tracer
)
if
__name__
==
'__main__'
:
enable_call_trace
()
...
...
tensorpack/utils/discretize.py
View file @
37e98945
...
...
@@ -3,8 +3,7 @@
# File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.
import
logger
from
.argtools
import
memoized
from
.argtools
import
log_once
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
import
six
...
...
@@ -13,13 +12,6 @@ from six.moves import range
__all__
=
[
'UniformDiscretizer1D'
,
'UniformDiscretizerND'
]
@
memoized
def
log_once
(
s
):
logger
.
warn
(
s
)
# just a placeholder
@
six
.
add_metaclass
(
ABCMeta
)
class
Discretizer
(
object
):
...
...
@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D):
def
get_bin
(
self
,
v
):
if
v
<
self
.
minv
:
log_once
(
"UniformDiscretizer1D: value smaller than min!"
)
log_once
(
"UniformDiscretizer1D: value smaller than min!"
,
'warn'
)
return
0
if
v
>
self
.
maxv
:
log_once
(
"UniformDiscretizer1D: value larger than max!"
)
log_once
(
"UniformDiscretizer1D: value larger than max!"
,
'warn'
)
return
self
.
nr_bin
-
1
return
int
(
np
.
clip
(
(
v
-
self
.
minv
)
/
self
.
spacing
,
...
...
@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer):
bin_id_nd
=
self
.
get_nd_bin_ids
(
bin_id
)
return
[
self
.
discretizers
[
k
]
.
get_bin_center
(
bin_id_nd
[
k
])
for
k
in
range
(
self
.
n
)]
if
__name__
==
'__main__'
:
#u = UniformDiscretizer1D(-10, 10, 0.12)
#
u = UniformDiscretizer1D(-10, 10, 0.12)
u
=
UniformDiscretizerND
((
0
,
100
,
1
),
(
0
,
100
,
1
),
(
0
,
100
,
1
))
import
IPython
as
IP
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/utils/fs.py
View file @
37e98945
...
...
@@ -54,5 +54,6 @@ def recursive_walk(rootdir):
for
f
in
files
:
yield
os
.
path
.
join
(
r
,
f
)
if
__name__
==
'__main__'
:
download
(
'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz'
,
'.'
)
tensorpack/utils/loadcaffe.py
View file @
37e98945
...
...
@@ -3,14 +3,9 @@
# File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
collections
import
namedtuple
,
defaultdict
from
abc
import
abstractmethod
import
numpy
as
np
import
copy
import
os
from
six.moves
import
zip
from
.utils
import
change_env
,
get_dataset_path
from
.fs
import
download
from
.
import
logger
...
...
@@ -115,7 +110,7 @@ def get_caffe_pb():
dir
=
get_dataset_path
(
'caffe'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
proto_path
=
download
(
CAFFE_PROTO_URL
,
dir
)
download
(
CAFFE_PROTO_URL
,
dir
)
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
...
...
@@ -123,6 +118,7 @@ def get_caffe_pb():
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
if
__name__
==
'__main__'
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -132,5 +128,4 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
ret
=
load_caffe
(
args
.
model
,
args
.
weights
)
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
tensorpack/utils/logger.py
View file @
37e98945
...
...
@@ -11,11 +11,11 @@ from datetime import datetime
from
six.moves
import
input
import
sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
class
_MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
...
...
@@ -40,12 +40,19 @@ def _getlogger():
handler
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
logger
.
addHandler
(
handler
)
return
logger
_logger
=
_getlogger
()
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]
# export logger functions
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
get_time_str
():
return
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
LOG_DIR
=
None
...
...
@@ -55,7 +62,7 @@ def _set_file(path):
if
os
.
path
.
isfile
(
path
):
backup_name
=
path
+
'.'
+
get_time_str
()
shutil
.
move
(
path
,
backup_name
)
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
# noqa: F821
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
hdl
.
setFormatter
(
_MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
...
...
@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""")
if
act
==
'b'
:
backup_name
=
dirname
+
get_time_str
()
shutil
.
move
(
dirname
,
backup_name
)
info
(
"Directory '{}' backuped to '{}'"
.
format
(
dirname
,
backup_name
))
info
(
"Directory '{}' backuped to '{}'"
.
format
(
dirname
,
backup_name
))
# noqa: F821
elif
act
==
'd'
:
shutil
.
rmtree
(
dirname
)
elif
act
==
'n'
:
dirname
=
dirname
+
get_time_str
()
info
(
"Use a new log directory {}"
.
format
(
dirname
))
info
(
"Use a new log directory {}"
.
format
(
dirname
))
# noqa: F821
elif
act
==
'k'
:
pass
else
:
...
...
@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""")
_set_file
(
LOG_FILE
)
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]
# export logger functions
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
""" disable all logging ability from this moment"""
for
func
in
_LOGGING_METHOD
:
...
...
@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False):
def
warn_dependency
(
name
,
dependencies
):
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
# noqa: F821
tensorpack/utils/naming.py
View file @
37e98945
...
...
@@ -2,6 +2,7 @@
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
...
...
@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
import
tensorflow
as
tf
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
# export all upper case variables
...
...
tensorpack/utils/serialize.py
View file @
37e98945
...
...
@@ -6,16 +6,13 @@
import
msgpack
import
msgpack_numpy
msgpack_numpy
.
patch
()
#import dill
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
# return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
# return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/timer.py
View file @
37e98945
...
...
@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False):
logger
.
info
(
'{} finished, time:{:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
...
...
@@ -66,4 +67,5 @@ def print_total_timer():
logger
.
info
(
"Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
atexit
.
register
(
print_total_timer
)
tensorpack/utils/utils.py
View file @
37e98945
...
...
@@ -8,7 +8,6 @@ from contextlib import contextmanager
import
inspect
from
datetime
import
datetime
from
tqdm
import
tqdm
import
time
import
numpy
as
np
__all__
=
[
'change_env'
,
...
...
tensorpack/utils/viz.py
View file @
37e98945
...
...
@@ -106,7 +106,6 @@ def build_patch_list(patch_list,
ph
,
pw
=
patch_list
.
shape
[
1
:
3
]
if
border
is
None
:
border
=
int
(
0.1
*
min
(
ph
,
pw
))
mh
,
mw
=
max
(
max_height
,
ph
+
border
),
max
(
max_width
,
pw
+
border
)
if
nr_row
is
None
:
nr_row
=
minnone
(
nr_row
,
max_height
/
(
ph
+
border
))
if
nr_col
is
None
:
...
...
@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True,
if
viz
is
not
None
:
vizlist
.
append
(
img
)
if
viz
is
not
None
and
len
(
vizlist
)
>=
vizsize
:
patch
=
next
(
build_patch_list
(
next
(
build_patch_list
(
vizlist
[:
vizsize
],
nr_row
=
viz
[
0
],
nr_col
=
viz
[
1
],
viz
=
True
))
vizlist
=
vizlist
[
vizsize
:]
if
__name__
==
'__main__'
:
import
cv2
imglist
=
[]
for
i
in
range
(
100
):
fname
=
"{:03d}.png"
.
format
(
i
)
...
...
tox.ini
0 → 100644
View file @
37e98945
[flake8]
max-line-length
=
120
exclude
=
.git,
__init__.py,
snippet,
docs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment