Commit 37e98945 authored by Yuxin Wu's avatar Yuxin Wu

fix flake8 style in tensorpack/

parent 233b3b90
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from collections import deque from collections import deque
from .envbase import ProxyPlayer from .envbase import ProxyPlayer
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import six import six
import random
from ..utils import get_rng from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer', __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
......
...@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -211,7 +211,9 @@ class ExpReplay(DataFlow, Callback):
if __name__ == '__main__': if __name__ == '__main__':
from .atari import AtariPlayer from .atari import AtariPlayer
import sys import sys
predictor = lambda x: np.array([1, 1, 1, 1])
def predictor(x):
np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204)) player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor, E = ExpReplay(predictor,
player=player, player=player,
......
...@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment): ...@@ -76,6 +76,7 @@ class GymEnv(RLEnvironment):
assert isinstance(spc, gym.spaces.discrete.Discrete) assert isinstance(spc, gym.spaces.discrete.Discrete)
return DiscreteActionSpace(spc.n) return DiscreteActionSpace(spc.n)
if __name__ == '__main__': if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1) env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions() num = env.get_action_space().num_actions()
......
...@@ -7,10 +7,8 @@ import tensorflow as tf ...@@ -7,10 +7,8 @@ import tensorflow as tf
import multiprocessing as mp import multiprocessing as mp
import time import time
import threading import threading
import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict
import numpy as np
import six import six
from six.moves import queue from six.moves import queue
...@@ -20,7 +18,6 @@ from ..callbacks import Callback ...@@ -20,7 +18,6 @@ from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor from ..predict import OfflinePredictor
from ..utils import logger from ..utils import logger
#from ..utils.timer import *
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils.concurrency import LoopThread, ensure_proc_terminate from ..utils.concurrency import LoopThread, ensure_proc_terminate
...@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -98,6 +95,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
reward, isOver = player.action(action) reward, isOver = player.action(action)
state = player.current_state() state = player.current_state()
# compatibility # compatibility
SimulatorProcess = SimulatorProcessStateExchange SimulatorProcess = SimulatorProcessStateExchange
...@@ -284,6 +282,7 @@ class WeightSync(Callback): ...@@ -284,6 +282,7 @@ class WeightSync(Callback):
self.condvar.notify_all() self.condvar.notify_all()
self.condvar.release() self.condvar.release()
if __name__ == '__main__': if __name__ == '__main__':
import random import random
from tensorpack.RL import NaiveRLEnvironment from tensorpack.RL import NaiveRLEnvironment
...@@ -293,14 +292,13 @@ if __name__ == '__main__': ...@@ -293,14 +292,13 @@ if __name__ == '__main__':
def _build_player(self): def _build_player(self):
return NaiveRLEnvironment() return NaiveRLEnvironment()
class NaiveActioner(SimulatorActioner): class NaiveActioner(SimulatorMaster):
def _get_action(self, state): def _get_action(self, state):
time.sleep(1) time.sleep(1)
return random.randint(1, 12) return random.randint(1, 12)
def _on_episode_over(self, client): def _on_episode_over(self, client):
#print("Over: ", client.memory) # print("Over: ", client.memory)
client.memory = [] client.memory = []
client.state = 0 client.state = 0
...@@ -312,5 +310,4 @@ if __name__ == '__main__': ...@@ -312,5 +310,4 @@ if __name__ == '__main__':
ensure_proc_terminate(procs) ensure_proc_terminate(procs)
th.start() th.start()
import time
time.sleep(100) time.sleep(100)
...@@ -3,10 +3,7 @@ ...@@ -3,10 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import sys from abc import ABCMeta
import os
import time
from abc import abstractmethod, ABCMeta
import six import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
""" Graph related callbacks""" """ Graph related callbacks"""
from .base import Callback from .base import Callback
from ..utils import logger
__all__ = ['RunOp'] __all__ = ['RunOp']
...@@ -26,7 +25,7 @@ class RunOp(Callback): ...@@ -26,7 +25,7 @@ class RunOp(Callback):
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
#self._op_name = self._op.name # self._op_name = self._op.name
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
...@@ -35,6 +34,3 @@ class RunOp(Callback): ...@@ -35,6 +34,3 @@ class RunOp(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
if self.run_epoch: if self.run_epoch:
self._op.run() self._op.run()
# def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
...@@ -86,7 +86,6 @@ class Callbacks(Callback): ...@@ -86,7 +86,6 @@ class Callbacks(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
test_sess_restored = False
for cb in self.cbs: for cb in self.cbs:
display_name = str(cb) display_name = str(cb)
with tm.timed_callback(display_name): with tm.timed_callback(display_name):
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# File: inference.py # File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import sys import sys
import six import six
from six.moves import zip from six.moves import zip
from ..utils import logger, execute_only_once from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name from ..tfutils import get_op_var_name
......
...@@ -12,7 +12,6 @@ from ..dataflow import DataFlow ...@@ -12,7 +12,6 @@ from ..dataflow import DataFlow
from .base import Callback from .base import Callback
from .inference import Inferencer from .inference import Inferencer
from .dispatcher import OutputTensorDispatcer from .dispatcher import OutputTensorDispatcer
from ..tfutils import get_op_tensor_name
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
...@@ -99,7 +98,6 @@ class InferenceRunner(Callback): ...@@ -99,7 +98,6 @@ class InferenceRunner(Callback):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
sess = tf.get_default_session()
self.ds.reset_state() self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar: with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
...@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback): ...@@ -171,7 +169,6 @@ class FeedfreeInferenceRunner(Callback):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
sess = tf.get_default_session()
sz = self._input_data.size() sz = self._input_data.size()
with get_tqdm(total=sz) as pbar: with get_tqdm(total=sz) as pbar:
for _ in range(sz): for _ in range(sz):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from abc import abstractmethod, ABCMeta, abstractproperty from abc import abstractmethod, ABCMeta
import operator import operator
import six import six
import os import os
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
import os import os
import shutil import shutil
import re
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# File: stats.py # File: stats.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
import os import os
import operator import operator
import json import json
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from __future__ import division from __future__ import division
import copy
import numpy as np import numpy as np
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
...@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow): ...@@ -48,7 +47,6 @@ class BatchData(ProxyDataFlow):
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
try: try:
s = ds.size()
assert batch_size <= ds.size() assert batch_size <= ds.size()
except NotImplementedError: except NotImplementedError:
pass pass
......
...@@ -8,7 +8,7 @@ import glob ...@@ -8,7 +8,7 @@ import glob
import cv2 import cv2
import numpy as np import numpy as np
from ...utils import logger, get_rng, get_dataset_path from ...utils import logger, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
......
...@@ -5,16 +5,13 @@ ...@@ -5,16 +5,13 @@
# Yukun Chen <cykustc@gmail.com> # Yukun Chen <cykustc@gmail.com>
import os import os
import sys
import pickle import pickle
import numpy as np import numpy as np
import random
import six import six
from six.moves import urllib, range from six.moves import range
import copy import copy
import logging
from ...utils import logger, get_rng, get_dataset_path from ...utils import logger, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
...@@ -152,6 +149,7 @@ class Cifar100(CifarBase): ...@@ -152,6 +149,7 @@ class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100) super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
if __name__ == '__main__': if __name__ == '__main__':
ds = Cifar10('train') ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images from tensorpack.dataflow.dftools import dump_dataset_images
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
import os import os
import tarfile import tarfile
import cv2 import cv2
import six
import numpy as np import numpy as np
from six.moves import range
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_path from ...utils import logger, get_dataset_path
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
...@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow): ...@@ -195,10 +195,10 @@ class ILSVRC12(RNGDataFlow):
box = root.find('object').find('bndbox').getchildren() box = root.find('object').find('bndbox').getchildren()
box = map(lambda x: float(x.text), box) box = map(lambda x: float(x.text), box)
#box[0] /= size[0] # box[0] /= size[0]
#box[1] /= size[1] # box[1] /= size[1]
#box[2] /= size[0] # box[2] /= size[0]
#box[3] /= size[1] # box[3] /= size[1]
return np.asarray(box, dtype='float32') return np.asarray(box, dtype='float32')
with timed_operation('Loading Bounding Boxes ...'): with timed_operation('Loading Bounding Boxes ...'):
...@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -218,6 +218,7 @@ class ILSVRC12(RNGDataFlow):
logger.info("{}/{} images have bounding box.".format(cnt, len(imglist))) logger.info("{}/{} images have bounding box.".format(cnt, len(imglist)))
return ret return ret
if __name__ == '__main__': if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
# print(meta.get_synset_words_1000()) # print(meta.get_synset_words_1000())
......
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
import os import os
import gzip import gzip
import random
import numpy import numpy
from six.moves import urllib, range from six.moves import range
from ...utils import logger, get_dataset_path from ...utils import logger, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
...@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow): ...@@ -110,6 +109,7 @@ class Mnist(RNGDataFlow):
label = self.labels[k] label = self.labels[k]
yield [img, label] yield [img, label]
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist('train') ds = Mnist('train')
for (img, label) in ds.get_data(): for (img, label) in ds.get_data():
......
...@@ -9,9 +9,7 @@ import numpy as np ...@@ -9,9 +9,7 @@ import numpy as np
from ...utils import logger, get_dataset_path from ...utils import logger, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ...utils.argtools import memoized_ignoreargs from ...utils.argtools import memoized_ignoreargs
from ..base import RNGDataFlow
try: try:
import tensorflow
from tensorflow.models.rnn.ptb import reader as tfreader from tensorflow.models.rnn.ptb import reader as tfreader
except ImportError: except ImportError:
logger.warn_dependency('PennTreeBank', 'tensorflow') logger.warn_dependency('PennTreeBank', 'tensorflow')
......
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
import os import os
import numpy as np import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_path from ...utils import logger, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
try: try:
...@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow): ...@@ -71,6 +70,7 @@ class SVHNDigit(RNGDataFlow):
c = SVHNDigit('extra') c = SVHNDigit('extra')
return np.concatenate((a.X, b.X, c.X)).mean(axis=0) return np.concatenate((a.X, b.X, c.X)).mean(axis=0)
if __name__ == '__main__': if __name__ == '__main__':
a = SVHNDigit('train') a = SVHNDigit('train')
b = SVHNDigit.get_per_pixel_mean() b = SVHNDigit.get_per_pixel_mean()
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..base import DataFlow from ..base import DataFlow
from ...utils import * from ...utils.timer import timed_operation
from ...utils.timer import *
from six.moves import zip, map from six.moves import zip, map
from collections import Counter from collections import Counter
import json import json
...@@ -74,12 +73,11 @@ class VisualQA(DataFlow): ...@@ -74,12 +73,11 @@ class VisualQA(DataFlow):
ret = cnt.most_common(n) ret = cnt.most_common(n)
return [k[0] for k in ret] return [k[0] for k in ret]
if __name__ == '__main__': if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json', vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json') '/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data(): for k in vqa.get_data():
print(json.dumps(k)) print(json.dumps(k))
break break
# vqa.get_common_question_words(100)
vqa.get_common_answer(100) vqa.get_common_answer(100)
#from IPython import embed; embed()
...@@ -6,10 +6,11 @@ import numpy as np ...@@ -6,10 +6,11 @@ import numpy as np
from six.moves import range from six.moves import range
import os import os
from ..utils import logger, get_rng, get_tqdm from ..utils import logger, get_tqdm
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow from .base import RNGDataFlow
try: try:
...@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow): ...@@ -114,7 +115,6 @@ class LMDBData(RNGDataFlow):
if k != '__keys__': if k != '__keys__':
yield [k, v] yield [k, v]
else: else:
s = self.size()
self.rng.shuffle(self.keys) self.rng.shuffle(self.keys)
for k in self.keys: for k in self.keys:
v = self._txn.get(k) v = self._txn.get(k)
...@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder): ...@@ -159,7 +159,7 @@ class CaffeLMDB(LMDBDataDecoder):
img = np.fromstring(datum.data, dtype=np.uint8) img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width) img = img.reshape(datum.channels, datum.height, datum.width)
except Exception: except Exception:
log_once("Cannot read key {}".format(k)) log_once("Cannot read key {}".format(k), 'warn')
return None return None
return [img.transpose(1, 2, 0), datum.label] return [img.transpose(1, 2, 0), datum.label]
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
import numpy as np import numpy as np
import cv2 import cv2
import copy from .base import RNGDataFlow
from .base import RNGDataFlow, DataFlow, ProxyDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
...@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent): ...@@ -52,7 +51,8 @@ class AugmentImageComponent(MapDataComponent):
Augment the image component of datapoints Augment the image component of datapoints
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order. :param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index (or list of indices) of the image component in the produced datapoints by `ds`. default to be 0 :param index: the index (or list of indices) of the image component
in the produced datapoints by `ds`. default to be 0
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
......
...@@ -55,7 +55,7 @@ class Augmentor(object): ...@@ -55,7 +55,7 @@ class Augmentor(object):
def _rand_range(self, low=1.0, high=None, size=None): def _rand_range(self, low=1.0, high=None, size=None):
if high is None: if high is None:
low, high = 0, low low, high = 0, low
if size == None: if size is None:
size = [] size = []
return self.rng.uniform(low, high, size) return self.rng.uniform(low, high, size)
......
...@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor): ...@@ -74,7 +74,6 @@ class FixedCrop(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
orig_shape = img.shape
return img[self.rect.y0: self.rect.y1 + 1, return img[self.rect.y0: self.rect.y1 + 1,
self.rect.x0: self.rect.x0 + 1] self.rect.x0: self.rect.x0 + 1]
...@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -174,5 +173,6 @@ class RandomCropRandomShape(ImageAugmentor):
y0, x0, h, w = param y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w] return img[y0:y0 + h, x0:x0 + w]
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
...@@ -26,7 +26,7 @@ class GaussianMap(object): ...@@ -26,7 +26,7 @@ class GaussianMap(object):
y = y.astype('float32') / ret.shape[0] - anchor[0] y = y.astype('float32') / ret.shape[0] - anchor[0]
x = x.astype('float32') / ret.shape[1] - anchor[1] x = x.astype('float32') / ret.shape[1] - anchor[1]
g = np.exp(-(x**2 + y ** 2) / self.sigma) g = np.exp(-(x**2 + y ** 2) / self.sigma)
#cv2.imshow(" ", g) # cv2.imshow(" ", g)
# cv2.waitKey() # cv2.waitKey()
return g return g
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
from .base import ImageAugmentor from .base import ImageAugmentor
import math import math
import cv2 import cv2
import numpy as np
__all__ = ['Rotation', 'RotationAndCropValid'] __all__ = ['Rotation', 'RotationAndCropValid']
...@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -59,7 +58,7 @@ class RotationAndCropValid(ImageAugmentor):
newh = min(newh, ret.shape[0]) newh = min(newh, ret.shape[0])
newx = int(center[0] - neww * 0.5) newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5) newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh) # print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy + newh, newx:newx + neww] return ret[newy:newy + newh, newx:newx + neww]
@staticmethod @staticmethod
......
...@@ -131,7 +131,8 @@ class Clip(ImageAugmentor): ...@@ -131,7 +131,8 @@ class Clip(ImageAugmentor):
class Saturation(ImageAugmentor): class Saturation(ImageAugmentor):
def __init__(self, alpha=0.4): def __init__(self, alpha=0.4):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218 """ Saturation,
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" """
super(Saturation, self).__init__() super(Saturation, self).__init__()
assert alpha < 1 assert alpha < 1
...@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor): ...@@ -150,7 +151,8 @@ class Lighting(ImageAugmentor):
def __init__(self, std, eigval, eigvec): def __init__(self, std, eigval, eigvec):
""" Lighting noise. """ Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex` See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
The implementation follows 'fb.resnet.torch': https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184 The implementation follows 'fb.resnet.torch':
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
:param eigvec: each column is one eigen vector :param eigvec: each column is one eigen vector
""" """
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
from __future__ import print_function from __future__ import print_function
import multiprocessing as mp import multiprocessing as mp
from threading import Thread
import itertools import itertools
from six.moves import range, zip from six.moves import range, zip
from six.moves.queue import Queue
import uuid import uuid
import os import os
...@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -127,7 +125,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order :param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random. of datapoints will be random.
:param pipedir: a local directory where the pipes would be. Useful if you're running on non-local FS such as NFS. :param pipedir: a local directory where the pipes would be.
Useful if you're running on non-local FS such as NFS.
""" """
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try: try:
......
...@@ -51,6 +51,7 @@ class RemoteData(DataFlow): ...@@ -51,6 +51,7 @@ class RemoteData(DataFlow):
dp = loads(self.socket.recv(copy=False)) dp = loads(self.socket.recv(copy=False))
yield dp yield dp
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
from tqdm import tqdm from tqdm import tqdm
......
...@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow): ...@@ -53,9 +53,6 @@ class TFFuncMapper(ProxyDataFlow):
if __name__ == '__main__': if __name__ == '__main__':
from .raw import FakeData from .raw import FakeData
from .prefetch import PrefetchDataZMQ
from .image import AugmentImageComponent
from . import imgaug
ds = FakeData([[224, 224, 3]], 100000, random=False) ds = FakeData([[224, 224, 3]], 100000, random=False)
def tf_aug(v): def tf_aug(v):
...@@ -69,6 +66,9 @@ if __name__ == '__main__': ...@@ -69,6 +66,9 @@ if __name__ == '__main__':
tf_aug, tf_aug,
lambda dp, f: [f([dp[0]])[0]] lambda dp, f: [f([dp[0]])[0]]
) )
# from .prefetch import PrefetchDataZMQ
# from .image import AugmentImageComponent
# from . import imgaug
# ds = AugmentImageComponent(ds, # ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False), # [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False), # imgaug.Contrast((0.8, 1.2), clip=False),
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import numpy as np
import unittest import unittest
...@@ -29,6 +28,7 @@ def run_test_case(case): ...@@ -29,6 +28,7 @@ def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case) suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
if __name__ == '__main__': if __name__ == '__main__':
import tensorpack import tensorpack
from tensorpack.utils import logger from tensorpack.utils import logger
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from copy import copy
import re
from ..tfutils.common import get_tf_version from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -65,7 +63,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat: if use_local_stat:
# training tower # training tower
if ctx.is_training: if ctx.is_training:
#reuse = tf.get_variable_scope().reuse # reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False): with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused # BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
...@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -86,7 +84,6 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean) mean_var_name = ema.average_name(batch_mean)
var_var_name = ema.average_name(batch_var) var_var_name = ema.average_name(batch_var)
sc = tf.get_variable_scope()
if ctx.is_main_tower: if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside # main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the desired variable name could # TODO when reuse=True, the desired variable name could
...@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -187,6 +184,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else: else:
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
if get_tf_version() >= 12: if get_tf_version() >= 12:
BatchNorm = BatchNormV2 BatchNorm = BatchNormV2
else: else:
......
...@@ -3,12 +3,9 @@ ...@@ -3,12 +3,9 @@
# File: conv2d.py # File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf import tensorflow as tf
import math
from ._common import layer_register, shape2d, shape4d from ._common import layer_register, shape2d, shape4d
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D']
...@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -63,7 +60,8 @@ def Conv2D(x, out_channel, kernel_shape,
conv = tf.concat(3, outputs) conv = tf.concat(3, outputs)
if nl is None: if nl is None:
logger.warn( logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.") "[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. "
"Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import math
from ._common import layer_register from ._common import layer_register
from ..tfutils import symbolic_functions as symbf from ..tfutils import symbolic_functions as symbf
from ..utils import logger
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
...@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim, ...@@ -31,7 +31,6 @@ def FullyConnected(x, out_dim,
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
if W_init is None: if W_init is None:
#W_init = tf.uniform_unit_scaling_initializer(factor=1.43)
W_init = tf.contrib.layers.variance_scaling_initializer() W_init = tf.contrib.layers.variance_scaling_initializer()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
...@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim, ...@@ -42,6 +41,7 @@ def FullyConnected(x, out_dim,
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W) prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None: if nl is None:
logger.warn( logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.") "[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated."
" Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(prod, name='output') return nl(prod, name='output')
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ._common import layer_register from ._common import layer_register
from ._test import TestModel
__all__ = ['ImageSample'] __all__ = ['ImageSample']
...@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'): ...@@ -82,7 +83,7 @@ def ImageSample(inputs, borderMode='repeat'):
diffy, diffx = tf.split(3, 2, diff) diffy, diffx = tf.split(3, 2, diff)
neg_diffy, neg_diffx = tf.split(3, 2, neg_diff) neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)
#prod = tf.reduce_prod(diff, 3, keep_dims=True) # prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod), # diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50) # tf.reduce_max(diff), diff], summarize=50)
...@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'): ...@@ -100,8 +101,6 @@ def ImageSample(inputs, borderMode='repeat'):
ret = ret * tf.cast(mask, tf.float32) ret = ret * tf.cast(mask, tf.float32)
return ret return ret
from ._test import TestModel
class TestSample(TestModel): class TestSample(TestModel):
...@@ -128,9 +127,9 @@ class TestSample(TestModel): ...@@ -128,9 +127,9 @@ class TestSample(TestModel):
bimg = np.random.rand(2, h, w, 3).astype('float32') bimg = np.random.rand(2, h, w, 3).astype('float32')
# mat = np.array([ # mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]], # [[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]] # [[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2 # ], dtype='float32') #2x2x2x2
mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3]) mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3])
true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32')) true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32'))
...@@ -140,10 +139,10 @@ class TestSample(TestModel): ...@@ -140,10 +139,10 @@ class TestSample(TestModel):
self.assertTrue((res == true_res).all()) self.assertTrue((res == true_res).all())
if __name__ == '__main__': if __name__ == '__main__':
import cv2 import cv2
import numpy as np import numpy as np
import sys
im = cv2.imread('cat.jpg') im = cv2.imread('cat.jpg')
im = im.reshape((1,) + im.shape).astype('float32') im = im.reshape((1,) + im.shape).astype('float32')
imv = tf.Variable(im) imv = tf.Variable(im)
...@@ -160,8 +159,8 @@ if __name__ == '__main__': ...@@ -160,8 +159,8 @@ if __name__ == '__main__':
sess = tf.Session() sess = tf.Session()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv)) # out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output) # out = sess.run(output)
# print(out[0].min()) # print(out[0].min())
# print(out[0].max()) # print(out[0].max())
# print(out[0].sum()) # print(out[0].sum())
......
...@@ -4,25 +4,19 @@ ...@@ -4,25 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re
import tensorflow as tf import tensorflow as tf
from collections import namedtuple
import inspect import inspect
import pickle import pickle
import six import six
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph'] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(object): class InputVar(object):
def __init__(self, type, shape, name, sparse=False): def __init__(self, type, shape, name, sparse=False):
self.type = type self.type = type
self.shape = shape self.shape = shape
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from copy import copy
from ._common import layer_register from ._common import layer_register
from .batch_norm import BatchNorm from .batch_norm import BatchNorm
...@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None): ...@@ -63,8 +62,8 @@ def LeakyReLU(x, alpha, name=None):
if name is None: if name is None:
name = 'output' name = 'output'
return tf.maximum(x, alpha * x, name=name) return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha) # alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) # x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
# return tf.mul(x, 0.5, name=name) # return tf.mul(x, 0.5, name=name)
......
...@@ -8,6 +8,8 @@ import numpy as np ...@@ -8,6 +8,8 @@ import numpy as np
from ._common import layer_register, shape4d from ._common import layer_register, shape4d
from ..utils.argtools import shape2d from ..utils.argtools import shape2d
from ..tfutils import symbolic_functions as symbf from ..tfutils import symbolic_functions as symbf
from ._test import TestModel
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
...@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape): ...@@ -131,7 +133,7 @@ def BilinearUpSample(x, shape):
:param x: input NHWC tensor :param x: input NHWC tensor
:param shape: an integer, the upsample factor :param shape: an integer, the upsample factor
""" """
#inp_shape = tf.shape(x) # inp_shape = tf.shape(x)
# return tf.image.resize_bilinear(x, # return tf.image.resize_bilinear(x,
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]), # tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True) # align_corners=True)
...@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape): ...@@ -172,9 +174,6 @@ def BilinearUpSample(x, shape):
return deconv return deconv
from ._test import TestModel
class TestPool(TestModel): class TestPool(TestModel):
def test_fixed_unpooling(self): def test_fixed_unpooling(self):
......
...@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout'] ...@@ -17,6 +17,7 @@ __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
def _log_regularizer(name): def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name)) logger.info("Apply regularizer for {}".format(name))
l2_regularizer = tf.contrib.layers.l2_regularizer l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer l1_regularizer = tf.contrib.layers.l1_regularizer
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta, abstractproperty from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils.naming import * from ..utils.naming import PREDICT_TOWER
from ..utils import logger from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
...@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -128,7 +128,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.predictors = [] self.predictors = []
with self.graph.as_default(): with self.graph.as_default():
# TODO backup summary keys? # TODO backup summary keys?
fn = lambda _: config.model.build_graph(config.model.get_input_vars()) def fn(_):
config.model.build_graph(config.model.get_input_vars())
build_multi_tower_prediction_graph(fn, towers) build_multi_tower_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config) self.sess = tf.Session(config=config.session_config)
......
...@@ -2,19 +2,14 @@ ...@@ -2,19 +2,14 @@
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import six import six
from six.moves import zip
from tensorpack.models import ModelDesc from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor from .base import OfflinePredictor
import multiprocessing
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult'] __all__ = ['PredictConfig', 'get_predict_func', 'PredictResult']
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
...@@ -53,7 +48,7 @@ class PredictConfig(object): ...@@ -53,7 +48,7 @@ class PredictConfig(object):
self.input_names = kwargs.pop('input_var_names', None) self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None: if self.input_names is not None:
pass pass
#logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!") # logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None: if self.input_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_input_vars_desc()
...@@ -61,7 +56,7 @@ class PredictConfig(object): ...@@ -61,7 +56,7 @@ class PredictConfig(object):
self.output_names = kwargs.pop('output_names', None) self.output_names = kwargs.pop('output_names', None)
if self.output_names is None: if self.output_names is None:
self.output_names = kwargs.pop('output_var_names') self.output_names = kwargs.pop('output_var_names')
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!") # logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names assert len(self.input_names), self.input_names
for v in self.input_names: for v in self.input_names:
assert_type(v, six.string_types) assert_type(v, six.string_types)
......
...@@ -5,10 +5,8 @@ ...@@ -5,10 +5,8 @@
import multiprocessing import multiprocessing
import threading import threading
import tensorflow as tf
import time
import six import six
from six.moves import queue, range, zip from six.moves import queue, range
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
...@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -49,7 +47,6 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from tensorpack.models._common import disable_layer_logging from tensorpack.models._common import disable_layer_logging
disable_layer_logging() disable_layer_logging()
self.predictor = OfflinePredictor(self.config) self.predictor = OfflinePredictor(self.config)
import sys
if self.idx == 0: if self.idx == 0:
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
describe_model() describe_model()
...@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -136,9 +133,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor""" """ :param predictors: a list of OnlinePredictor"""
assert len(predictors) assert len(predictors)
for k in predictors: for k in predictors:
#assert isinstance(k, OnlinePredictor), type(k) # assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here # TODO use predictors.return_input here
assert k.return_input == False assert not k.return_input
self.input_queue = queue.Queue(maxsize=len(predictors) * 100) self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [ self.threads = [
PredictorWorkerThread( PredictorWorkerThread(
......
...@@ -9,7 +9,7 @@ import multiprocessing ...@@ -9,7 +9,7 @@ import multiprocessing
import os import os
import six import six
from ..dataflow import DataFlow, BatchData from ..dataflow import DataFlow
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from ..utils.naming import * from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
import six import six
...@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -36,7 +36,7 @@ def get_default_sess_config(mem_fraction=0.99):
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
conf.allow_soft_placement = True conf.allow_soft_placement = True
#conf.log_device_placement = True # conf.log_device_placement = True
return conf return conf
...@@ -74,6 +74,7 @@ def get_op_tensor_name(name): ...@@ -74,6 +74,7 @@ def get_op_tensor_name(name):
else: else:
return name, name + ':0' return name, name + ':0'
get_op_var_name = get_op_tensor_name get_op_var_name = get_op_tensor_name
...@@ -88,6 +89,7 @@ def get_tensors_by_names(names): ...@@ -88,6 +89,7 @@ def get_tensors_by_names(names):
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
get_vars_by_names = get_tensors_by_names get_vars_by_names = get_tensors_by_names
......
...@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor): ...@@ -103,6 +103,7 @@ class MapGradient(GradientProcessor):
ret.append((grad, var)) ret.append((grad, var))
return ret return ret
_summaried_gradient = set() _summaried_gradient = set()
...@@ -133,7 +134,7 @@ class CheckGradient(MapGradient): ...@@ -133,7 +134,7 @@ class CheckGradient(MapGradient):
def _mapper(self, grad, var): def _mapper(self, grad, var):
# this is very slow.... see #3649 # this is very slow.... see #3649
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100) # op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name) grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name)
return grad return grad
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import os import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import re
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
...@@ -120,7 +119,8 @@ class SaverRestore(SessionInit): ...@@ -120,7 +119,8 @@ class SaverRestore(SessionInit):
ckpt_vars = reader.get_variable_to_shape_map().keys() ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars: for v in ckpt_vars:
if v.startswith(PREDICT_TOWER): if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. But anything from prediction tower shouldn't be saved.".format(v.name)) logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars) return set(ckpt_vars)
def _get_vars_to_restore_multimap(self, vars_available): def _get_vars_to_restore_multimap(self, vars_available):
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import re import re
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.naming import * from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
from . import get_global_step_var from . import get_global_step_var
from .symbolic_functions import rms from .symbolic_functions import rms
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from ..utils import logger
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
...@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -79,12 +78,10 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight) cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name) cost = tf.reduce_mean(cost * (1 - beta), name=name)
#logstable = tf.log(1 + tf.exp(-tf.abs(z))) # logstable = tf.log(1 + tf.exp(-tf.abs(z)))
# loss_pos = -beta * tf.reduce_mean(-y * # loss_pos = -beta * tf.reduce_mean(-y * (logstable - tf.minimum(0.0, z)))
#(logstable - tf.minimum(0.0, z))) # loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * (logstable + tf.maximum(z, 0.0)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * # cost = tf.sub(loss_pos, loss_neg, name=name)
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
return cost return cost
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
import re import re
from ..utils.naming import * from ..utils.naming import PREDICT_TOWER
__all__ = ['get_current_tower_context', 'TowerContext'] __all__ = ['get_current_tower_context', 'TowerContext']
...@@ -49,7 +49,7 @@ class TowerContext(object): ...@@ -49,7 +49,7 @@ class TowerContext(object):
with tf.variable_scope(self._name) as scope: with tf.variable_scope(self._name) as scope:
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False):
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
assert scope.reuse == False assert not scope.reuse
return tf.get_variable(*args, **kwargs) return tf.get_variable(*args, **kwargs)
def find_tensor_in_main_tower(self, graph, name): def find_tensor_in_main_tower(self, graph, name):
......
...@@ -10,7 +10,7 @@ from collections import defaultdict ...@@ -10,7 +10,7 @@ from collections import defaultdict
import re import re
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
from ..utils.naming import * from ..utils.naming import PREDICT_TOWER
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
...@@ -51,7 +51,7 @@ class SessionUpdate(object): ...@@ -51,7 +51,7 @@ class SessionUpdate(object):
self.sess = sess self.sess = sess
self.assign_ops = defaultdict(list) self.assign_ops = defaultdict(list)
for v in vars_to_update: for v in vars_to_update:
#p = tf.placeholder(v.dtype, shape=v.get_shape()) # p = tf.placeholder(v.dtype, shape=v.get_shape())
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
p = tf.placeholder(v.dtype) p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name) savename = get_savename_from_varname(v.name)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import signal
import re import re
import weakref import weakref
import six import six
......
...@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var ...@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput, DummyConstantInput from .input_data import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
from .trainer import MultiPredictorTowerTrainer from .trainer import MultiPredictorTowerTrainer
...@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer( ...@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) # self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer): class QueueInputTrainer(SimpleFeedfreeTrainer):
...@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer): ...@@ -114,7 +114,8 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
""" """
config.data = QueueInput(config.dataset, input_queue) config.data = QueueInput(config.dataset, input_queue)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!") logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
......
...@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread): ...@@ -86,7 +86,7 @@ class EnqueueThread(threading.Thread):
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError:
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
......
...@@ -9,9 +9,9 @@ import re ...@@ -9,9 +9,9 @@ import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import * from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
...@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer): ...@@ -36,7 +36,7 @@ class MultiGPUTrainer(Trainer):
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \ with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \ tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope: TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
grad_list.append(get_tower_grad_func()) grad_list.append(get_tower_grad_func())
...@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -60,7 +60,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!") logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
...@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -82,7 +83,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
if None in nones and len(nones) != 1: if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name)) raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None: elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name)) logger.warn("No Gradient w.r.t {}".format(v.op.name))
continue continue
try: try:
grad = tf.add_n(all_grad) / float(len(tower_grads)) grad = tf.add_n(all_grad) / float(len(tower_grads))
...@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -98,8 +99,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance: # debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]] # ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
...@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -129,7 +130,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!") logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
self._setup_predictor_factory(config.predict_tower) self._setup_predictor_factory(config.predict_tower)
......
...@@ -3,18 +3,16 @@ ...@@ -3,18 +3,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import time
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput from .input_data import FeedInput
__all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer'] __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
...@@ -49,7 +47,8 @@ class PredictorFactory(object): ...@@ -49,7 +47,8 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope # build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
fn = lambda _: self.model.build_graph(self.model.get_input_vars()) def fn(_):
self.model.build_graph(self.model.get_input_vars())
build_multi_tower_prediction_graph(fn, self.towers) build_multi_tower_prediction_graph(fn, self.towers)
self.tower_built = True self.tower_built = True
......
...@@ -9,8 +9,9 @@ import inspect ...@@ -9,8 +9,9 @@ import inspect
import six import six
import functools import functools
import collections import collections
from . import logger
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs'] __all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
def map_arg(**maps): def map_arg(**maps):
...@@ -64,11 +65,12 @@ class memoized(object): ...@@ -64,11 +65,12 @@ class memoized(object):
'''Support instance methods.''' '''Support instance methods.'''
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {} _MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func): def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary? hash(func) # make sure it is hashable. TODO is it necessary?
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS: if func not in _MEMOIZED_NOARGS:
...@@ -99,3 +101,8 @@ def shape2d(a): ...@@ -99,3 +101,8 @@ def shape2d(a):
assert len(a) == 2 assert len(a) == 2
return list(a) return list(a)
raise RuntimeError("Illegal shape: {}".format(a)) raise RuntimeError("Illegal shape: {}".format(a))
@memoized
def log_once(message, func):
getattr(logger, func)(message)
...@@ -11,13 +11,15 @@ from contextlib import contextmanager ...@@ -11,13 +11,15 @@ from contextlib import contextmanager
import signal import signal
import weakref import weakref
import six import six
from six.moves import queue
from . import logger
if six.PY2: if six.PY2:
import subprocess32 as subprocess import subprocess32 as subprocess
else: else:
import subprocess import subprocess
from six.moves import queue
from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE', 'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
......
...@@ -28,6 +28,7 @@ def enable_call_trace(): ...@@ -28,6 +28,7 @@ def enable_call_trace():
return return
sys.settrace(tracer) sys.settrace(tracer)
if __name__ == '__main__': if __name__ == '__main__':
enable_call_trace() enable_call_trace()
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
# File: discretize.py # File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from . import logger from .argtools import log_once
from .argtools import memoized
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import six import six
...@@ -13,13 +12,6 @@ from six.moves import range ...@@ -13,13 +12,6 @@ from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND'] __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
@memoized
def log_once(s):
logger.warn(s)
# just a placeholder
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Discretizer(object): class Discretizer(object):
...@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -54,10 +46,10 @@ class UniformDiscretizer1D(Discretizer1D):
def get_bin(self, v): def get_bin(self, v):
if v < self.minv: if v < self.minv:
log_once("UniformDiscretizer1D: value smaller than min!") log_once("UniformDiscretizer1D: value smaller than min!", 'warn')
return 0 return 0
if v > self.maxv: if v > self.maxv:
log_once("UniformDiscretizer1D: value larger than max!") log_once("UniformDiscretizer1D: value larger than max!", 'warn')
return self.nr_bin - 1 return self.nr_bin - 1
return int(np.clip( return int(np.clip(
(v - self.minv) / self.spacing, (v - self.minv) / self.spacing,
...@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer): ...@@ -126,8 +118,9 @@ class UniformDiscretizerND(Discretizer):
bin_id_nd = self.get_nd_bin_ids(bin_id) bin_id_nd = self.get_nd_bin_ids(bin_id)
return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)] return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)]
if __name__ == '__main__': if __name__ == '__main__':
#u = UniformDiscretizer1D(-10, 10, 0.12) # u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1)) u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
import IPython as IP import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config()) IP.embed(config=IP.terminal.ipapp.load_default_config())
...@@ -54,5 +54,6 @@ def recursive_walk(rootdir): ...@@ -54,5 +54,6 @@ def recursive_walk(rootdir):
for f in files: for f in files:
yield os.path.join(r, f) yield os.path.join(r, f)
if __name__ == '__main__': if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.') download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
...@@ -3,14 +3,9 @@ ...@@ -3,14 +3,9 @@
# File: loadcaffe.py # File: loadcaffe.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from collections import namedtuple, defaultdict
from abc import abstractmethod
import numpy as np import numpy as np
import copy
import os import os
from six.moves import zip
from .utils import change_env, get_dataset_path from .utils import change_env, get_dataset_path
from .fs import download from .fs import download
from . import logger from . import logger
...@@ -115,7 +110,7 @@ def get_caffe_pb(): ...@@ -115,7 +110,7 @@ def get_caffe_pb():
dir = get_dataset_path('caffe') dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file): if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir) download(CAFFE_PROTO_URL, dir)
assert os.path.isfile(os.path.join(dir, 'caffe.proto')) assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir)) ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \ assert ret == 0, \
...@@ -123,6 +118,7 @@ def get_caffe_pb(): ...@@ -123,6 +118,7 @@ def get_caffe_pb():
import imp import imp
return imp.load_source('caffepb', caffe_pb_file) return imp.load_source('caffepb', caffe_pb_file)
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -132,5 +128,4 @@ if __name__ == '__main__': ...@@ -132,5 +128,4 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
ret = load_caffe(args.model, args.weights) ret = load_caffe(args.model, args.weights)
import numpy as np
np.save(args.output, ret) np.save(args.output, ret)
...@@ -11,11 +11,11 @@ from datetime import datetime ...@@ -11,11 +11,11 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', 'warn_dependency'] __all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir',
'warn_dependency']
class _MyFormatter(logging.Formatter): class _MyFormatter(logging.Formatter):
def format(self, record): def format(self, record):
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green') date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s' msg = '%(message)s'
...@@ -40,12 +40,19 @@ def _getlogger(): ...@@ -40,12 +40,19 @@ def _getlogger():
handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
def get_time_str(): def get_time_str():
return datetime.now().strftime('%m%d-%H%M%S') return datetime.now().strftime('%m%d-%H%M%S')
# logger file and directory: # logger file and directory:
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
LOG_DIR = None LOG_DIR = None
...@@ -55,7 +62,7 @@ def _set_file(path): ...@@ -55,7 +62,7 @@ def _set_file(path):
if os.path.isfile(path): if os.path.isfile(path):
backup_name = path + '.' + get_time_str() backup_name = path + '.' + get_time_str()
shutil.move(path, backup_name) shutil.move(path, backup_name)
info("Log file '{}' backuped to '{}'".format(path, backup_name)) info("Log file '{}' backuped to '{}'".format(path, backup_name)) # noqa: F821
hdl = logging.FileHandler( hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w') filename=path, encoding='utf-8', mode='w')
hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
...@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""") ...@@ -83,12 +90,12 @@ If you're resuming from a previous run you can choose to keep it.""")
if act == 'b': if act == 'b':
backup_name = dirname + get_time_str() backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name) shutil.move(dirname, backup_name)
info("Directory '{}' backuped to '{}'".format(dirname, backup_name)) info("Directory '{}' backuped to '{}'".format(dirname, backup_name)) # noqa: F821
elif act == 'd': elif act == 'd':
shutil.rmtree(dirname) shutil.rmtree(dirname)
elif act == 'n': elif act == 'n':
dirname = dirname + get_time_str() dirname = dirname + get_time_str()
info("Use a new log directory {}".format(dirname)) info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k': elif act == 'k':
pass pass
else: else:
...@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""") ...@@ -100,12 +107,6 @@ If you're resuming from a previous run you can choose to keep it.""")
_set_file(LOG_FILE) _set_file(LOG_FILE)
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
def disable_logger(): def disable_logger():
""" disable all logging ability from this moment""" """ disable all logging ability from this moment"""
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
...@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False): ...@@ -127,4 +128,4 @@ def auto_set_dir(action=None, overwrite=False):
def warn_dependency(name, dependencies): def warn_dependency(name, dependencies):
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# File: naming.py # File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
...@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' ...@@ -14,7 +15,6 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables # placeholders for input variables
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
import tensorflow as tf
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
# export all upper case variables # export all upper case variables
......
...@@ -6,16 +6,13 @@ ...@@ -6,16 +6,13 @@
import msgpack import msgpack
import msgpack_numpy import msgpack_numpy
msgpack_numpy.patch() msgpack_numpy.patch()
#import dill
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
# return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True) return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
# return dill.loads(buf)
return msgpack.loads(buf) return msgpack.loads(buf)
...@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False): ...@@ -48,6 +48,7 @@ def timed_operation(msg, log_start=False):
logger.info('{} finished, time:{:.2f}sec.'.format( logger.info('{} finished, time:{:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
_TOTAL_TIMER_DATA = defaultdict(StatCounter) _TOTAL_TIMER_DATA = defaultdict(StatCounter)
...@@ -66,4 +67,5 @@ def print_total_timer(): ...@@ -66,4 +67,5 @@ def print_total_timer():
logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format( logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format(
k, v.sum, v.count, v.average)) k, v.sum, v.count, v.average))
atexit.register(print_total_timer) atexit.register(print_total_timer)
...@@ -8,7 +8,6 @@ from contextlib import contextmanager ...@@ -8,7 +8,6 @@ from contextlib import contextmanager
import inspect import inspect
from datetime import datetime from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
import time
import numpy as np import numpy as np
__all__ = ['change_env', __all__ = ['change_env',
......
...@@ -106,7 +106,6 @@ def build_patch_list(patch_list, ...@@ -106,7 +106,6 @@ def build_patch_list(patch_list,
ph, pw = patch_list.shape[1:3] ph, pw = patch_list.shape[1:3]
if border is None: if border is None:
border = int(0.1 * min(ph, pw)) border = int(0.1 * min(ph, pw))
mh, mw = max(max_height, ph + border), max(max_width, pw + border)
if nr_row is None: if nr_row is None:
nr_row = minnone(nr_row, max_height / (ph + border)) nr_row = minnone(nr_row, max_height / (ph + border))
if nr_col is None: if nr_col is None:
...@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -204,13 +203,13 @@ def dump_dataflow_images(df, index=0, batched=True,
if viz is not None: if viz is not None:
vizlist.append(img) vizlist.append(img)
if viz is not None and len(vizlist) >= vizsize: if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list( next(build_patch_list(
vizlist[:vizsize], vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1], viz=True)) nr_row=viz[0], nr_col=viz[1], viz=True))
vizlist = vizlist[vizsize:] vizlist = vizlist[vizsize:]
if __name__ == '__main__': if __name__ == '__main__':
import cv2
imglist = [] imglist = []
for i in range(100): for i in range(100):
fname = "{:03d}.png".format(i) fname = "{:03d}.png".format(i)
......
[flake8]
max-line-length = 120
exclude = .git,
__init__.py,
snippet,
docs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment