Commit 37e98945 authored by Yuxin Wu's avatar Yuxin Wu

fix flake8 style in tensorpack/

parent 233b3b90
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from collections import deque
from .envbase import ProxyPlayer
......
......@@ -7,7 +7,6 @@
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import six
import random
from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
......
......@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback):
if __name__ == '__main__':
from .atari import AtariPlayer
import sys
predictor = lambda x: np.array([1, 1, 1, 1])
def predictor(x):
np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
......
......@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment):
assert isinstance(spc, gym.spaces.discrete.Discrete)
return DiscreteActionSpace(spc.n)
if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions()
......
......@@ -7,10 +7,8 @@ import tensorflow as tf
import multiprocessing as mp
import time
import threading
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
from collections import defaultdict
import six
from six.moves import queue
......@@ -20,7 +18,6 @@ from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
#from ..utils.timer import *
from ..utils.serialize import loads, dumps
from ..utils.concurrency import LoopThread, ensure_proc_terminate
......@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
reward, isOver = player.action(action)
state = player.current_state()
# compatibility
SimulatorProcess = SimulatorProcessStateExchange
......@@ -284,6 +282,7 @@ class WeightSync(Callback):
self.condvar.notify_all()
self.condvar.release()
if __name__ == '__main__':
import random
from tensorpack.RL import NaiveRLEnvironment
......@@ -293,14 +292,13 @@ if __name__ == '__main__':
def _build_player(self):
return NaiveRLEnvironment()
class NaiveActioner(SimulatorActioner):
class NaiveActioner(SimulatorMaster):
def _get_action(self, state):
time.sleep(1)
return random.randint(1, 12)
def _on_episode_over(self, client):
#print("Over: ", client.memory)
# print("Over: ", client.memory)
client.memory = []
client.state = 0
......@@ -312,5 +310,4 @@ if __name__ == '__main__':
ensure_proc_terminate(procs)
th.start()
import time
time.sleep(100)
......@@ -3,10 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import os
import time
from abc import abstractmethod, ABCMeta
from abc import ABCMeta
import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
......
......@@ -6,7 +6,6 @@
""" Graph related callbacks"""
from .base import Callback
from ..utils import logger
__all__ = ['RunOp']
......@@ -26,7 +25,7 @@ class RunOp(Callback):
def _setup_graph(self):
self._op = self.setup_func()
#self._op_name = self._op.name
# self._op_name = self._op.name
def _before_train(self):
if self.run_before:
......@@ -35,6 +34,3 @@ class RunOp(Callback):
def _trigger_epoch(self):
if self.run_epoch:
self._op.run()
# def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
......@@ -86,7 +86,6 @@ class Callbacks(Callback):
def _trigger_epoch(self):
tm = CallbackTimeLogger()
test_sess_restored = False
for cb in self.cbs:
display_name = str(cb)
with tm.timed_callback(display_name):
......
......@@ -2,14 +2,13 @@
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
from abc import ABCMeta, abstractmethod
import sys
import six
from six.moves import zip
from ..utils import logger, execute_only_once
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name
......
......@@ -12,7 +12,6 @@ from ..dataflow import DataFlow
from .base import Callback
from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm
from ..train.input_data import FeedfreeInput
......@@ -99,7 +98,6 @@ class InferenceRunner(Callback):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
......@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback):
for inf in self.infs:
inf.before_inference()
sess = tf.get_default_session()
sz = self._input_data.size()
with get_tqdm(total=sz) as pbar:
for _ in range(sz):
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import abstractmethod, ABCMeta, abstractproperty
from abc import abstractmethod, ABCMeta
import operator
import six
import os
......
......@@ -5,7 +5,6 @@
import tensorflow as tf
import os
import shutil
import re
from .base import Callback
from ..utils import logger
......
......@@ -2,8 +2,6 @@
# File: stats.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
import os
import operator
import json
......
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from __future__ import division
import copy
import numpy as np
from collections import deque, defaultdict
from six.moves import range, map
......@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow):
super(BatchData, self).__init__(ds)
if not remainder:
try:
s = ds.size()
assert batch_size <= ds.size()
except NotImplementedError:
pass
......
......@@ -8,7 +8,7 @@ import glob
import cv2
import numpy as np
from ...utils import logger, get_rng, get_dataset_path
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
......
......@@ -5,16 +5,13 @@
# Yukun Chen <cykustc@gmail.com>
import os
import sys
import pickle
import numpy as np
import random
import six
from six.moves import urllib, range
from six.moves import range
import copy
import logging
from ...utils import logger, get_rng, get_dataset_path
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
......@@ -152,6 +149,7 @@ class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
if __name__ == '__main__':
ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images
......
......@@ -5,11 +5,11 @@
import os
import tarfile
import cv2
import six
import numpy as np
from six.moves import range
import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_path
from ...utils import logger, get_dataset_path
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation
......@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow):
box = root.find('object').find('bndbox').getchildren()
box = map(lambda x: float(x.text), box)
#box[0] /= size[0]
#box[1] /= size[1]
#box[2] /= size[0]
#box[3] /= size[1]
# box[0] /= size[0]
# box[1] /= size[1]
# box[2] /= size[0]
# box[3] /= size[1]
return np.asarray(box, dtype='float32')
with timed_operation('Loading Bounding Boxes ...'):
......@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow):
logger.info("{}/{} images have bounding box.".format(cnt, len(imglist)))
return ret
if __name__ == '__main__':
meta = ILSVRCMeta()
# print(meta.get_synset_words_1000())
......
......@@ -5,9 +5,8 @@
import os
import gzip
import random
import numpy
from six.moves import urllib, range
from six.moves import range
from ...utils import logger, get_dataset_path
from ...utils.fs import download
......@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow):
label = self.labels[k]
yield [img, label]
if __name__ == '__main__':
ds = Mnist('train')
for (img, label) in ds.get_data():
......
......@@ -9,9 +9,7 @@ import numpy as np
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ...utils.argtools import memoized_ignoreargs
from ..base import RNGDataFlow
try:
import tensorflow
from tensorflow.models.rnn.ptb import reader as tfreader
except ImportError:
logger.warn_dependency('PennTreeBank', 'tensorflow')
......
......@@ -5,9 +5,8 @@
import os
import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_path
from ...utils import logger, get_dataset_path
from ..base import RNGDataFlow
try:
......@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow):
c = SVHNDigit('extra')
return np.concatenate((a.X, b.X, c.X)).mean(axis=0)
if __name__ == '__main__':
a = SVHNDigit('train')
b = SVHNDigit.get_per_pixel_mean()
......@@ -4,8 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..base import DataFlow
from ...utils import *
from ...utils.timer import *
from ...utils.timer import timed_operation
from six.moves import zip, map
from collections import Counter
import json
......@@ -74,12 +73,11 @@ class VisualQA(DataFlow):
ret = cnt.most_common(n)
return [k[0] for k in ret]
if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
print(json.dumps(k))
break
# vqa.get_common_question_words(100)
vqa.get_common_answer(100)
#from IPython import embed; embed()
......@@ -6,10 +6,11 @@ import numpy as np
from six.moves import range
import os
from ..utils import logger, get_rng, get_tqdm
from ..utils import logger, get_tqdm
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow
try:
......@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow):
if k != '__keys__':
yield [k, v]
else:
s = self.size()
self.rng.shuffle(self.keys)
for k in self.keys:
v = self._txn.get(k)
......@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder):
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k))
log_once("Cannot read key {}".format(k), 'warn')
return None
return [img.transpose(1, 2, 0), datum.label]
......
......@@ -4,8 +4,7 @@
import numpy as np
import cv2
import copy
from .base import RNGDataFlow, DataFlow, ProxyDataFlow
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from .imgaug import AugmentorList
......@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent):
Augment the image component of datapoints
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index (or list of indices) of the image component in the produced datapoints by `ds`. default to be 0
:param index: the index (or list of indices) of the image component
in the produced datapoints by `ds`. default to be 0
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
......
......@@ -55,7 +55,7 @@ class Augmentor(object):
def _rand_range(self, low=1.0, high=None, size=None):
if high is None:
low, high = 0, low
if size == None:
if size is None:
size = []
return self.rng.uniform(low, high, size)
......
......@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor):
self._init(locals())
def _augment(self, img, _):
orig_shape = img.shape
return img[self.rect.y0: self.rect.y1 + 1,
self.rect.x0: self.rect.x0 + 1]
......@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor):
y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w]
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
......@@ -26,7 +26,7 @@ class GaussianMap(object):
y = y.astype('float32') / ret.shape[0] - anchor[0]
x = x.astype('float32') / ret.shape[1] - anchor[1]
g = np.exp(-(x**2 + y ** 2) / self.sigma)
#cv2.imshow(" ", g)
# cv2.imshow(" ", g)
# cv2.waitKey()
return g
......
......@@ -6,7 +6,6 @@
from .base import ImageAugmentor
import math
import cv2
import numpy as np
__all__ = ['Rotation', 'RotationAndCropValid']
......@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor):
newh = min(newh, ret.shape[0])
newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh)
# print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy + newh, newx:newx + neww]
@staticmethod
......
......@@ -131,7 +131,8 @@ class Clip(ImageAugmentor):
class Saturation(ImageAugmentor):
def __init__(self, alpha=0.4):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" Saturation,
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
super(Saturation, self).__init__()
assert alpha < 1
......@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor):
def __init__(self, std, eigval, eigvec):
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
The implementation follows 'fb.resnet.torch': https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
The implementation follows 'fb.resnet.torch':
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
:param eigvec: each column is one eigen vector
"""
......
......@@ -4,10 +4,8 @@
from __future__ import print_function
import multiprocessing as mp
from threading import Thread
import itertools
from six.moves import range, zip
from six.moves.queue import Queue
import uuid
import os
......@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
:param pipedir: a local directory where the pipes would be. Useful if you're running on non-local FS such as NFS.
:param pipedir: a local directory where the pipes would be.
Useful if you're running on non-local FS such as NFS.
"""
super(PrefetchDataZMQ, self).__init__(ds)
try:
......
......@@ -51,6 +51,7 @@ class RemoteData(DataFlow):
dp = loads(self.socket.recv(copy=False))
yield dp
if __name__ == '__main__':
import sys
from tqdm import tqdm
......
......@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow):
if __name__ == '__main__':
from .raw import FakeData
from .prefetch import PrefetchDataZMQ
from .image import AugmentImageComponent
from . import imgaug
ds = FakeData([[224, 224, 3]], 100000, random=False)
def tf_aug(v):
......@@ -69,6 +66,9 @@ if __name__ == '__main__':
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
# from .prefetch import PrefetchDataZMQ
# from .image import AugmentImageComponent
# from . import imgaug
# ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
import unittest
......@@ -29,6 +28,7 @@ def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite)
if __name__ == '__main__':
import tensorpack
from tensorpack.utils import logger
......
......@@ -6,8 +6,6 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from copy import copy
import re
from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context
......@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat:
# training tower
if ctx.is_training:
#reuse = tf.get_variable_scope().reuse
# reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
......@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean)
var_var_name = ema.average_name(batch_var)
sc = tf.get_variable_scope()
if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the desired variable name could
......@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else:
return tf.identity(xn, name='output')
if get_tf_version() >= 12:
BatchNorm = BatchNormV2
else:
......
......@@ -3,12 +3,9 @@
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf
import math
from ._common import layer_register, shape2d, shape4d
from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D']
......@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape,
conv = tf.concat(3, outputs)
if nl is None:
logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. "
"Please use argscope instead.")
nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
......
......@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import math
from ._common import layer_register
from ..tfutils import symbolic_functions as symbf
from ..utils import logger
__all__ = ['FullyConnected']
......@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim,
in_dim = x.get_shape().as_list()[1]
if W_init is None:
#W_init = tf.uniform_unit_scaling_initializer(factor=1.43)
W_init = tf.contrib.layers.variance_scaling_initializer()
if b_init is None:
b_init = tf.constant_initializer()
......@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim,
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None:
logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated."
" Please use argscope instead.")
nl = tf.nn.relu
return nl(prod, name='output')
......@@ -6,6 +6,7 @@
import tensorflow as tf
from ._common import layer_register
from ._test import TestModel
__all__ = ['ImageSample']
......@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'):
diffy, diffx = tf.split(3, 2, diff)
neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
# prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50)
......@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'):
ret = ret * tf.cast(mask, tf.float32)
return ret
from ._test import TestModel
class TestSample(TestModel):
......@@ -128,9 +127,9 @@ class TestSample(TestModel):
bimg = np.random.rand(2, h, w, 3).astype('float32')
# mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
# [[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
# [[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
# ], dtype='float32') #2x2x2x2
mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3])
true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32'))
......@@ -140,10 +139,10 @@ class TestSample(TestModel):
self.assertTrue((res == true_res).all())
if __name__ == '__main__':
import cv2
import numpy as np
import sys
im = cv2.imread('cat.jpg')
im = im.reshape((1,) + im.shape).astype('float32')
imv = tf.Variable(im)
......@@ -160,8 +159,8 @@ if __name__ == '__main__':
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
# out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
# out = sess.run(output)
# print(out[0].min())
# print(out[0].max())
# print(out[0].sum())
......
......@@ -4,25 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod
import re
import tensorflow as tf
from collections import namedtuple
import inspect
import pickle
import six
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(object):
def __init__(self, type, shape, name, sparse=False):
self.type = type
self.shape = shape
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from copy import copy
from ._common import layer_register
from .batch_norm import BatchNorm
......@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None):
if name is None:
name = 'output'
return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
# alpha = float(alpha)
# x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
# return tf.mul(x, 0.5, name=name)
......
......@@ -8,6 +8,8 @@ import numpy as np
from ._common import layer_register, shape4d
from ..utils.argtools import shape2d
from ..tfutils import symbolic_functions as symbf
from ._test import TestModel
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
......@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape):
:param x: input NHWC tensor
:param shape: an integer, the upsample factor
"""
#inp_shape = tf.shape(x)
# inp_shape = tf.shape(x)
# return tf.image.resize_bilinear(x,
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True)
......@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape):
return deconv
from ._test import TestModel
class TestPool(TestModel):
def test_fixed_unpooling(self):
......
......@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name))
l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer
......
......@@ -3,11 +3,11 @@
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta, abstractproperty
from abc import abstractmethod, ABCMeta
import tensorflow as tf
import six
from ..utils.naming import *
from ..utils.naming import PREDICT_TOWER
from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext
......@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
fn = lambda _: config.model.build_graph(config.model.get_input_vars())
def fn(_):
config.model.build_graph(config.model.get_input_vars())
build_multi_tower_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config)
......
......@@ -2,19 +2,14 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from collections import namedtuple
import six
from six.moves import zip
from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor
import multiprocessing
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
......@@ -53,7 +48,7 @@ class PredictConfig(object):
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
# logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
......@@ -61,7 +56,7 @@ class PredictConfig(object):
self.output_names = kwargs.pop('output_names', None)
if self.output_names is None:
self.output_names = kwargs.pop('output_var_names')
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
# logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names
for v in self.input_names:
assert_type(v, six.string_types)
......
......@@ -5,10 +5,8 @@
import multiprocessing
import threading
import tensorflow as tf
import time
import six
from six.moves import queue, range, zip
from six.moves import queue, range
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
......@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.predictor = OfflinePredictor(self.config)
import sys
if self.idx == 0:
with self.predictor.graph.as_default():
describe_model()
......@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
assert len(predictors)
for k in predictors:
#assert isinstance(k, OnlinePredictor), type(k)
# assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert k.return_input == False
assert not k.return_input
self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [
PredictorWorkerThread(
......
......@@ -9,7 +9,7 @@ import multiprocessing
import os
import six
from ..dataflow import DataFlow, BatchData
from ..dataflow import DataFlow
from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm
......
......@@ -3,7 +3,7 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from ..utils.naming import *
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME
import tensorflow as tf
from copy import copy
import six
......@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
conf.allow_soft_placement = True
#conf.log_device_placement = True
# conf.log_device_placement = True
return conf
......@@ -74,6 +74,7 @@ def get_op_tensor_name(name):
else:
return name, name + ':0'
get_op_var_name = get_op_tensor_name
......@@ -88,6 +89,7 @@ def get_tensors_by_names(names):
ret.append(G.get_tensor_by_name(varn))
return ret
get_vars_by_names = get_tensors_by_names
......
......@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor):
ret.append((grad, var))
return ret
_summaried_gradient = set()
......@@ -133,7 +134,7 @@ class CheckGradient(MapGradient):
def _mapper(self, grad, var):
# this is very slow.... see #3649
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
# op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name)
return grad
......
......@@ -5,7 +5,6 @@
import os
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import re
import numpy as np
import tensorflow as tf
import six
......@@ -120,7 +119,8 @@ class SaverRestore(SessionInit):
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. But anything from prediction tower shouldn't be saved.".format(v.name))
logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars)
def _get_vars_to_restore_multimap(self, vars_available):
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
import re
from ..utils.argtools import memoized
from ..utils.naming import *
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from .tower import get_current_tower_context
from . import get_global_step_var
from .symbolic_functions import rms
......
......@@ -4,7 +4,6 @@
import tensorflow as tf
import numpy as np
from ..utils import logger
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
......@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
# logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# cost = tf.sub(loss_pos, loss_neg, name=name)
return cost
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
import re
from ..utils.naming import *
from ..utils.naming import PREDICT_TOWER
__all__ = ['get_current_tower_context', 'TowerContext']
......@@ -49,7 +49,7 @@ class TowerContext(object):
with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope()
assert scope.reuse == False
assert not scope.reuse
return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name):
......
......@@ -10,7 +10,7 @@ from collections import defaultdict
import re
import numpy as np
from ..utils import logger
from ..utils.naming import *
from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
......@@ -51,7 +51,7 @@ class SessionUpdate(object):
self.sess = sess
self.assign_ops = defaultdict(list)
for v in vars_to_update:
#p = tf.placeholder(v.dtype, shape=v.get_shape())
# p = tf.placeholder(v.dtype, shape=v.get_shape())
with tf.device('/cpu:0'):
p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name)
......
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod
import signal
import re
import weakref
import six
......
......@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput, DummyConstantInput
from .input_data import QueueInput, FeedfreeInput
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
......@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
# self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer):
......@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
"""
config.data = QueueInput(config.dataset, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
......
......@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread):
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
except tf.errors.CancelledError:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
......
......@@ -9,9 +9,9 @@ import re
from six.moves import zip, range
from ..utils import logger
from ..utils.naming import *
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.summary import summary_moving_average
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
......@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer):
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope:
TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx))
grad_list.append(get_tower_grad_func())
......@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert isinstance(self._input_method, QueueInput)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config)
......@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
logger.warn("No Gradient w.r.t {}".format(v.op.name))
continue
try:
grad = tf.add_n(all_grad) / float(len(tower_grads))
......@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
# ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
# self.train_op = tf.group(*ops)
# return
grads = SyncMultiGPUTrainer._average_grads(grad_list)
......@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super(AsyncMultiGPUTrainer, self).__init__(config)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower)
......
......@@ -3,18 +3,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import time
from six.moves import zip
from .base import Trainer
from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput
from .input_data import FeedInput
__all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
......@@ -49,7 +47,8 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
fn = lambda _: self.model.build_graph(self.model.get_input_vars())
def fn(_):
self.model.build_graph(self.model.get_input_vars())
build_multi_tower_prediction_graph(fn, self.towers)
self.tower_built = True
......
......@@ -9,8 +9,9 @@ import inspect
import six
import functools
import collections
from . import logger
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs']
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
def map_arg(**maps):
......@@ -64,11 +65,12 @@ class memoized(object):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary?
hash(func) # make sure it is hashable. TODO is it necessary?
def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS:
......@@ -99,3 +101,8 @@ def shape2d(a):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
@memoized
def log_once(message, func):
getattr(logger, func)(message)
......@@ -11,13 +11,15 @@ from contextlib import contextmanager
import signal
import weakref
import six
from six.moves import queue
from . import logger
if six.PY2:
import subprocess32 as subprocess
else:
import subprocess
from six.moves import queue
from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
......
......@@ -28,6 +28,7 @@ def enable_call_trace():
return
sys.settrace(tracer)
if __name__ == '__main__':
enable_call_trace()
......
......@@ -3,8 +3,7 @@
# File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from . import logger
from .argtools import memoized
from .argtools import log_once
from abc import abstractmethod, ABCMeta
import numpy as np
import six
......@@ -13,13 +12,6 @@ from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
@memoized
def log_once(s):
logger.warn(s)
# just a placeholder
@six.add_metaclass(ABCMeta)
class Discretizer(object):
......@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D):
def get_bin(self, v):
if v < self.minv:
log_once("UniformDiscretizer1D: value smaller than min!")
log_once("UniformDiscretizer1D: value smaller than min!", 'warn')
return 0
if v > self.maxv:
log_once("UniformDiscretizer1D: value larger than max!")
log_once("UniformDiscretizer1D: value larger than max!", 'warn')
return self.nr_bin - 1
return int(np.clip(
(v - self.minv) / self.spacing,
......@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer):
bin_id_nd = self.get_nd_bin_ids(bin_id)
return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)]
if __name__ == '__main__':
#u = UniformDiscretizer1D(-10, 10, 0.12)
# u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
......@@ -54,5 +54,6 @@ def recursive_walk(rootdir):
for f in files:
yield os.path.join(r, f)
if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
......@@ -3,14 +3,9 @@
# File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from collections import namedtuple, defaultdict
from abc import abstractmethod
import numpy as np
import copy
import os
from six.moves import zip
from .utils import change_env, get_dataset_path
from .fs import download
from . import logger
......@@ -115,7 +110,7 @@ def get_caffe_pb():
dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir)
download(CAFFE_PROTO_URL, dir)
assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \
......@@ -123,6 +118,7 @@ def get_caffe_pb():
import imp
return imp.load_source('caffepb', caffe_pb_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
......@@ -132,5 +128,4 @@ if __name__ == '__main__':
args = parser.parse_args()
ret = load_caffe(args.model, args.weights)
import numpy as np
np.save(args.output, ret)
......@@ -11,11 +11,11 @@ from datetime import datetime
from six.moves import input
import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', 'warn_dependency']
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir',
'warn_dependency']
class _MyFormatter(logging.Formatter):
def format(self, record):
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s'
......@@ -40,12 +40,19 @@ def _getlogger():
handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(handler)
return logger
_logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
# logger file and directory:
global LOG_FILE, LOG_DIR
LOG_DIR = None
......@@ -55,7 +62,7 @@ def _set_file(path):
if os.path.isfile(path):
backup_name = path + '.' + get_time_str()
shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name))
info("Log file '{}' backuped to '{}'".format(path, backup_name)) # noqa: F821
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
......@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""")
if act == 'b':
backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name)
info("Directory '{}' backuped to '{}'".format(dirname, backup_name))
info("Directory '{}' backuped to '{}'".format(dirname, backup_name)) # noqa: F821
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'n':
dirname = dirname + get_time_str()
info("Use a new log directory {}".format(dirname))
info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k':
pass
else:
......@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""")
_set_file(LOG_FILE)
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
def disable_logger():
""" disable all logging ability from this moment"""
for func in _LOGGING_METHOD:
......@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False):
def warn_dependency(name, dependencies):
warn("Failed to import '{}', {} won't be available'".format(dependencies, name))
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
......@@ -2,6 +2,7 @@
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0'
......@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES'
import tensorflow as tf
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
# export all upper case variables
......
......@@ -6,16 +6,13 @@
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
#import dill
__all__ = ['loads', 'dumps']
def dumps(obj):
# return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
# return dill.loads(buf)
return msgpack.loads(buf)
......@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False):
logger.info('{} finished, time:{:.2f}sec.'.format(
msg, time.time() - start))
_TOTAL_TIMER_DATA = defaultdict(StatCounter)
......@@ -66,4 +67,5 @@ def print_total_timer():
logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format(
k, v.sum, v.count, v.average))
atexit.register(print_total_timer)
......@@ -8,7 +8,6 @@ from contextlib import contextmanager
import inspect
from datetime import datetime
from tqdm import tqdm
import time
import numpy as np
__all__ = ['change_env',
......
......@@ -106,7 +106,6 @@ def build_patch_list(patch_list,
ph, pw = patch_list.shape[1:3]
if border is None:
border = int(0.1 * min(ph, pw))
mh, mw = max(max_height, ph + border), max(max_width, pw + border)
if nr_row is None:
nr_row = minnone(nr_row, max_height / (ph + border))
if nr_col is None:
......@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True,
if viz is not None:
vizlist.append(img)
if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list(
next(build_patch_list(
vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1], viz=True))
vizlist = vizlist[vizsize:]
if __name__ == '__main__':
import cv2
imglist = []
for i in range(100):
fname = "{:03d}.png".format(i)
......
[flake8]
max-line-length = 120
exclude = .git,
__init__.py,
snippet,
docs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment