Commit fb2a051c authored by Yuxin Wu's avatar Yuxin Wu

run autopep8 over tensorpack/

parent 59553585
...@@ -8,6 +8,8 @@ import os ...@@ -8,6 +8,8 @@ import os
import os.path import os.path
__all__ = [] __all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages( ...@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_'):
_global_import(module_name) _global_import(module_name)
...@@ -9,7 +9,8 @@ from collections import deque ...@@ -9,7 +9,8 @@ from collections import deque
from .envbase import ProxyPlayer from .envbase import ProxyPlayer
__all__ = ['PreventStuckPlayer', 'LimitLengthPlayer', 'AutoRestartPlayer', __all__ = ['PreventStuckPlayer', 'LimitLengthPlayer', 'AutoRestartPlayer',
'MapPlayerState'] 'MapPlayerState']
class PreventStuckPlayer(ProxyPlayer): class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op) """ Prevent the player from getting stuck (repeating a no-op)
...@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer):
where the agent needs to press the 'start' button to start playing. where the agent needs to press the 'start' button to start playing.
""" """
# TODO hash the state as well? # TODO hash the state as well?
def __init__(self, player, nr_repeat, action): def __init__(self, player, nr_repeat, action):
""" """
It does auto-reset, but doesn't auto-restart the underlying player. It does auto-reset, but doesn't auto-restart the underlying player.
...@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
super(PreventStuckPlayer, self).restart_episode() super(PreventStuckPlayer, self).restart_episode()
self.act_que.clear() self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode. """ Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout Will auto restart the underlying player on timeout
""" """
def __init__(self, player, limit): def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player) super(LimitLengthPlayer, self).__init__(player)
self.limit = limit self.limit = limit
...@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
self.player.restart_episode() self.player.restart_episode()
self.cnt = 0 self.cnt = 0
class AutoRestartPlayer(ProxyPlayer): class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends, """ Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """ in case some player wasn't designed to do so. """
def action(self, act): def action(self, act):
r, isOver = self.player.action(act) r, isOver = self.player.action(act)
if isOver: if isOver:
...@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer): ...@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
self.player.restart_episode() self.player.restart_episode()
return r, isOver return r, isOver
class MapPlayerState(ProxyPlayer): class MapPlayerState(ProxyPlayer):
def __init__(self, player, func): def __init__(self, player, func):
super(MapPlayerState, self).__init__(player) super(MapPlayerState, self).__init__(player)
self.func = func self.func = func
......
...@@ -13,8 +13,10 @@ from ..utils import get_rng ...@@ -13,8 +13,10 @@ from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer', __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace'] 'DiscreteActionSpace']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class RLEnvironment(object): class RLEnvironment(object):
def __init__(self): def __init__(self):
self.reset_stat() self.reset_stat()
...@@ -60,13 +62,15 @@ class RLEnvironment(object): ...@@ -60,13 +62,15 @@ class RLEnvironment(object):
s = self.current_state() s = self.current_state()
act = func(s) act = func(s)
r, isOver = self.action(act) r, isOver = self.action(act)
#print r # print r
if isOver: if isOver:
s = [self.stats[k] for k in stat] s = [self.stats[k] for k in stat]
self.reset_stat() self.reset_stat()
return s if len(s) > 1 else s[0] return s if len(s) > 1 else s[0]
class ActionSpace(object): class ActionSpace(object):
def __init__(self): def __init__(self):
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -77,7 +81,9 @@ class ActionSpace(object): ...@@ -77,7 +81,9 @@ class ActionSpace(object):
def num_actions(self): def num_actions(self):
raise NotImplementedError() raise NotImplementedError()
class DiscreteActionSpace(ActionSpace): class DiscreteActionSpace(ActionSpace):
def __init__(self, num): def __init__(self, num):
super(DiscreteActionSpace, self).__init__() super(DiscreteActionSpace, self).__init__()
self.num = num self.num = num
...@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace): ...@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
def __str__(self): def __str__(self):
return "DiscreteActionSpace({})".format(self.num) return "DiscreteActionSpace({})".format(self.num)
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ for testing only"""
def __init__(self): def __init__(self):
self.k = 0 self.k = 0
def current_state(self): def current_state(self):
self.k += 1 self.k += 1
return self.k return self.k
def action(self, act): def action(self, act):
self.k = act self.k = act
return (self.k, self.k > 10) return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment): class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """ """ Serve as a proxy another player """
def __init__(self, player): def __init__(self, player):
self.player = player self.player = player
......
...@@ -10,14 +10,15 @@ import six ...@@ -10,14 +10,15 @@ import six
from six.moves import queue from six.moves import queue
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm, get_rng
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['ExpReplay'] __all__ = ['ExpReplay']
Experience = namedtuple('Experience', Experience = namedtuple('Experience',
['state', 'action', 'reward', 'isOver']) ['state', 'action', 'reward', 'isOver'])
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
...@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback): ...@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as an DataFlow. This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching) This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
""" """
def __init__(self, def __init__(self,
predictor_io_names, predictor_io_names,
player, player,
batch_size=32, batch_size=32,
memory_size=1e6, memory_size=1e6,
init_memory_size=50000, init_memory_size=50000,
exploration=1, exploration=1,
end_exploration=0.1, end_exploration=0.1,
exploration_epoch_anneal=0.002, exploration_epoch_anneal=0.002,
reward_clip=None, reward_clip=None,
update_frequency=1, update_frequency=1,
history_len=1 history_len=1
): ):
""" """
:param predictor: a callabale running the up-to-date network. :param predictor: a callabale running the up-to-date network.
called with a state, return a distribution. called with a state, return a distribution.
...@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback): ...@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
#if len(self.mem): # if len(self.mem):
#from copy import deepcopy # quickly fill the memory for debug # from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0])) # self.mem.append(deepcopy(self.mem[0]))
#return # return
old_s = self.player.current_state() old_s = self.player.current_state()
if self.rng.rand() <= self.exploration: if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
...@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback): ...@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
while True: while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)] batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2 # for debug # import cv2 # for debug
#def view_state(state, next_state): # def view_state(state, next_state):
#""" for debugging state representation""" # """ for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1) # r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1) # r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#r = np.concatenate([r, r2], axis=0) # r = np.concatenate([r, r2], axis=0)
#print r.shape # print r.shape
#cv2.imshow("state", r) # cv2.imshow("state", r)
#cv2.waitKey() # cv2.waitKey()
#exp = batch_exp[0] # exp = batch_exp[0]
#print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4]) # print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#if exp[2] or exp[4]: # if exp[2] or exp[4]:
#view_state(exp[0], exp[1]) # view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
self._populate_job_queue.put(1) self._populate_job_queue.put(1)
...@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback): ...@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
# when x.isOver==True, (x+1).state is of a different episode # when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1) idx = self.rng.randint(len(self.mem) - self.history_len - 1)
samples = [self.mem[k] for k in range(idx, idx+self.history_len+1)] samples = [self.mem[k] for k in range(idx, idx + self.history_len + 1)]
def concat(idx): def concat(idx):
v = [x.state for x in samples[idx:idx+self.history_len]] v = [x.state for x in samples[idx:idx + self.history_len]]
return np.concatenate(v, axis=2) return np.concatenate(v, axis=2)
state = concat(0) state = concat(0)
next_state = concat(1) next_state = concat(1)
...@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback): ...@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
# zero-fill state before starting # zero-fill state before starting
zero_fill = False zero_fill = False
for k in range(1, self.history_len): for k in range(1, self.history_len):
if samples[start_idx-k].isOver: if samples[start_idx - k].isOver:
zero_fill = True zero_fill = True
if zero_fill: if zero_fill:
state[:,:,-k-1] = 0 state[:, :, -k - 1] = 0
if k + 2 <= self.history_len: if k + 2 <= self.history_len:
next_state[:,:,-k-2] = 0 next_state[:, :, -k - 2] = 0
return (state, next_state, reward, action, isOver) return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
...@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
def _before_train(self): def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x # spawn a separate thread to run policy, can speed up 1.3x
self._populate_job_queue = queue.Queue(maxsize=1) self._populate_job_queue = queue.Queue(maxsize=1)
def populate_job_func(): def populate_job_func():
self._populate_job_queue.get() self._populate_job_queue.get()
with self.trainer.sess.as_default(): with self.trainer.sess.as_default():
...@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback): ...@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback):
pass pass
self.player.reset_stat() self.player.reset_stat()
if __name__ == '__main__': if __name__ == '__main__':
from .atari import AtariPlayer from .atari import AtariPlayer
import sys import sys
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204)) player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor, E = ExpReplay(predictor,
player=player, player=player,
num_actions=player.get_action_space().num_actions(), num_actions=player.get_action_space().num_actions(),
populate_size=1001, populate_size=1001,
history_len=4) history_len=4)
E._init_memory() E._init_memory()
for k in E.get_data(): for k in E.get_data():
import IPython as IP; import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config()) IP.embed(config=IP.terminal.ipapp.load_default_config())
pass pass
#import IPython; # import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config()) # IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break # break
...@@ -9,7 +9,7 @@ from ..utils import logger ...@@ -9,7 +9,7 @@ from ..utils import logger
try: try:
import gym import gym
# TODO # TODO
#gym.undo_logger_setup() # gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199 # https://github.com/openai/gym/pull/199
# not sure does it cause other problems # not sure does it cause other problems
__all__ = ['GymEnv'] __all__ = ['GymEnv']
...@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace ...@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
_ENV_LOCK = threading.Lock() _ENV_LOCK = threading.Lock()
class GymEnv(RLEnvironment): class GymEnv(RLEnvironment):
""" """
An OpenAI/gym wrapper. Can optionally auto restart. An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now Only support discrete action space now
""" """
def __init__(self, name, dumpdir=None, viz=False, auto_restart=True): def __init__(self, name, dumpdir=None, viz=False, auto_restart=True):
with _ENV_LOCK: with _ENV_LOCK:
self.gymenv = gym.make(name) self.gymenv = gym.make(name)
...@@ -82,7 +84,7 @@ if __name__ == '__main__': ...@@ -82,7 +84,7 @@ if __name__ == '__main__':
rng = get_rng(num) rng = get_rng(num)
while True: while True:
act = rng.choice(range(num)) act = rng.choice(range(num))
#print act # print act
r, o = env.action(act) r, o = env.action(act)
env.current_state() env.current_state()
if r != 0 or o: if r != 0 or o:
......
...@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer ...@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer'] __all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer): class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images """ Include history frames in state, or use black images
Assume player will do auto-restart. Assume player will do auto-restart.
""" """
def __init__(self, player, hist_len): def __init__(self, player, hist_len):
""" """
:param hist_len: total length of the state, including the current :param hist_len: total length of the state, including the current
...@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer): ...@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
super(HistoryFramePlayer, self).restart_episode() super(HistoryFramePlayer, self).restart_episode()
self.history.clear() self.history.clear()
self.history.append(self.player.current_state()) self.history.append(self.player.current_state())
...@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps ...@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps
from ..utils.concurrency import LoopThread, ensure_proc_terminate from ..utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight', 'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync'] 'TransitionExperience', 'WeightSync']
try: try:
import zmq import zmq
...@@ -34,8 +34,10 @@ except ImportError: ...@@ -34,8 +34,10 @@ except ImportError:
logger.warn_dependency('Simulator', 'zmq') logger.warn_dependency('Simulator', 'zmq')
__all__ = [] __all__ = []
class TransitionExperience(object): class TransitionExperience(object):
""" A transition of state, or experience""" """ A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs): def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save""" """ kwargs: whatever other attribute you want to save"""
self.state = state self.state = state
...@@ -44,6 +46,7 @@ class TransitionExperience(object): ...@@ -44,6 +46,7 @@ class TransitionExperience(object):
for k, v in six.iteritems(kwargs): for k, v in six.iteritems(kwargs):
setattr(self, k, v) setattr(self, k, v)
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process): class SimulatorProcessBase(mp.Process):
...@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to A process that simulates a player and communicates to master to
send states and receive the next action send states and receive the next action
""" """
def __init__(self, idx, pipe_c2s, pipe_s2c): def __init__(self, idx, pipe_c2s, pipe_s2c):
""" """
:param idx: idx of this process :param idx: idx of this process
...@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket = context.socket(zmq.DEALER) s2c_socket = context.socket(zmq.DEALER)
s2c_socket.setsockopt(zmq.IDENTITY, self.identity) s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
#s2c_socket.set_hwm(5) # s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c) s2c_socket.connect(self.s2c)
state = player.current_state() state = player.current_state()
...@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
# compatibility # compatibility
SimulatorProcess = SimulatorProcessStateExchange SimulatorProcess = SimulatorProcessStateExchange
class SimulatorMaster(threading.Thread): class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess. """ A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished. defining callbacks when a transition or an episode is finished.
""" """
class ClientState(object): class ClientState(object):
def __init__(self): def __init__(self):
self.memory = [] # list of Experience self.memory = [] # list of Experience
...@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread): ...@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
def __del__(self): def __del__(self):
self.context.destroy(linger=0) self.context.destroy(linger=0)
class SimulatorProcessDF(SimulatorProcessBase): class SimulatorProcessDF(SimulatorProcessBase):
""" A simulator which contains a forward model itself, allowing """ A simulator which contains a forward model itself, allowing
it to produce data points directly """ it to produce data points directly """
def __init__(self, idx, pipe_c2s): def __init__(self, idx, pipe_c2s):
super(SimulatorProcessDF, self).__init__(idx) super(SimulatorProcessDF, self).__init__(idx)
self.pipe_c2s = pipe_c2s self.pipe_c2s = pipe_c2s
...@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase): ...@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
def get_data(self): def get_data(self):
pass pass
class SimulatorProcessSharedWeight(SimulatorProcessDF): class SimulatorProcessSharedWeight(SimulatorProcessDF):
""" A simulator process with an extra thread waiting for event, """ A simulator process with an extra thread waiting for event,
and take shared weight from shm. and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set! Start me under some CUDA_VISIBLE_DEVICES set!
""" """
def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config): def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s) super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.condvar = condvar self.condvar = condvar
...@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF): ...@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
vars_to_update = self._params_to_update() vars_to_update = self._params_to_update()
self.sess_updater = SessionUpdate( self.sess_updater = SessionUpdate(
self.predictor.session, vars_to_update) self.predictor.session, vars_to_update)
# TODO setup callback for explore? # TODO setup callback for explore?
self.predictor.graph.finalize() self.predictor.graph.finalize()
...@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF): ...@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
# can be overwritten to update more params # can be overwritten to update more params
return tf.trainable_variables() return tf.trainable_variables()
class WeightSync(Callback): class WeightSync(Callback):
""" Sync weight from main process to shared_dic and notify""" """ Sync weight from main process to shared_dic and notify"""
def __init__(self, condvar, shared_dic): def __init__(self, condvar, shared_dic):
self.condvar = condvar self.condvar = condvar
self.shared_dic = shared_dic self.shared_dic = shared_dic
...@@ -260,6 +272,7 @@ class WeightSync(Callback): ...@@ -260,6 +272,7 @@ class WeightSync(Callback):
def _before_train(self): def _before_train(self):
self._sync() self._sync()
def _trigger_epoch(self): def _trigger_epoch(self):
self._sync() self._sync()
...@@ -274,13 +287,18 @@ class WeightSync(Callback): ...@@ -274,13 +287,18 @@ class WeightSync(Callback):
if __name__ == '__main__': if __name__ == '__main__':
import random import random
from tensorpack.RL import NaiveRLEnvironment from tensorpack.RL import NaiveRLEnvironment
class NaiveSimulator(SimulatorProcess): class NaiveSimulator(SimulatorProcess):
def _build_player(self): def _build_player(self):
return NaiveRLEnvironment() return NaiveRLEnvironment()
class NaiveActioner(SimulatorActioner): class NaiveActioner(SimulatorActioner):
def _get_action(self, state): def _get_action(self, state):
time.sleep(1) time.sleep(1)
return random.randint(1, 12) return random.randint(1, 12)
def _on_episode_over(self, client): def _on_episode_over(self, client):
#print("Over: ", client.memory) #print("Over: ", client.memory)
client.memory = [] client.memory = []
...@@ -296,4 +314,3 @@ if __name__ == '__main__': ...@@ -296,4 +314,3 @@ if __name__ == '__main__':
import time import time
time.sleep(100) time.sleep(100)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034 import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924 import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from tensorpack.train import * from tensorpack.train import *
......
...@@ -7,6 +7,8 @@ import os ...@@ -7,6 +7,8 @@ import os
__all__ = [] __all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages( ...@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
continue continue
if not module_name.startswith('_'): if not module_name.startswith('_'):
_global_import(module_name) _global_import(module_name)
...@@ -11,6 +11,7 @@ import six ...@@ -11,6 +11,7 @@ import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks """
...@@ -72,7 +73,9 @@ class Callback(object): ...@@ -72,7 +73,9 @@ class Callback(object):
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
class ProxyCallback(Callback): class ProxyCallback(Callback):
def __init__(self, cb): def __init__(self, cb):
self.cb = cb self.cb = cb
...@@ -91,11 +94,13 @@ class ProxyCallback(Callback): ...@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
def __str__(self): def __str__(self):
return "Proxy-" + str(self.cb) return "Proxy-" + str(self.cb)
class PeriodicCallback(ProxyCallback): class PeriodicCallback(ProxyCallback):
""" """
A callback to be triggered after every `period` epochs. A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step Doesn't work for trigger_step
""" """
def __init__(self, cb, period): def __init__(self, cb, period):
""" """
:param cb: a `Callback` :param cb: a `Callback`
...@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback): ...@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
def __str__(self): def __str__(self):
return "Periodic-" + str(self.cb) return "Periodic-" + str(self.cb)
...@@ -9,7 +9,9 @@ from ..utils import logger ...@@ -9,7 +9,9 @@ from ..utils import logger
__all__ = ['StartProcOrThread'] __all__ = ['StartProcOrThread']
class StartProcOrThread(Callback): class StartProcOrThread(Callback):
def __init__(self, procs_threads): def __init__(self, procs_threads):
""" """
Start extra threads and processes before training Start extra threads and processes before training
...@@ -20,7 +22,7 @@ class StartProcOrThread(Callback): ...@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
self._procs_threads = procs_threads self._procs_threads = procs_threads
def _before_train(self): def _before_train(self):
logger.info("Starting " + \ logger.info("Starting " +
', '.join([k.name for k in self._procs_threads])) ', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
...@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name ...@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
__all__ = ['OutputTensorDispatcer'] __all__ = ['OutputTensorDispatcer']
class OutputTensorDispatcer(object): class OutputTensorDispatcer(object):
def __init__(self): def __init__(self):
self._names = [] self._names = []
self._idxs = [] self._idxs = []
......
...@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name ...@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
""" """
Dump a variable to image(s) after every epoch to logger.LOG_DIR. Dump a variable to image(s) after every epoch to logger.LOG_DIR.
""" """
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False): def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
""" """
:param var_name: the name of the variable. :param var_name: the name of the variable.
...@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback): ...@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
if self.clip: if self.clip:
res = np.clip(res, 0, 255) res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8')) cv2.imwrite(fname, res.astype('uint8'))
...@@ -10,8 +10,10 @@ from ..utils import logger ...@@ -10,8 +10,10 @@ from ..utils import logger
__all__ = ['RunOp'] __all__ = ['RunOp']
class RunOp(Callback): class RunOp(Callback):
""" Run an op periodically""" """ Run an op periodically"""
def __init__(self, setup_func, run_before=True, run_epoch=True): def __init__(self, setup_func, run_before=True, run_epoch=True):
""" """
:param setup_func: a function that returns the op in the graph :param setup_func: a function that returns the op in the graph
...@@ -34,5 +36,5 @@ class RunOp(Callback): ...@@ -34,5 +36,5 @@ class RunOp(Callback):
if self.run_epoch: if self.run_epoch:
self._op.run() self._op.run()
#def _log(self): # def _log(self):
#logger.info("Running op {} ...".format(self._op_name)) #logger.info("Running op {} ...".format(self._op_name))
...@@ -12,7 +12,9 @@ from ..utils import logger ...@@ -12,7 +12,9 @@ from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
self.times = [] self.times = []
self.tot = 0 self.tot = 0
...@@ -39,10 +41,12 @@ class CallbackTimeLogger(object): ...@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}".format( "Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs))) self.tot, '; '.join(msgs)))
class Callbacks(Callback): class Callbacks(Callback):
""" """
A container to hold all callbacks, and execute them in the right order and proper session. A container to hold all callbacks, and execute them in the right order and proper session.
""" """
def __init__(self, cbs): def __init__(self, cbs):
""" """
:param cbs: a list of `Callbacks` :param cbs: a list of `Callbacks`
......
...@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics ...@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name from ..tfutils import get_op_var_name
__all__ = ['ClassificationError', __all__ = ['ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats'] 'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(object): class Inferencer(object):
...@@ -59,12 +60,14 @@ class Inferencer(object): ...@@ -59,12 +60,14 @@ class Inferencer(object):
def _get_output_tensors(self): def _get_output_tensors(self):
pass pass
class ScalarStats(Inferencer): class ScalarStats(Inferencer):
""" """
Write some scalar tensor to both stat and summary. Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar. The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow. The value will be averaged over all data points in the inference dataflow.
""" """
def __init__(self, names_to_print, prefix='validation'): def __init__(self, names_to_print, prefix='validation'):
""" """
:param names_to_print: list of names of tensors, or just a name :param names_to_print: list of names of tensors, or just a name
...@@ -96,6 +99,7 @@ class ScalarStats(Inferencer): ...@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
ret[name] = stat ret[name] = stat
return ret return ret
class ClassificationError(Inferencer): class ClassificationError(Inferencer):
""" """
Compute classification error in batch mode, from a `wrong` variable Compute classification error in batch mode, from a `wrong` variable
...@@ -109,6 +113,7 @@ class ClassificationError(Inferencer): ...@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size). testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch. Therefore the result is different from averaging the error rate of each batch.
""" """
def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'): def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'):
""" """
:param wrong_var_name: name of the `wrong` variable :param wrong_var_name: name of the `wrong` variable
...@@ -138,6 +143,7 @@ class ClassificationError(Inferencer): ...@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
def _after_inference(self): def _after_inference(self):
return {self.summary_name: self.err_stat.ratio} return {self.summary_name: self.err_stat.ratio}
class BinaryClassificationStats(Inferencer): class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the """ Compute precision/recall in binary classification, given the
prediction vector and the label vector. prediction vector and the label vector.
......
...@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput ...@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner'] __all__ = ['InferenceRunner']
def summary_inferencer(trainer, infs): def summary_inferencer(trainer, infs):
for inf in infs: for inf in infs:
ret = inf.after_inference() ret = inf.after_inference()
...@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs): ...@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
continue continue
trainer.write_scalar_summary(k, v) trainer.write_scalar_summary(k, v)
class InferenceRunner(Callback): class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
...@@ -54,16 +56,17 @@ class InferenceRunner(Callback): ...@@ -54,16 +56,17 @@ class InferenceRunner(Callback):
self.input_tensors = input_tensors self.input_tensors = input_tensors
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # these are all tensor names self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func( self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors) self.input_tensors, self.output_tensors)
def _find_input_tensors(self): def _find_input_tensors(self):
if self.input_tensors is None: if self.input_tensors is None:
input_vars = self.trainer.model.get_reuse_placehdrs() input_vars = self.trainer.model.get_reuse_placehdrs()
# TODO even if it works here, sparse still is unavailable # TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse # because get_tensor_by_name doesn't work for sparse
def get_name(x): def get_name(x):
if isinstance(x, tf.SparseTensor): if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0] return x.op.name.split('/')[0]
...@@ -79,6 +82,7 @@ class InferenceRunner(Callback): ...@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
IOTensor = InferenceRunner.IOTensor IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter( self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names)) lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs): def find_oid(idxs):
ret = [] ret = []
for idx in idxs: for idx in idxs:
...@@ -102,7 +106,7 @@ class InferenceRunner(Callback): ...@@ -102,7 +106,7 @@ class InferenceRunner(Callback):
outputs = self.pred_func(dp) outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index] inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap] for k in tensormap]
inf.datapoint(inf_output) inf.datapoint(inf_output)
pbar.update() pbar.update()
self._write_summary_after_inference() self._write_summary_after_inference()
...@@ -110,6 +114,7 @@ class InferenceRunner(Callback): ...@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
def _write_summary_after_inference(self): def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Callback): class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
...@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback): ...@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback):
if self.input_tensor_names is not None: if self.input_tensor_names is not None:
assert isinstance(self.input_tensor_names, list) assert isinstance(self.input_tensor_names, list)
self._input_tensors = [k for idx, k in enumerate(self._input_tensors) self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_tensor_names] if model_placehdrs[idx].name in self.input_tensor_names]
assert len(self._input_tensors) == len(self.input_tensor_names), \ assert len(self._input_tensors) == len(self.input_tensor_names), \
"names of input tensors are not defined in the Model" "names of input tensors are not defined in the Model"
def _find_output_tensors(self): def _find_output_tensors(self):
# doesn't support output an input tensor # doesn't support output an input tensor
...@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback): ...@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
IOTensor = InferenceRunner.IOTensor IOTensor = InferenceRunner.IOTensor
self.output_tensors = all_names self.output_tensors = all_names
def find_oid(idxs): def find_oid(idxs):
ret = [] ret = []
for idx in idxs: for idx in idxs:
...@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback): ...@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()] self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor) # list of list of (var_name: IOTensor)
def _trigger_epoch(self): def _trigger_epoch(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
...@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback): ...@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
sz = self._input_data.size() sz = self._input_data.size()
with get_tqdm(total=sz) as pbar: with get_tqdm(total=sz) as pbar:
for _ in range(sz): for _ in range(sz):
#outputs = self.pred_func(dp) # outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors): # for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#inf_output = [(outputs if k.isOutput else dp)[k.index] # inf_output = [(outputs if k.isOutput else dp)[k.index]
#for k in tensormap] # for k in tensormap]
#inf.datapoint(inf_output) # inf.datapoint(inf_output)
pbar.update() pbar.update()
self._write_summary_after_inference() self._write_summary_after_inference()
......
...@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter', ...@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter', 'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc', 'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam'] 'HyperParam', 'GraphVarParam', 'ObjAttrParam']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class HyperParam(object): class HyperParam(object):
""" Base class for a hyper param""" """ Base class for a hyper param"""
...@@ -35,8 +37,10 @@ class HyperParam(object): ...@@ -35,8 +37,10 @@ class HyperParam(object):
""" A name to display""" """ A name to display"""
return self._readable_name return self._readable_name
class GraphVarParam(HyperParam): class GraphVarParam(HyperParam):
""" a variable in the graph can be a hyperparam""" """ a variable in the graph can be a hyperparam"""
def __init__(self, name, shape=[]): def __init__(self, name, shape=[]):
self.name = name self.name = name
self.shape = shape self.shape = shape
...@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam): ...@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
self.assign_op = self.var.assign(self.val_holder) self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v): def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v}) self.assign_op.eval(feed_dict={self.val_holder: v})
def get_value(self): def get_value(self):
return self.var.eval() return self.var.eval()
class ObjAttrParam(HyperParam): class ObjAttrParam(HyperParam):
""" an attribute of an object can be a hyperparam""" """ an attribute of an object can be a hyperparam"""
def __init__(self, obj, attrname, readable_name=None): def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname.""" """ :param readable_name: default to be attrname."""
self.obj = obj self.obj = obj
...@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam): ...@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
def get_value(self, v): def get_value(self, v):
return getattr(self.obj, self.attrname) return getattr(self.obj, self.attrname)
class HyperParamSetter(Callback): class HyperParamSetter(Callback):
""" """
Base class to set hyperparameters after every epoch. Base class to set hyperparameters after every epoch.
...@@ -126,10 +133,12 @@ class HyperParamSetter(Callback): ...@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
if v is not None: if v is not None:
self.param.set_value(v) self.param.set_value(v)
class HumanHyperParamSetter(HyperParamSetter): class HumanHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by loading the value from a file each time it get called. Set hyperparameters by loading the value from a file each time it get called.
""" """
def __init__(self, param, file_name='hyper.txt'): def __init__(self, param, file_name='hyper.txt'):
""" """
:param file_name: a file containing the value of the variable. :param file_name: a file containing the value of the variable.
...@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
with open(self.file_name) as f: with open(self.file_name) as f:
lines = f.readlines() lines = f.readlines()
lines = [s.strip().split(':') for s in lines] lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines} dic = {str(k): float(v) for k, v in lines}
ret = dic[self.param.readable_name] ret = dic[self.param.readable_name]
return ret return ret
except: except:
...@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
self.param.readable_name, self.file_name)) self.param.readable_name, self.file_name))
return None return None
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by a predefined schedule. Set hyperparameters by a predefined schedule.
""" """
def __init__(self, param, schedule, interp=None): def __init__(self, param, schedule, interp=None):
""" """
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...] :param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
...@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v return v
class HyperParamSetterWithFunc(HyperParamSetter): class HyperParamSetterWithFunc(HyperParamSetter):
def __init__(self, param, func): def __init__(self, param, func):
"""Set hyperparameter by a func """Set hyperparameter by a func
new_value = f(epoch_num, old_value) new_value = f(epoch_num, old_value)
...@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter): ...@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter):
def _get_value_to_set(self): def _get_value_to_set(self):
return self.f(self.epoch_num, self.get_current_value()) return self.f(self.epoch_num, self.get_current_value())
class StatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
def __init__(self, param, stat_name, value_func, threshold, def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False last_k, reverse=False
): ):
""" """
Set hyperparameter by a func, when a specific stat wasn't Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs. decreasing/increasing enough in the last $k$ epochs.
...@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter):
def _get_value_to_set(self): def _get_value_to_set(self):
holder = self.trainer.stat_holder holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name) hist = holder.get_stat_history(self.stat_name)
if len(hist) < self.last_k+1 or \ if len(hist) < self.last_k + 1 or \
self.epoch_num - self.last_changed_epoch < self.last_k: self.epoch_num - self.last_changed_epoch < self.last_k:
return None return None
hist = hist[-self.last_k-1:] # len==last_k+1 hist = hist[-self.last_k - 1:] # len==last_k+1
hist_first = hist[0] hist_first = hist[0]
if not self.reverse: if not self.reverse:
hist_min = min(hist) hist_min = min(hist)
if hist_min < hist_first - self.threshold: # small enough if hist_min < hist_first - self.threshold: # small enough
return None return None
else: else:
hist_max = max(hist) hist_max = max(hist)
if hist_max > hist_first + self.threshold: # large enough if hist_max > hist_first + self.threshold: # large enough
return None return None
self.last_changed_epoch = self.epoch_num self.last_changed_epoch = self.epoch_num
logger.info("[StatMonitorParamSetter] Triggered, history: " + logger.info("[StatMonitorParamSetter] Triggered, history: " +
','.join(map(str, hist))) ','.join(map(str, hist)))
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import os, shutil import os
import shutil
import re import re
from .base import Callback from .base import Callback
...@@ -13,12 +14,14 @@ from ..tfutils import get_global_step ...@@ -13,12 +14,14 @@ from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback): class ModelSaver(Callback):
""" """
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=None): var_collections=None):
""" """
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
...@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try: try:
if not self.meta_graph_written: if not self.meta_graph_written:
self.saver.export_meta_graph( self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR, os.path.join(logger.LOG_DIR,
'graph-{}.meta'.format(logger.get_time_str())), 'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys()) collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True self.meta_graph_written = True
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
...@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
except (OSError, IOError): # disk error sometimes.. just ignore it except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback): class MinSaver(Callback):
def __init__(self, monitor_stat, reverse=True, filename=None): def __init__(self, monitor_stat, reverse=True, filename=None):
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
...@@ -116,15 +121,14 @@ class MinSaver(Callback): ...@@ -116,15 +121,14 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.LOG_DIR,
self.filename or self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel')) ('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
shutil.copy(path, newname) shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format( logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) 'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver): class MaxSaver(MinSaver):
def __init__(self, monitor_stat): def __init__(self, monitor_stat):
super(MaxSaver, self).__init__(monitor_stat, True) super(MaxSaver, self).__init__(monitor_stat, True)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import re, os import re
import os
import operator import operator
import json import json
...@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step ...@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
__all__ = ['StatHolder', 'StatPrinter', 'SendStat'] __all__ = ['StatHolder', 'StatPrinter', 'SendStat']
class StatHolder(object): class StatHolder(object):
""" """
A holder to keep all statistics aside from tensorflow events. A holder to keep all statistics aside from tensorflow events.
""" """
def __init__(self, log_dir): def __init__(self, log_dir):
""" """
:param log_dir: directory to save the stats. :param log_dir: directory to save the stats.
...@@ -62,9 +65,11 @@ class StatHolder(object): ...@@ -62,9 +65,11 @@ class StatHolder(object):
ret = [] ret = []
for h in self.stat_history: for h in self.stat_history:
v = h.get(key, None) v = h.get(key, None)
if v is not None: ret.append(v) if v is not None:
ret.append(v)
v = self.stat_now.get(key, None) v = self.stat_now.get(key, None)
if v is not None: ret.append(v) if v is not None:
ret.append(v)
return ret return ret
def finalize(self): def finalize(self):
...@@ -88,13 +93,15 @@ class StatHolder(object): ...@@ -88,13 +93,15 @@ class StatHolder(object):
with open(tmp_filename, 'w') as f: with open(tmp_filename, 'w') as f:
json.dump(self.stat_history, f) json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
except IOError: # disk error sometimes.. except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!") logger.exception("Exception in StatHolder.finalize()!")
class StatPrinter(Callback): class StatPrinter(Callback):
""" """
Control what stats to print. Control what stats to print.
""" """
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" """
:param print_tag: a list of regex to match scalar summary to print. :param print_tag: a list of regex to match scalar summary to print.
...@@ -116,6 +123,7 @@ class StatPrinter(Callback): ...@@ -116,6 +123,7 @@ class StatPrinter(Callback):
self._stat_holder.finalize() self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1) self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Callback): class SendStat(Callback):
""" """
Execute a command with some specific stats. Execute a command with some specific stats.
...@@ -126,6 +134,7 @@ class SendStat(Callback): ...@@ -126,6 +134,7 @@ class SendStat(Callback):
-d body={validation_error} > /dev/null 2>&1', -d body={validation_error} > /dev/null 2>&1',
'validation_error') 'validation_error')
""" """
def __init__(self, command, stats): def __init__(self, command, stats):
self.command = command self.command = command
if not isinstance(stats, list): if not isinstance(stats, list):
......
...@@ -12,6 +12,7 @@ from . import imgaug ...@@ -12,6 +12,7 @@ from . import imgaug
__all__ = ['dataset', 'imgaug', 'dftools'] __all__ = ['dataset', 'imgaug', 'dftools']
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug'] ...@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug']
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_') and \ if not module_name.startswith('_') and \
module_name not in __SKIP: module_name not in __SKIP:
_global_import(module_name) _global_import(module_name)
...@@ -10,6 +10,7 @@ from ..utils import get_rng ...@@ -10,6 +10,7 @@ from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
...@@ -17,7 +18,6 @@ class DataFlow(object): ...@@ -17,7 +18,6 @@ class DataFlow(object):
class Infinity: class Infinity:
pass pass
@abstractmethod @abstractmethod
def get_data(self): def get_data(self):
""" """
...@@ -44,11 +44,14 @@ class DataFlow(object): ...@@ -44,11 +44,14 @@ class DataFlow(object):
class RNGDataFlow(DataFlow): class RNGDataFlow(DataFlow):
""" A dataflow with rng""" """ A dataflow with rng"""
def reset_state(self): def reset_state(self):
self.rng = get_rng(self) self.rng = get_rng(self)
class ProxyDataFlow(DataFlow): class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another""" """ Base class for DataFlow that proxies another"""
def __init__(self, ds): def __init__(self, ds):
""" """
:param ds: a :mod:`DataFlow` instance to proxy :param ds: a :mod:`DataFlow` instance to proxy
......
...@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData', ...@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData', 'TestDataSpeed', 'BatchDataByShape'] 'LocallyShuffleData', 'TestDataSpeed', 'BatchDataByShape']
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
def __init__(self, ds, size=1000): def __init__(self, ds, size=1000):
super(TestDataSpeed, self).__init__(ds) super(TestDataSpeed, self).__init__(ds)
self.test_size = size self.test_size = size
...@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
pbar.update() pbar.update()
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
""" """
Group data in `ds` into batches. Group data in `ds` into batches.
...@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow): ...@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
raise raise
except: except:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP; import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config()) IP.embed(config=IP.terminal.ipapp.load_default_config())
return result return result
class BatchDataByShape(BatchData): class BatchDataByShape(BatchData):
def __init__(self, ds, batch_size, idx): def __init__(self, ds, batch_size, idx):
""" Group datapoint of the same shape together to batches """ Group datapoint of the same shape together to batches
...@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData): ...@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
yield BatchData._aggregate_batch(holder) yield BatchData._aggregate_batch(holder)
del holder[:] del holder[:]
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size. """ Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch. The state of the underlying DataFlow is maintained among each epoch.
""" """
def __init__(self, ds, size): def __init__(self, ds, size):
""" """
:param ds: a :mod:`DataFlow` to produce data :param ds: a :mod:`DataFlow` to produce data
...@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow): ...@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
if cnt == self._size: if cnt == self._size:
return return
class RepeatedData(ProxyDataFlow): class RepeatedData(ProxyDataFlow):
""" Take data points from another `DataFlow` and produce them until """ Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times. it's exhausted for certain amount of times.
""" """
def __init__(self, ds, nr): def __init__(self, ds, nr):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
...@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow): ...@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint""" """ Apply map/filter a function on the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
...@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow): ...@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
if ret is not None: if ret is not None:
yield ret yield ret
class MapDataComponent(ProxyDataFlow): class MapDataComponent(ProxyDataFlow):
""" Apply map/filter on the given index in the datapoint""" """ Apply map/filter on the given index in the datapoint"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
...@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow): ...@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
dp[self.index] = repl # NOTE modifying dp[self.index] = repl # NOTE modifying
yield dp yield dp
class RandomChooseData(RNGDataFlow): class RandomChooseData(RNGDataFlow):
""" """
Randomly choose from several DataFlow. Stop producing when any of them is Randomly choose from several DataFlow. Stop producing when any of them is
exhausted. exhausted.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple :param df_lists: list of dataflow, or list of (dataflow, probability) tuple
...@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow): ...@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
except StopIteration: except StopIteration:
return return
class RandomMixData(RNGDataFlow): class RandomMixData(RNGDataFlow):
""" """
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix. Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
:param df_lists: list of dataflow. :param df_lists: list of dataflow.
...@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow): ...@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
idxs = np.array(list(map( idxs = np.array(list(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs))) lambda x: np.searchsorted(sums, x, 'right'), idxs)))
itrs = [k.get_data() for k in self.df_lists] itrs = [k.get_data() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs)-1) assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1)
for k in idxs: for k in idxs:
yield next(itrs[k]) yield next(itrs[k])
class ConcatData(DataFlow): class ConcatData(DataFlow):
""" """
Concatenate several dataflows. Concatenate several dataflows.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
:param df_lists: list of :mod:`DataFlow` instances :param df_lists: list of :mod:`DataFlow` instances
...@@ -311,6 +331,7 @@ class ConcatData(DataFlow): ...@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
for dp in d.get_data(): for dp in d.get_data():
yield dp yield dp
class JoinData(DataFlow): class JoinData(DataFlow):
""" """
Join the components from each DataFlow. Join the components from each DataFlow.
...@@ -321,6 +342,7 @@ class JoinData(DataFlow): ...@@ -321,6 +342,7 @@ class JoinData(DataFlow):
df2: [dp3, dp4] df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4] join: [dp1, dp2, dp3, dp4]
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
:param df_lists: list of :mod:`DataFlow` instances :param df_lists: list of :mod:`DataFlow` instances
...@@ -329,7 +351,7 @@ class JoinData(DataFlow): ...@@ -329,7 +351,7 @@ class JoinData(DataFlow):
self._size = self.df_lists[0].size() self._size = self.df_lists[0].size()
for d in self.df_lists: for d in self.df_lists:
assert d.size() == self._size, \ assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size) "All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
def reset_state(self): def reset_state(self):
for d in self.df_lists: for d in self.df_lists:
...@@ -352,7 +374,9 @@ class JoinData(DataFlow): ...@@ -352,7 +374,9 @@ class JoinData(DataFlow):
for itr in itrs: for itr in itrs:
del itr del itr
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size, nr_reuse=1): def __init__(self, ds, cache_size, nr_reuse=1):
""" """
Cache a number of datapoints and shuffle them. Cache a number of datapoints and shuffle them.
...@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield v yield v
return return
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
:param ds: a :mod:`DataFlow` instance :param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow :param idxs: a list of datapoint component index of the original dataflow
""" """
return MapData(ds, lambda dp: [dp[i] for i in idxs]) return MapData(ds, lambda dp: [dp[i] for i in idxs])
...@@ -7,6 +7,8 @@ import os ...@@ -7,6 +7,8 @@ import os
import os.path import os.path
__all__ = [] __all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages( ...@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_'):
global_import(module_name) global_import(module_name)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: bsds500.py # File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, glob import os
import glob
import cv2 import cv2
import numpy as np import numpy as np
...@@ -21,6 +22,7 @@ except ImportError: ...@@ -21,6 +22,7 @@ except ImportError:
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321 IMG_W, IMG_H = 481, 321
class BSDS500(RNGDataFlow): class BSDS500(RNGDataFlow):
""" """
`Berkeley Segmentation Data Set and Benchmarks 500 `Berkeley Segmentation Data Set and Benchmarks 500
...@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow): ...@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
im = cv2.imread(f, cv2.IMREAD_COLOR) im = cv2.imread(f, cv2.IMREAD_COLOR)
assert im is not None assert im is not None
if im.shape[0] > im.shape[1]: if im.shape[0] > im.shape[1]:
im = np.transpose(im, (1,0,2)) im = np.transpose(im, (1, 0, 2))
assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W)) assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W))
imgid = os.path.basename(f).split('.')[0] imgid = os.path.basename(f).split('.')[0]
...@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow): ...@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
if __name__ == '__main__': if __name__ == '__main__':
a = BSDS500('val') a = BSDS500('val')
for k in a.get_data(): for k in a.get_data():
cv2.imshow("haha", k[1].astype('uint8')*255) cv2.imshow("haha", k[1].astype('uint8') * 255)
cv2.waitKey(1000) cv2.waitKey(1000)
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com> # Yukun Chen <cykustc@gmail.com>
import os, sys import os
import sys
import pickle import pickle
import numpy as np import numpy as np
import random import random
...@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100'] ...@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory, cifar_classnum): def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website. """Download and extract the tarball from Alex's website.
copied from tensorflow example """ copied from tensorflow example """
...@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum): ...@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
import tarfile import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar(filenames, cifar_classnum): def read_cifar(filenames, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
ret = [] ret = []
...@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum): ...@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum):
data = dic[b'data'] data = dic[b'data']
if cifar_classnum == 10: if cifar_classnum == 10:
label = dic[b'labels'] label = dic[b'labels']
IMG_NUM = 10000 # cifar10 data are split into blocks of 10000 IMG_NUM = 10000 # cifar10 data are split into blocks of 10000
elif cifar_classnum == 100: elif cifar_classnum == 100:
label = dic[b'fine_labels'] label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000 IMG_NUM = 50000 if 'train' in fname else 10000
...@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum): ...@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
ret.append([img, label[k]]) ret.append([img, label[k]])
return ret return ret
def get_filenames(dir, cifar_classnum): def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10: if cifar_classnum == 10:
...@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum): ...@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
os.path.join(dir, 'cifar-100-python', 'test')] os.path.join(dir, 'cifar-100-python', 'test')]
return filenames return filenames
class CifarBase(RNGDataFlow): class CifarBase(RNGDataFlow):
""" """
Return [image, label], Return [image, label],
image is 32x32x3 in the range [0,255] image is 32x32x3 in the range [0,255]
""" """
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10): def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
""" """
Args: Args:
...@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow): ...@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
return three values as mean of each channel return three values as mean of each channel
""" """
mean = self.get_per_pixel_mean() mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1)) return np.mean(mean, axis=(0, 1))
class Cifar10(CifarBase): class Cifar10(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10) super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase): class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100) super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
...@@ -149,7 +159,6 @@ if __name__ == '__main__': ...@@ -149,7 +159,6 @@ if __name__ == '__main__':
print(mean) print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100) dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data(): # for (img, label) in ds.get_data():
#from IPython import embed; embed() # from IPython import embed; embed()
#break # break
...@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12'] ...@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object): class ILSVRCMeta(object):
""" """
Some metadata for ILSVRC dataset. Some metadata for ILSVRC dataset.
""" """
def __init__(self, dir=None): def __init__(self, dir=None):
if dir is None: if dir is None:
dir = get_dataset_path('ilsvrc_metadata') dir = get_dataset_path('ilsvrc_metadata')
...@@ -82,14 +84,16 @@ class ILSVRCMeta(object): ...@@ -82,14 +84,16 @@ class ILSVRCMeta(object):
with open(mean_file, 'rb') as f: with open(mean_file, 'rb') as f:
obj.ParseFromString(f.read()) obj.ParseFromString(f.read())
arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32') arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32')
arr = np.transpose(arr, [1,2,0]) arr = np.transpose(arr, [1, 2, 0])
if size is not None: if size is not None:
arr = cv2.resize(arr, size[::-1]) arr = cv2.resize(arr, size[::-1])
return arr return arr
class ILSVRC12(RNGDataFlow): class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True, def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original', include_bb=False): dir_structure='original', include_bb=False):
""" """
:param dir: A directory containing a subdir named `name`, where the :param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed. original ILSVRC12_`name`.tar gets decompressed.
...@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow):
if include_bb: if include_bb:
bbdir = os.path.join(dir, 'bbox') if not \ bbdir = os.path.join(dir, 'bbox') if not \
isinstance(include_bb, six.string_types) else include_bb isinstance(include_bb, six.string_types) else include_bb
assert name == 'train', 'Bounding box only available for training' assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist) self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist)
self.include_bb = include_bb self.include_bb = include_bb
...@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow): ...@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
im = cv2.imread(fname.strip(), cv2.IMREAD_COLOR) im = cv2.imread(fname.strip(), cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2) im = np.expand_dims(im, 2).repeat(3, 2)
if self.include_bb: if self.include_bb:
bb = self.bblist[k] bb = self.bblist[k]
if bb is None: if bb is None:
bb = [0, 0, im.shape[1]-1, im.shape[0]-1] bb = [0, 0, im.shape[1] - 1, im.shape[0] - 1]
yield [im, label, bb] yield [im, label, bb]
else: else:
yield [im, label] yield [im, label]
...@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow): ...@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
if __name__ == '__main__': if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
#print(meta.get_synset_words_1000()) # print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True, ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True,
shuffle=False) shuffle=False)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds.get_data():
from IPython import embed; embed() from IPython import embed
embed()
break break
...@@ -17,6 +17,7 @@ __all__ = ['Mnist'] ...@@ -17,6 +17,7 @@ __all__ = ['Mnist']
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory): def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here.""" """Download the data from Yann's website, unless it's already here."""
filepath = os.path.join(work_directory, filename) filepath = os.path.join(work_directory, filename)
...@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory): ...@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory):
download(SOURCE_URL + filename, work_directory) download(SOURCE_URL + filename, work_directory)
return filepath return filepath
def _read32(bytestream): def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>') dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename): def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
magic = _read32(bytestream) magic = _read32(bytestream)
if magic != 2051: if magic != 2051:
raise ValueError( raise ValueError(
'Invalid magic number %d in MNIST image file: %s' % 'Invalid magic number %d in MNIST image file: %s' %
(magic, filename)) (magic, filename))
num_images = _read32(bytestream) num_images = _read32(bytestream)
rows = _read32(bytestream) rows = _read32(bytestream)
cols = _read32(bytestream) cols = _read32(bytestream)
...@@ -46,24 +49,27 @@ def extract_images(filename): ...@@ -46,24 +49,27 @@ def extract_images(filename):
data = data.astype('float32') / 255.0 data = data.astype('float32') / 255.0
return data return data
def extract_labels(filename): def extract_labels(filename):
"""Extract the labels into a 1D uint8 numpy array [index].""" """Extract the labels into a 1D uint8 numpy array [index]."""
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
magic = _read32(bytestream) magic = _read32(bytestream)
if magic != 2049: if magic != 2049:
raise ValueError( raise ValueError(
'Invalid magic number %d in MNIST label file: %s' % 'Invalid magic number %d in MNIST label file: %s' %
(magic, filename)) (magic, filename))
num_items = _read32(bytestream) num_items = _read32(bytestream)
buf = bytestream.read(num_items) buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8) labels = numpy.frombuffer(buf, dtype=numpy.uint8)
return labels return labels
class Mnist(RNGDataFlow): class Mnist(RNGDataFlow):
""" """
Return [image, label], Return [image, label],
image is 28x28 in the range [0,1] image is 28x28 in the range [0,1]
""" """
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
Args: Args:
...@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow): ...@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist('train') ds = Mnist('train')
for (img, label) in ds.get_data(): for (img, label) in ds.get_data():
from IPython import embed; embed() from IPython import embed
embed()
break break
...@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra ...@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt' VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt' TEST_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@memoized_ignoreargs @memoized_ignoreargs
def get_PennTreeBank(data_dir=None): def get_PennTreeBank(data_dir=None):
if data_dir is None: if data_dir is None:
...@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None): ...@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None):
# TODO these functions in TF might not be available in the future # TODO these functions in TF might not be available in the future
word_to_id = tfreader._build_vocab(os.path.join(data_dir, 'ptb.train.txt')) word_to_id = tfreader._build_vocab(os.path.join(data_dir, 'ptb.train.txt'))
data3 = [np.asarray(tfreader._file_to_word_ids(os.path.join(data_dir, fname), word_to_id)) data3 = [np.asarray(tfreader._file_to_word_ids(os.path.join(data_dir, fname), word_to_id))
for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']] for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']]
return data3, word_to_id return data3, word_to_id
...@@ -19,6 +19,7 @@ except ImportError: ...@@ -19,6 +19,7 @@ except ImportError:
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(RNGDataFlow): class SVHNDigit(RNGDataFlow):
""" """
SVHN Cropped Digit Dataset. SVHN Cropped Digit Dataset.
...@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow): ...@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow):
assert name in ['train', 'test', 'extra'], name assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \ assert os.path.isfile(filename), \
"File {} not found! Please download it from {}.".format(filename, SVHN_URL) "File {} not found! Please download it from {}.".format(filename, SVHN_URL)
logger.info("Loading {} ...".format(filename)) logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename) data = scipy.io.loadmat(filename)
self.X = data['X'].transpose(3,0,1,2) self.X = data['X'].transpose(3, 0, 1, 2)
self.Y = data['y'].reshape((-1)) self.Y = data['y'].reshape((-1))
self.Y[self.Y==10] = 0 self.Y[self.Y == 10] = 0
SVHNDigit._Cache[name] = (self.X, self.Y) SVHNDigit._Cache[name] = (self.X, self.Y)
def size(self): def size(self):
......
...@@ -12,6 +12,7 @@ import json ...@@ -12,6 +12,7 @@ import json
__all__ = ['VisualQA'] __all__ = ['VisualQA']
def read_json(fname): def read_json(fname):
f = open(fname) f = open(fname)
ret = json.load(f) ret = json.load(f)
...@@ -19,11 +20,14 @@ def read_json(fname): ...@@ -19,11 +20,14 @@ def read_json(fname):
return ret return ret
# TODO shuffle # TODO shuffle
class VisualQA(DataFlow): class VisualQA(DataFlow):
""" """
Visual QA dataset. See http://visualqa.org/ Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format. Simply read q/a json file and produce q/a pairs in their original format.
""" """
def __init__(self, question_file, annotation_file): def __init__(self, question_file, annotation_file):
with timed_operation('Reading VQA JSON file'): with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file])) qobj, aobj = list(map(read_json, [question_file, annotation_file]))
...@@ -62,7 +66,7 @@ class VisualQA(DataFlow): ...@@ -62,7 +66,7 @@ class VisualQA(DataFlow):
""" Get the n most common words in questions """ Get the n most common words in questions
n=4600 ~= thresh 6 n=4600 ~= thresh 6
""" """
from nltk.tokenize import word_tokenize # will need to download 'punckt' from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter() cnt = Counter()
for q in self.questions: for q in self.questions:
cnt.update(word_tokenize(q['question'].lower())) cnt.update(word_tokenize(q['question'].lower()))
...@@ -72,7 +76,7 @@ class VisualQA(DataFlow): ...@@ -72,7 +76,7 @@ class VisualQA(DataFlow):
if __name__ == '__main__': if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json', vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json') '/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data(): for k in vqa.get_data():
print(json.dumps(k)) print(json.dumps(k))
break break
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# File: dftools.py # File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import sys, os import sys
import os
import cv2 import cv2
import multiprocessing as mp import multiprocessing as mp
import six import six
...@@ -23,6 +24,8 @@ else: ...@@ -23,6 +24,8 @@ else:
__all__.extend(['dump_dataflow_to_lmdb']) __all__.extend(['dump_dataflow_to_lmdb'])
# TODO pass a name_func to write label as filename? # TODO pass a name_func to write label as filename?
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" Dump images from a `DataFlow` to a directory. """ Dump images from a `DataFlow` to a directory.
...@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img = dp[index] img = dp[index]
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img) cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path): def dump_dataflow_to_lmdb(ds, lmdb_path):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index """ Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint. and the data is the serialized datapoint.
...@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path): ...@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
assert not os.path.isfile(lmdb_path), "LMDB file exists!" assert not os.path.isfile(lmdb_path), "LMDB file exists!"
ds.reset_state() ds.reset_state()
db = lmdb.open(lmdb_path, subdir=isdir, db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False, map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end meminit=False, map_async=True) # need sync() at the end
try: try:
sz = ds.size() sz = ds.size()
except NotImplementedError: except NotImplementedError:
...@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer): ...@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
the queue once you start it. Each element is (task_id, dp). the queue once you start it. Each element is (task_id, dp).
""" """
q = mp.Queue(size) q = mp.Queue(size)
class EnqueProc(mp.Process): class EnqueProc(mp.Process):
def __init__(self, ds, q, nr_consumer): def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__() super(EnqueProc, self).__init__()
self.ds = ds self.ds = ds
...@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer): ...@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc = EnqueProc(ds, q, nr_consumer) proc = EnqueProc(ds, q, nr_consumer)
return q, proc return q, proc
...@@ -40,10 +40,13 @@ Adapters for different data format. ...@@ -40,10 +40,13 @@ Adapters for different data format.
""" """
# TODO lazy load # TODO lazy load
class HDF5Data(RNGDataFlow): class HDF5Data(RNGDataFlow):
""" """
Zip data from different paths in an HDF5 file. Will load all data into memory. Zip data from different paths in an HDF5 file. Will load all data into memory.
""" """
def __init__(self, filename, data_paths, shuffle=True): def __init__(self, filename, data_paths, shuffle=True):
""" """
:param filename: h5 data file. :param filename: h5 data file.
...@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow): ...@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
logger.info("Loading {} to memory...".format(filename)) logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths] self.dps = [self.f[k].value for k in data_paths]
lens = [len(k) for k in self.dps] lens = [len(k) for k in self.dps]
assert all([k==lens[0] for k in lens]) assert all([k == lens[0] for k in lens])
self._size = lens[0] self._size = lens[0]
self.shuffle = shuffle self.shuffle = shuffle
...@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow): ...@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow): class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """ """ Read a lmdb and produce k,v pair """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
self._lmdb_path = lmdb_path self._lmdb_path = lmdb_path
self._shuffle = shuffle self._shuffle = shuffle
...@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow): ...@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow):
def open_lmdb(self): def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path, self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path), subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False, readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100) map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
if self._shuffle: if self._shuffle:
...@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow): ...@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
v = self._txn.get(k) v = self._txn.get(k)
yield [k, v] yield [k, v]
class LMDBDataDecoder(LMDBData): class LMDBDataDecoder(LMDBData):
def __init__(self, lmdb_path, decoder, shuffle=True): def __init__(self, lmdb_path, decoder, shuffle=True):
""" """
:param decoder: a function taking k, v and return a data point, :param decoder: a function taking k, v and return a data point,
...@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData): ...@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
def get_data(self): def get_data(self):
for dp in super(LMDBDataDecoder, self).get_data(): for dp in super(LMDBDataDecoder, self).get_data():
v = self.decoder(dp[0], dp[1]) v = self.decoder(dp[0], dp[1])
if v: yield v if v:
yield v
class LMDBDataPoint(LMDBDataDecoder): class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file where each value is a serialized datapoint""" """ Read a LMDB file where each value is a serialized datapoint"""
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
super(LMDBDataPoint, self).__init__( super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle) lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
class CaffeLMDB(LMDBDataDecoder): class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """ """ Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
cpb = get_caffe_pb() cpb = get_caffe_pb()
def decoder(k, v): def decoder(k, v):
try: try:
datum = cpb.Datum() datum = cpb.Datum()
...@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder): ...@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder):
return [img.transpose(1, 2, 0), datum.label] return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__( super(CaffeLMDB, self).__init__(
lmdb_path, decoder=decoder, shuffle=shuffle) lmdb_path, decoder=decoder, shuffle=shuffle)
class SVMLightData(RNGDataFlow): class SVMLightData(RNGDataFlow):
""" Read X,y from a svmlight file """ """ Read X,y from a svmlight file """
def __init__(self, filename, shuffle=True): def __init__(self, filename, shuffle=True):
self.X, self.y = sklearn.datasets.load_svmlight_file(filename) self.X, self.y = sklearn.datasets.load_svmlight_file(filename)
self.X = np.asarray(self.X.todense()) self.X = np.asarray(self.X.todense())
...@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow): ...@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for id in idxs: for id in idxs:
yield [self.X[id,:], self.y[id]] yield [self.X[id, :], self.y[id]]
...@@ -11,7 +11,9 @@ from .imgaug import AugmentorList ...@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class ImageFromFile(RNGDataFlow): class ImageFromFile(RNGDataFlow):
def __init__(self, files, channel=3, resize=None, shuffle=False): def __init__(self, files, channel=3, resize=None, shuffle=False):
""" """
Generate images of 1 channel or 3 channels (in RGB order) from list of files. Generate images of 1 channel or 3 channels (in RGB order) from list of files.
...@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow): ...@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
if self.resize is not None: if self.resize is not None:
im = cv2.resize(im, self.resize[::-1]) im = cv2.resize(im, self.resize[::-1])
if self.channel == 1: if self.channel == 1:
im = im[:,:,np.newaxis] im = im[:, :, np.newaxis]
yield [im] yield [im]
class AugmentImageComponent(MapDataComponent): class AugmentImageComponent(MapDataComponent):
def __init__(self, ds, augmentors, index=0): def __init__(self, ds, augmentors, index=0):
""" """
Augment the image component of datapoints Augment the image component of datapoints
...@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent): ...@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
class AugmentImageComponents(MapData): class AugmentImageComponents(MapData):
def __init__(self, ds, augmentors, index=(0,1)):
def __init__(self, ds, augmentors, index=(0, 1)):
""" Augment a list of images of the same shape, with the same parameters """ Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order. :param augmentors: a list of `ImageAugmentor` instance to be applied in order.
......
...@@ -7,6 +7,7 @@ from pkgutil import walk_packages ...@@ -7,6 +7,7 @@ from pkgutil import walk_packages
__all__ = [] __all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages( ...@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_'):
global_import(module_name) global_import(module_name)
...@@ -10,15 +10,15 @@ from .crop import * ...@@ -10,15 +10,15 @@ from .crop import *
from .imgproc import * from .imgproc import *
from .noname import * from .noname import *
from .deform import * from .deform import *
from .noise import SaltPepperNoise from .noise import SaltPepperNoise
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([ augmentors = AugmentorList([
Contrast((0.8,1.2)), Contrast((0.8, 1.2)),
Flip(horiz=True), Flip(horiz=True),
GaussianDeform(anchors, (360,480), 0.2, randrange=20), GaussianDeform(anchors, (360, 480), 0.2, randrange=20),
#RandomCropRandomShape(0.3), # RandomCropRandomShape(0.3),
SaltPepperNoise() SaltPepperNoise()
]) ])
......
...@@ -9,6 +9,7 @@ from six.moves import zip ...@@ -9,6 +9,7 @@ from six.moves import zip
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList'] __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Augmentor(object): class Augmentor(object):
""" Base class for an augmentor""" """ Base class for an augmentor"""
...@@ -58,7 +59,9 @@ class Augmentor(object): ...@@ -58,7 +59,9 @@ class Augmentor(object):
size = [] size = []
return self.rng.uniform(low, high, size) return self.rng.uniform(low, high, size)
class ImageAugmentor(Augmentor): class ImageAugmentor(Augmentor):
def augment(self, img): def augment(self, img):
""" """
Perform augmentation on the image in-place. Perform augmentation on the image in-place.
...@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor): ...@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
return coord return coord
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
Augment by a list of augmentors Augment by a list of augmentors
""" """
def __init__(self, augmentors): def __init__(self, augmentors):
""" """
:param augmentors: list of `ImageAugmentor` instance to be applied :param augmentors: list of `ImageAugmentor` instance to be applied
...@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor): ...@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """ """ Will reset state of each augmentor """
for a in self.augs: for a in self.augs:
a.reset_state() a.reset_state()
...@@ -10,10 +10,12 @@ from six.moves import range ...@@ -10,10 +10,12 @@ from six.moves import range
import numpy as np import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', __all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop',
'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox'] 'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
:param crop_shape: a shape like (h, w) :param crop_shape: a shape like (h, w)
...@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor): ...@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
orig_shape = img.shape orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \ assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape and orig_shape[1] >= self.crop_shape[1], orig_shape
diffh = orig_shape[0] - self.crop_shape[0] diffh = orig_shape[0] - self.crop_shape[0]
h0 = 0 if diffh == 0 else self.rng.randint(diffh) h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1] diffw = orig_shape[1] - self.crop_shape[1]
...@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor): ...@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
def _augment(self, img, param): def _augment(self, img, param):
h0, w0 = param h0, w0 = param
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]] return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class CenterCrop(ImageAugmentor): class CenterCrop(ImageAugmentor):
""" Crop the image at the center""" """ Crop the image at the center"""
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
:param crop_shape: a shape like (h, w) :param crop_shape: a shape like (h, w)
...@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor): ...@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
orig_shape = img.shape orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5) h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5) w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]] return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class FixedCrop(ImageAugmentor): class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location""" """ Crop a rectangle at a given location"""
def __init__(self, rect): def __init__(self, rect):
""" """
Two arguments defined the range in both axes to crop, min inclued, max excluded. Two arguments defined the range in both axes to crop, min inclued, max excluded.
...@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor): ...@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor):
def _augment(self, img, _): def _augment(self, img, _):
orig_shape = img.shape orig_shape = img.shape
return img[self.rect.y0: self.rect.y1+1, return img[self.rect.y0: self.rect.y1 + 1,
self.rect.x0: self.rect.x0+1] self.rect.x0: self.rect.x0 + 1]
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
def perturb_BB(image_shape, bb, max_pertub_pixel, def perturb_BB(image_shape, bb, max_pertub_pixel,
rng=None, max_aspect_ratio_diff=0.3, rng=None, max_aspect_ratio_diff=0.3,
max_try=100): max_try=100):
""" """
Perturb a bounding box. Perturb a bounding box.
:param image_shape: [h, w] :param image_shape: [h, w]
...@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
""" """
Crop a box around a bounding box Crop a box around a bounding box
""" """
def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3): def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3):
""" """
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)] :param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
...@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
shape = img.shape[:2] shape = img.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1) box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1]) dist = self.perturb_ratio * np.sqrt(shape[0] * shape[1])
newbox = perturb_BB(shape, box, dist, newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff) self.rng, self.max_aspect_ratio_diff)
return newbox return newbox
def _augment(self, img, newbox): def _augment(self, img, newbox):
...@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor):
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class RandomCropRandomShape(ImageAugmentor): class RandomCropRandomShape(ImageAugmentor):
def __init__(self, wmin, hmin, def __init__(self, wmin, hmin,
wmax=None, hmax=None, wmax=None, hmax=None,
max_aspect_ratio=None): max_aspect_ratio=None):
""" """
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive). Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape. If max is None, will use the input image shape.
...@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
hmax = self.hmax or img.shape[0] hmax = self.hmax or img.shape[0]
wmax = self.wmax or img.shape[1] wmax = self.wmax or img.shape[1]
h = self.rng.randint(self.hmin, hmax+1) h = self.rng.randint(self.hmin, hmax + 1)
w = self.rng.randint(self.wmin, wmax+1) w = self.rng.randint(self.wmin, wmax + 1)
diffh = img.shape[0] - h diffh = img.shape[0] - h
diffw = img.shape[1] - w diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0 assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh) y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw) x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0,x0,h,w) return (y0, x0, h, w)
def _augment(self, img, param): def _augment(self, img, param):
y0, x0, h, w = param y0, x0, h, w = param
return img[y0:y0+h,x0:x0+w] return img[y0:y0 + h, x0:x0 + w]
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
...@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap'] ...@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup # TODO really needs speedup
class GaussianMap(object): class GaussianMap(object):
""" Generate gaussian weighted deformation map""" """ Generate gaussian weighted deformation map"""
def __init__(self, image_shape, sigma=0.5): def __init__(self, image_shape, sigma=0.5):
assert len(image_shape) == 2 assert len(image_shape) == 2
self.shape = image_shape self.shape = image_shape
...@@ -25,17 +27,18 @@ class GaussianMap(object): ...@@ -25,17 +27,18 @@ class GaussianMap(object):
x = x.astype('float32') / ret.shape[1] - anchor[1] x = x.astype('float32') / ret.shape[1] - anchor[1]
g = np.exp(-(x**2 + y ** 2) / self.sigma) g = np.exp(-(x**2 + y ** 2) / self.sigma)
#cv2.imshow(" ", g) #cv2.imshow(" ", g)
#cv2.waitKey() # cv2.waitKey()
return g return g
def np_sample(img, coords): def np_sample(img, coords):
# a numpy implementation of ImageSample layer # a numpy implementation of ImageSample layer
coords = np.maximum(coords, 0) coords = np.maximum(coords, 0)
coords = np.minimum(coords, np.array([img.shape[0]-1, img.shape[1]-1])) coords = np.minimum(coords, np.array([img.shape[0] - 1, img.shape[1] - 1]))
lcoor = np.floor(coords).astype('int32') lcoor = np.floor(coords).astype('int32')
ucoor = lcoor + 1 ucoor = lcoor + 1
ucoor = np.minimum(ucoor, np.array([img.shape[0]-1, img.shape[1]-1])) ucoor = np.minimum(ucoor, np.array([img.shape[0] - 1, img.shape[1] - 1]))
diff = coords - lcoor diff = coords - lcoor
neg_diff = 1.0 - diff neg_diff = 1.0 - diff
...@@ -46,17 +49,20 @@ def np_sample(img, coords): ...@@ -46,17 +49,20 @@ def np_sample(img, coords):
diffy, diffx = np.split(diff, 2, axis=2) diffy, diffx = np.split(diff, 2, axis=2)
ndiffy, ndiffx = np.split(neg_diff, 2, axis=2) ndiffy, ndiffx = np.split(neg_diff, 2, axis=2)
ret = img[lcoory,lcoorx,:] * ndiffx * ndiffy + \ ret = img[lcoory, lcoorx, :] * ndiffx * ndiffy + \
img[ucoory, ucoorx,:] * diffx * diffy + \ img[ucoory, ucoorx, :] * diffx * diffy + \
img[lcoory, ucoorx,:] * ndiffy * diffx + \ img[lcoory, ucoorx, :] * ndiffy * diffx + \
img[ucoory,lcoorx,:] * diffy * ndiffx img[ucoory, lcoorx, :] * diffy * ndiffx
return ret[:,:,0,:] return ret[:, :, 0, :]
# TODO input/output with different shape # TODO input/output with different shape
class GaussianDeform(ImageAugmentor): class GaussianDeform(ImageAugmentor):
""" """
Some kind of deformation. Quite slow. Some kind of deformation. Quite slow.
""" """
def __init__(self, anchors, shape, sigma=0.5, randrange=None): def __init__(self, anchors, shape, sigma=0.5, randrange=None):
""" """
:param anchors: in [0,1] coordinate :param anchors: in [0,1] coordinate
...@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor): ...@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
self.anchors = anchors self.anchors = anchors
self.K = len(self.anchors) self.K = len(self.anchors)
self.shape = shape self.shape = shape
self.grid = np.mgrid[0:self.shape[0], 0:self.shape[1]].transpose(1,2,0) self.grid = np.mgrid[0:self.shape[0], 0:self.shape[1]].transpose(1, 2, 0)
self.grid = self.grid.astype('float32') # HxWx2 self.grid = self.grid.astype('float32') # HxWx2
gm = GaussianMap(self.shape, sigma=sigma) gm = GaussianMap(self.shape, sigma=sigma)
self.gws = np.array([gm.get_gaussian_weight(ank) self.gws = np.array([gm.get_gaussian_weight(ank)
for ank in self.anchors], dtype='float32') # KxHxW for ank in self.anchors], dtype='float32') # KxHxW
self.gws = self.gws.transpose(1, 2, 0) #HxWxK self.gws = self.gws.transpose(1, 2, 0) # HxWxK
if randrange is None: if randrange is None:
self.randrange = self.shape[0] / 8 self.randrange = self.shape[0] / 8
else: else:
......
...@@ -10,11 +10,13 @@ import numpy as np ...@@ -10,11 +10,13 @@ import numpy as np
__all__ = ['Rotation', 'RotationAndCropValid'] __all__ = ['Rotation', 'RotationAndCropValid']
class Rotation(ImageAugmentor): class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0,1),
interp=cv2.INTER_CUBIC, def __init__(self, max_deg, center_range=(0, 1),
border=cv2.BORDER_REPLICATE): interp=cv2.INTER_CUBIC,
border=cv2.BORDER_REPLICATE):
""" """
:param max_deg: max abs value of the rotation degree :param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center :param center_range: the location of the rotation center
...@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor): ...@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
center = img.shape[1::-1] * self._rand_range( center = img.shape[1::-1] * self._rand_range(
self.center_range[0], self.center_range[1], (2,)) self.center_range[0], self.center_range[1], (2,))
deg = self._rand_range(-self.max_deg, self.max_deg) deg = self._rand_range(-self.max_deg, self.max_deg)
return cv2.getRotationMatrix2D(tuple(center), deg, 1) return cv2.getRotationMatrix2D(tuple(center), deg, 1)
def _augment(self, img, rot_m): def _augment(self, img, rot_m):
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=self.border) flags=self.interp, borderMode=self.border)
return ret return ret
class RotationAndCropValid(ImageAugmentor): class RotationAndCropValid(ImageAugmentor):
""" Random rotate and crop the largest possible rect without the border """ Random rotate and crop the largest possible rect without the border
This will produce images of different shapes. This will produce images of different shapes.
""" """
def __init__(self, max_deg, interp=cv2.INTER_CUBIC): def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
super(RotationAndCropValid, self).__init__() super(RotationAndCropValid, self).__init__()
self._init(locals()) self._init(locals())
...@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor):
return deg return deg
def _augment(self, img, deg): def _augment(self, img, deg):
center = (img.shape[1]*0.5, img.shape[0]*0.5) center = (img.shape[1] * 0.5, img.shape[0] * 0.5)
rot_m = cv2.getRotationMatrix2D(center, deg, 1) rot_m = cv2.getRotationMatrix2D(center, deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=cv2.BORDER_CONSTANT) flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1]) neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0]) newh = min(newh, ret.shape[0])
newx = int(center[0] - neww * 0.5) newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5) newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh) #print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy+newh,newx:newx+neww] return ret[newy:newy + newh, newx:newx + neww]
@staticmethod @staticmethod
def largest_rotated_rect(w, h, angle): def largest_rotated_rect(w, h, angle):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """ """ http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle = angle / 180.0 * math.pi angle = angle / 180.0 * math.pi
if w <= 0 or h <= 0: if w <= 0 or h <= 0:
return 0,0 return 0, 0
width_is_longer = w >= h width_is_longer = w >= h
side_long, side_short = (w,h) if width_is_longer else (h,w) side_long, side_short = (w, h) if width_is_longer else (h, w)
# since the solutions for angle, -angle and 180-angle are all the same, # since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos: # if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
if side_short <= 2.*sin_a*cos_a*side_long: if side_short <= 2. * sin_a * cos_a * side_long:
# half constrained case: two crop corners touch the longer side, # half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line # the other two corners are on the mid-line parallel to the longer line
x = 0.5*side_short x = 0.5 * side_short
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
else: else:
# fully constrained case: crop touches all 4 sides # fully constrained case: crop touches all 4 sides
cos_2a = cos_a*cos_a - sin_a*sin_a cos_2a = cos_a * cos_a - sin_a * sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
return int(wr), int(hr) return int(wr), int(hr)
...@@ -7,12 +7,14 @@ import numpy as np ...@@ -7,12 +7,14 @@ import numpy as np
import cv2 import cv2
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
'Gamma', 'Clip', 'Saturation', 'Lighting'] 'Gamma', 'Clip', 'Saturation', 'Lighting']
class Brightness(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
Random adjust brightness. Random adjust brightness.
""" """
def __init__(self, delta, clip=True): def __init__(self, delta, clip=True):
""" """
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True. Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor): ...@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img return img
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
""" """
Apply x = (x - mean) * contrast_factor + mean to each channel Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255] and clip to [0, 255]
""" """
def __init__(self, factor_range, clip=True): def __init__(self, factor_range, clip=True):
""" """
:param factor_range: an interval to random sample the `contrast_factor`. :param factor_range: an interval to random sample the `contrast_factor`.
...@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor): ...@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
return self._rand_range(*self.factor_range) return self._rand_range(*self.factor_range)
def _augment(self, img, r): def _augment(self, img, r):
mean = np.mean(img, axis=(0,1), keepdims=True) mean = np.mean(img, axis=(0, 1), keepdims=True)
img = (img - mean) * r + mean img = (img - mean) * r + mean
if self.clip: if self.clip:
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img return img
class MeanVarianceNormalize(ImageAugmentor): class MeanVarianceNormalize(ImageAugmentor):
""" """
Linearly scales image to have zero mean and unit norm. Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels)) where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
""" """
def __init__(self, all_channel=True): def __init__(self, all_channel=True):
""" """
:param all_channel: if True, normalize all channels together. else separately. :param all_channel: if True, normalize all channels together. else separately.
...@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
mean = np.mean(img) mean = np.mean(img)
std = np.std(img) std = np.std(img)
else: else:
mean = np.mean(img, axis=(0,1), keepdims=True) mean = np.mean(img, axis=(0, 1), keepdims=True)
std = np.std(img, axis=(0,1), keepdims=True) std = np.std(img, axis=(0, 1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape))) std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape)))
img = (img - mean) / std img = (img - mean) / std
return img return img
class GaussianBlur(ImageAugmentor): class GaussianBlur(ImageAugmentor):
def __init__(self, max_size=3): def __init__(self, max_size=3):
""":params max_size: (maximum kernel size-1)/2""" """:params max_size: (maximum kernel size-1)/2"""
super(GaussianBlur, self).__init__() super(GaussianBlur, self).__init__()
...@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor): ...@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor):
def _augment(self, img, s): def _augment(self, img, s):
return cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0, return cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
borderType=cv2.BORDER_REPLICATE) borderType=cv2.BORDER_REPLICATE)
class Gamma(ImageAugmentor): class Gamma(ImageAugmentor):
def __init__(self, range=(-0.5, 0.5)): def __init__(self, range=(-0.5, 0.5)):
super(Gamma, self).__init__() super(Gamma, self).__init__()
self._init(locals()) self._init(locals())
...@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor): ...@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
img = cv2.LUT(img, lut).astype('float32') img = cv2.LUT(img, lut).astype('float32')
return img return img
class Clip(ImageAugmentor): class Clip(ImageAugmentor):
def __init__(self, min=0, max=255): def __init__(self, min=0, max=255):
self._init(locals()) self._init(locals())
...@@ -117,7 +127,9 @@ class Clip(ImageAugmentor): ...@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
img = np.clip(img, self.min, self.max) img = np.clip(img, self.min, self.max)
return img return img
class Saturation(ImageAugmentor): class Saturation(ImageAugmentor):
def __init__(self, alpha=0.4): def __init__(self, alpha=0.4):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218 """ Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
""" """
...@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor): ...@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
def _augment(self, img, v): def _augment(self, img, v):
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img * v + (grey * (1 - v))[:,:,np.newaxis] return img * v + (grey * (1 - v))[:, :, np.newaxis]
class Lighting(ImageAugmentor): class Lighting(ImageAugmentor):
def __init__(self, std, eigval, eigvec): def __init__(self, std, eigval, eigvec):
""" Lighting noise. """ Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex` See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
...@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor): ...@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
eigval = np.asarray(eigval) eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec) eigvec = np.asarray(eigvec)
assert eigval.shape == (3,) assert eigval.shape == (3,)
assert eigvec.shape == (3,3) assert eigvec.shape == (3, 3)
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor): ...@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
inc = np.dot(self.eigvec, v).reshape((3,)) inc = np.dot(self.eigvec, v).reshape((3,))
img += inc img += inc
return img return img
...@@ -7,14 +7,18 @@ ...@@ -7,14 +7,18 @@
from .base import ImageAugmentor from .base import ImageAugmentor
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug', __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
'RandomOrderAug'] 'RandomOrderAug']
class Identity(ImageAugmentor): class Identity(ImageAugmentor):
def _augment(self, img, _): def _augment(self, img, _):
return img return img
class RandomApplyAug(ImageAugmentor): class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing""" """ Randomly apply the augmentor with a prob. Otherwise do nothing"""
def __init__(self, aug, prob): def __init__(self, aug, prob):
self._init(locals()) self._init(locals())
super(RandomApplyAug, self).__init__() super(RandomApplyAug, self).__init__()
...@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor): ...@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
else: else:
return self.aug._augment(img, prm[1]) return self.aug._augment(img, prm[1])
class RandomChooseAug(ImageAugmentor): class RandomChooseAug(ImageAugmentor):
def __init__(self, aug_lists): def __init__(self, aug_lists):
""" """
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple :param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
...@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor): ...@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
idx, prm = prm idx, prm = prm
return self.aug_lists[idx]._augment(img, prm) return self.aug_lists[idx]._augment(img, prm)
class RandomOrderAug(ImageAugmentor): class RandomOrderAug(ImageAugmentor):
def __init__(self, aug_lists): def __init__(self, aug_lists):
""" """
Shuffle the augmentors into random order. Shuffle the augmentors into random order.
...@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor): ...@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
img = self.aug_lists[k]._augment(img, prms[k]) img = self.aug_lists[k]._augment(img, prms[k])
return img return img
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
""" """
Map the image array by a function. Map the image array by a function.
""" """
def __init__(self, func): def __init__(self, func):
""" """
:param func: a function which takes a image array and return a augmented one :param func: a function which takes a image array and return a augmented one
...@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor): ...@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
def _augment(self, img, _): def _augment(self, img, _):
return self.func(img) return self.func(img)
...@@ -9,7 +9,9 @@ import cv2 ...@@ -9,7 +9,9 @@ import cv2
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise'] __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor): class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)): def __init__(self, quality_range=(40, 100)):
super(JpegNoise, self).__init__() super(JpegNoise, self).__init__()
self._init(locals()) self._init(locals())
...@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor): ...@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
def __init__(self, sigma=1, clip=True): def __init__(self, sigma=1, clip=True):
""" """
Add a gaussian noise N(0, sigma^2) of the same shape to img. Add a gaussian noise N(0, sigma^2) of the same shape to img.
...@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor): ...@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
ret = np.clip(ret, 0, 255) ret = np.clip(ret, 0, 255)
return ret return ret
class SaltPepperNoise(ImageAugmentor): class SaltPepperNoise(ImageAugmentor):
def __init__(self, white_prob=0.05, black_prob=0.05): def __init__(self, white_prob=0.05, black_prob=0.05):
""" Salt and pepper noise. """ Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels. Randomly set some elements in img to 0 or 255, regardless of its channels.
......
...@@ -10,10 +10,12 @@ import cv2 ...@@ -10,10 +10,12 @@ import cv2
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge'] __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
class Flip(ImageAugmentor): class Flip(ImageAugmentor):
""" """
Random flip. Random flip.
""" """
def __init__(self, horiz=False, vert=False, prob=0.5): def __init__(self, horiz=False, vert=False, prob=0.5):
""" """
Only one of horiz, vert can be set. Only one of horiz, vert can be set.
...@@ -45,8 +47,10 @@ class Flip(ImageAugmentor): ...@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
""" Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_CUBIC): def __init__(self, shape, interp=cv2.INTER_CUBIC):
""" """
:param shape: shape in (h, w) :param shape: shape in (h, w)
...@@ -59,13 +63,15 @@ class Resize(ImageAugmentor): ...@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
img, self.shape[::-1], img, self.shape[::-1],
interpolation=self.interp) interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
class ResizeShortestEdge(ImageAugmentor): class ResizeShortestEdge(ImageAugmentor):
""" Resize the shortest edge to a certain number while """ Resize the shortest edge to a certain number while
keeping the aspect ratio keeping the aspect ratio
""" """
def __init__(self, size): def __init__(self, size):
size = size * 1.0 size = size * 1.0
self._init(locals()) self._init(locals())
...@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor):
desSize = map(int, [scale * w, scale * h]) desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC) ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image""" """ randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC): def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
""" """
:param xrange: (min, max) scaling ratio :param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio :param yrange: (min, max) scaling ratio
...@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor): ...@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
def _augment(self, img, dsize): def _augment(self, img, dsize):
ret = cv2.resize(img, dsize, interpolation=self.interp) ret = cv2.resize(img, dsize, interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
...@@ -9,11 +9,12 @@ from abc import abstractmethod ...@@ -9,11 +9,12 @@ from abc import abstractmethod
import numpy as np import numpy as np
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller', __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste'] 'RandomPaste']
class BackgroundFiller(object): class BackgroundFiller(object):
""" Base class for all BackgroundFiller""" """ Base class for all BackgroundFiller"""
def fill(self, background_shape, img): def fill(self, background_shape, img):
""" """
Return a proper background image of background_shape, given img Return a proper background image of background_shape, given img
...@@ -28,8 +29,10 @@ class BackgroundFiller(object): ...@@ -28,8 +29,10 @@ class BackgroundFiller(object):
def _fill(self, background_shape, img): def _fill(self, background_shape, img):
pass pass
class ConstantBackgroundFiller(BackgroundFiller): class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """ """ Fill the background by a constant """
def __init__(self, value): def __init__(self, value):
""" """
:param value: the value to fill the background. :param value: the value to fill the background.
...@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller): ...@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
return_shape = background_shape return_shape = background_shape
return np.zeros(return_shape) + self.value return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor): class CenterPaste(ImageAugmentor):
""" """
Paste the image onto the center of a background canvas. Paste the image onto the center of a background canvas.
""" """
def __init__(self, background_shape, background_filler=None): def __init__(self, background_shape, background_filler=None):
""" """
:param background_shape: shape of the background canvas. :param background_shape: shape of the background canvas.
...@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor): ...@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
self.background_shape, img) self.background_shape, img)
y0 = int((self.background_shape[0] - img_shape[0]) * 0.5) y0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
x0 = int((self.background_shape[1] - img_shape[1]) * 0.5) x0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background return background
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class RandomPaste(CenterPaste): class RandomPaste(CenterPaste):
""" """
Randomly paste the image onto a background convas Randomly paste the image onto a background convas
""" """
def _get_augment_params(self, img): def _get_augment_params(self, img):
img_shape = img.shape[:2] img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
...@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste): ...@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
img_shape = img.shape[:2] img_shape = img.shape[:2]
background = self.background_filler.fill( background = self.background_filler.fill(
self.background_shape, img) self.background_shape, img)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background return background
...@@ -13,7 +13,7 @@ import os ...@@ -13,7 +13,7 @@ import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal) mask_sigint, start_proc_mask_signal)
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
...@@ -28,6 +28,7 @@ else: ...@@ -28,6 +28,7 @@ else:
class PrefetchProcess(mp.Process): class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True): def __init__(self, ds, queue, reset_after_spawn=True):
""" """
:param ds: ds to take data from :param ds: ds to take data from
...@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process): ...@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.queue.put(dp) self.queue.put(dp)
class PrefetchData(ProxyDataFlow): class PrefetchData(ProxyDataFlow):
""" """
Prefetch data from a `DataFlow` using multiprocessing Prefetch data from a `DataFlow` using multiprocessing
""" """
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_prefetch, nr_proc=1):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
...@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes # do nothing. all ds are reset once and only once in spawned processes
pass pass
def BlockParallel(ds, queue_size): def BlockParallel(ds, queue_size):
# TODO more doc # TODO more doc
""" """
...@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size): ...@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
""" """
return PrefetchData(ds, queue_size, 1) return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process): class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
...@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process): ...@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False) self.socket.send(dumps(dp), copy=False)
class PrefetchDataZMQ(ProxyDataFlow): class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """ """ Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_proc=1, pipedir=None): def __init__(self, ds, nr_proc=1, pipedir=None):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
...@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
except: except:
pass pass
class PrefetchOnGPUs(PrefetchDataZMQ): class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES """ Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable""" variable"""
def __init__(self, ds, gpus, pipedir=None): def __init__(self, ds, gpus, pipedir=None):
self.gpus = gpus self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir) super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
...@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ): ...@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for gpu, proc in zip(self.gpus, self.procs): for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu): with change_gpu(gpu):
proc.start() proc.start()
...@@ -17,8 +17,10 @@ except: ...@@ -17,8 +17,10 @@ except:
else: else:
__all__.append('DataFromSocket') __all__.append('DataFromSocket')
class FakeData(RNGDataFlow): class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes""" """ Generate fake fixed data of given shapes"""
def __init__(self, shapes, size, random=True, dtype='float32'): def __init__(self, shapes, size, random=True, dtype='float32'):
""" """
:param shapes: a list of lists/tuples :param shapes: a list of lists/tuples
...@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow): ...@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
for _ in range(self._size): for _ in range(self._size):
yield copy.deepcopy(v) yield copy.deepcopy(v)
class DataFromQueue(DataFlow): class DataFromQueue(DataFlow):
""" Produce data from a queue """ """ Produce data from a queue """
def __init__(self, queue): def __init__(self, queue):
self.queue = queue self.queue = queue
...@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow): ...@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
while True: while True:
yield self.queue.get() yield self.queue.get()
class DataFromList(RNGDataFlow): class DataFromList(RNGDataFlow):
""" Produce data from a list""" """ Produce data from a list"""
def __init__(self, lst, shuffle=True): def __init__(self, lst, shuffle=True):
super(DataFromList, self).__init__() super(DataFromList, self).__init__()
self.lst = lst self.lst = lst
...@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow): ...@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
for k in idxs: for k in idxs:
yield self.lst[k] yield self.lst[k]
class DataFromSocket(DataFlow): class DataFromSocket(DataFlow):
""" Produce data from a zmq socket""" """ Produce data from a zmq socket"""
def __init__(self, socket_name): def __init__(self, socket_name):
self._name = socket_name self._name = socket_name
...@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow): ...@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
yield dp yield dp
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=0)
...@@ -17,6 +17,7 @@ from .common import RepeatedData ...@@ -17,6 +17,7 @@ from .common import RepeatedData
from ..utils import logger from ..utils import logger
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
def serve_data(ds, addr): def serve_data(ds, addr):
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
...@@ -36,7 +37,9 @@ def serve_data(ds, addr): ...@@ -36,7 +37,9 @@ def serve_data(ds, addr):
if not ctx.closed: if not ctx.closed:
ctx.destroy(0) ctx.destroy(0)
class RemoteData(DataFlow): class RemoteData(DataFlow):
def __init__(self, addr): def __init__(self, addr):
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.PULL) self.socket = self.ctx.socket(zmq.PULL)
...@@ -54,7 +57,7 @@ if __name__ == '__main__': ...@@ -54,7 +57,7 @@ if __name__ == '__main__':
from .raw import FakeData from .raw import FakeData
addr = "tcp://127.0.0.1:8877" addr = "tcp://127.0.0.1:8877"
if sys.argv[1] == 'serve': if sys.argv[1] == 'serve':
ds = FakeData([(128,244,244,3)], 1000) ds = FakeData([(128, 244, 244, 3)], 1000)
serve_data(ds, addr) serve_data(ds, addr)
else: else:
ds = RemoteData(addr) ds = RemoteData(addr)
...@@ -62,4 +65,3 @@ if __name__ == '__main__': ...@@ -62,4 +65,3 @@ if __name__ == '__main__':
with tqdm(total=10000) as pbar: with tqdm(total=10000) as pbar:
for k in ds.get_data(): for k in ds.get_data():
pbar.update() pbar.update()
...@@ -14,9 +14,11 @@ except ImportError: ...@@ -14,9 +14,11 @@ except ImportError:
else: else:
__all__ = ['TFFuncMapper'] __all__ = ['TFFuncMapper']
class TFFuncMapper(ProxyDataFlow): class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds, def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'): get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
""" """
:param get_placeholders: a function returning the placeholders :param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders :param symbf: a symbolic function taking the placeholders
...@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow): ...@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow):
def run_func(vals): def run_func(vals):
return self.sess.run(self.output_vars, return self.sess.run(self.output_vars,
feed_dict=dict(zip(self.placeholders, vals))) feed_dict=dict(zip(self.placeholders, vals)))
self.run_func = run_func self.run_func = run_func
def get_data(self): def get_data(self):
...@@ -63,16 +65,16 @@ if __name__ == '__main__': ...@@ -63,16 +65,16 @@ if __name__ == '__main__':
v = tf.image.random_flip_left_right(v) v = tf.image.random_flip_left_right(v)
return v return v
ds = TFFuncMapper(ds, ds = TFFuncMapper(ds,
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')], lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug, tf_aug,
lambda dp, f: [f([dp[0]])[0]] lambda dp, f: [f([dp[0]])[0]]
) )
#ds = AugmentImageComponent(ds, # ds = AugmentImageComponent(ds,
#[imgaug.Brightness(0.1, clip=False), # [imgaug.Brightness(0.1, clip=False),
#imgaug.Contrast((0.8, 1.2), clip=False), # imgaug.Contrast((0.8, 1.2), clip=False),
#imgaug.Flip(horiz=True) # imgaug.Flip(horiz=True)
#]) # ])
#ds = PrefetchDataZMQ(ds, 4) # ds = PrefetchDataZMQ(ds, 4)
ds.reset_state() ds.reset_state()
import tqdm import tqdm
......
...@@ -12,6 +12,7 @@ from ..utils import logger ...@@ -12,6 +12,7 @@ from ..utils import logger
__all__ = ['LinearWrap'] __all__ = ['LinearWrap']
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -32,6 +33,7 @@ class LinearWrap(object): ...@@ -32,6 +33,7 @@ class LinearWrap(object):
""" """
class TFModuleFunc(object): class TFModuleFunc(object):
def __init__(self, mod, tensor): def __init__(self, mod, tensor):
self._mod = mod self._mod = mod
self._t = tensor self._t = tensor
...@@ -88,4 +90,3 @@ class LinearWrap(object): ...@@ -88,4 +90,3 @@ class LinearWrap(object):
def print_tensor(self): def print_tensor(self):
print(self._t) print(self._t)
return self return self
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import tensorflow as tf import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
import copy, os import copy
import os
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str from ..tfutils.modelutils import get_shape_str
...@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d ...@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _layer_logged = set()
def disable_layer_logging(): def disable_layer_logging():
class ContainEverything: class ContainEverything:
def __contains__(self, x): def __contains__(self, x):
return True return True
# can use nonlocal in python3, but how # can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything() globals()['_layer_logged'] = ContainEverything()
def layer_register( def layer_register(
summary_activation=False, summary_activation=False,
log_shape=True, log_shape=True,
...@@ -42,13 +46,13 @@ def layer_register( ...@@ -42,13 +46,13 @@ def layer_register(
def wrapped_func(*args, **kwargs): def wrapped_func(*args, **kwargs):
if use_scope: if use_scope:
name, inputs = args[0], args[1] name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func args = args[1:] # actual positional args used to call func
assert isinstance(name, six.string_types), name assert isinstance(name, six.string_types), name
else: else:
assert not log_shape and not summary_activation assert not log_shape and not summary_activation
if isinstance(args[0], six.string_types): if isinstance(args[0], six.string_types):
name, inputs = args[0], args[1] name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func args = args[1:] # actual positional args used to call func
else: else:
inputs = args[0] inputs = args[0]
name = None name = None
...@@ -97,13 +101,14 @@ def layer_register( ...@@ -97,13 +101,14 @@ def layer_register(
# need some special handling for sphinx to work with the arguments # need some special handling for sphinx to work with the arguments
on_doc = os.environ.get('READTHEDOCS') == 'True' \ on_doc = os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING') or os.environ.get('TENSORPACK_DOC_BUILDING')
if on_doc: if on_doc:
from decorator import decorator from decorator import decorator
wrapper = decorator(wrapper) wrapper = decorator(wrapper)
return wrapper return wrapper
def shape4d(a): def shape4d(a):
# for use with tensorflow NHWC ops # for use with tensorflow NHWC ops
return [1] + shape2d(a) + [1] return [1] + shape2d(a) + [1]
...@@ -7,7 +7,9 @@ import tensorflow as tf ...@@ -7,7 +7,9 @@ import tensorflow as tf
import numpy as np import numpy as np
import unittest import unittest
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def run_variable(self, var): def run_variable(self, var):
sess = tf.Session() sess = tf.Session()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
...@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase): ...@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
else: else:
return tf.Variable(args[0]) return tf.Variable(args[0])
def run_test_case(case): def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case) suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -34,5 +37,3 @@ if __name__ == '__main__': ...@@ -34,5 +37,3 @@ if __name__ == '__main__':
subs = tensorpack.models._test.TestModel.__subclasses__() subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs: for cls in subs:
run_test_case(cls) run_test_case(cls)
...@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2'] ...@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# decay: being too close to 1 leads to slow start-up. torch use 0.9. # decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
""" """
...@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out = shape[-1] # channel n_out = shape[-1] # channel
assert n_out is not None assert n_out is not None
beta = tf.get_variable('beta', [n_out], beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer()) initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
if len(shape) == 2: if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
...@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
#reuse = tf.get_variable_scope().reuse #reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False): with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused # BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics # TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile # how to use multiple tensors to update one EMA? seems impossbile
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
...@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean = tf.get_variable('mean/' + emaname, [n_out]) ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out]) ema_var = tf.get_variable('variance/' + emaname, [n_out])
else: else:
## use statistics in another tower # use statistics in another tower
G = tf.get_default_graph() G = tf.get_default_graph()
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0') ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0')
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0') ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0')
...@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, ema_mean, ema_var, beta, gamma, epsilon, 'output') x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
""" """
...@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x = tf.reshape(x, [-1, 1, 1, n_out]) x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out], beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer()) initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
# x * gamma + beta # x * gamma + beta
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger.warn("[BatchNorm] use_local_stat != is_training") logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out], moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out], moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True) epsilon=epsilon, is_training=True)
# maintain EMA only in the main training tower # maintain EMA only in the main training tower
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op') name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average( update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False, moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op') name='var_ema_op')
add_model_variable(moving_mean) add_model_variable(moving_mean)
add_model_variable(moving_var) add_model_variable(moving_var)
else: else:
...@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# consider some fixed-param tasks, such as load model and fine tune one layer # consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference # fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta, # xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var, # moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output') # epsilon=epsilon, is_training=False, name='output')
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
......
...@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d ...@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D']
@layer_register() @layer_register()
def Conv2D(x, out_channel, kernel_shape, def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1, padding='SAME', stride=1,
...@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)] for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs) conv = tf.concat(3, outputs)
if nl is None: if nl is None:
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.") logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
class StaticDynamicShape(object): class StaticDynamicShape(object):
def __init__(self, static, dynamic): def __init__(self, static, dynamic):
self.static = static self.static = static
self.dynamic = dynamic self.dynamic = dynamic
def apply(self, f): def apply(self, f):
try: try:
st = f(self.static) st = f(self.static)
...@@ -76,11 +81,12 @@ class StaticDynamicShape(object): ...@@ -76,11 +81,12 @@ class StaticDynamicShape(object):
except: except:
return StaticDynamicShape(None, f(self.dynamic)) return StaticDynamicShape(None, f(self.dynamic))
@layer_register() @layer_register()
def Deconv2D(x, out_shape, kernel_shape, def Deconv2D(x, out_shape, kernel_shape,
stride, padding='SAME', stride, padding='SAME',
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, use_bias=True): nl=tf.identity, use_bias=True):
""" """
2D deconvolution on 4D inputs. 2D deconvolution on 4D inputs.
......
...@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf ...@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
@layer_register() @layer_register()
def FullyConnected(x, out_dim, def FullyConnected(x, out_dim,
W_init=None, b_init=None, W_init=None, b_init=None,
...@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim, ...@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
b = tf.get_variable('b', [out_dim], initializer=b_init) b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W) prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None: if nl is None:
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.") logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu nl = tf.nn.relu
return nl(prod, name='output') return nl(prod, name='output')
...@@ -12,6 +12,8 @@ __all__ = ['ImageSample'] ...@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
# XXX TODO ugly. # XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing # really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206 # See github:tensorflow#418,#206
def sample(img, coords): def sample(img, coords):
""" """
:param img: bxhxwxc :param img: bxhxwxc
...@@ -33,14 +35,15 @@ def sample(img, coords): ...@@ -33,14 +35,15 @@ def sample(img, coords):
# bxh2xw2 # bxh2xw2
batch_add = tf.range(tf.shape(img)[0]) * (shape[0] * shape[1]) batch_add = tf.range(tf.shape(img)[0]) * (shape[0] * shape[1])
batch_add = tf.reshape(batch_add, [-1, 1, 1]) #bx1x1 batch_add = tf.reshape(batch_add, [-1, 1, 1]) # bx1x1
flat_coords = coords + batch_add flat_coords = coords + batch_add
img = tf.reshape(img, [-1, shape[2]]) #bhw x c img = tf.reshape(img, [-1, shape[2]]) # bhw x c
sampled = tf.gather(img, flat_coords) sampled = tf.gather(img, flat_coords)
return sampled return sampled
@layer_register() @layer_register()
def ImageSample(inputs, borderMode='repeat'): def ImageSample(inputs, borderMode='repeat'):
""" """
...@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'): ...@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'):
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4 assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
input_shape = template.get_shape().as_list()[1:] input_shape = template.get_shape().as_list()[1:]
assert None not in input_shape, \ assert None not in input_shape, \
"Images in ImageSample layer must have fully-defined shape" "Images in ImageSample layer must have fully-defined shape"
assert borderMode in ['repeat', 'constant'] assert borderMode in ['repeat', 'constant']
orig_mapping = mapping orig_mapping = mapping
...@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'): ...@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
ucoor = lcoor + 1 ucoor = lcoor + 1
diff = mapping - lcoor diff = mapping - lcoor
neg_diff = 1.0 - diff #bxh2xw2x2 neg_diff = 1.0 - diff # bxh2xw2x2
lcoory, lcoorx = tf.split(3, 2, lcoor) lcoory, lcoorx = tf.split(3, 2, lcoor)
ucoory, ucoorx = tf.split(3, 2, ucoor) ucoory, ucoorx = tf.split(3, 2, ucoor)
...@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'): ...@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'):
neg_diffy, neg_diffx = tf.split(3, 2, neg_diff) neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)
#prod = tf.reduce_prod(diff, 3, keep_dims=True) #prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod), # diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#tf.reduce_max(diff), diff], summarize=50) # tf.reduce_max(diff), diff], summarize=50)
ret = tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy, ret = tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy, sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx, sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled') sample(template, uylx) * diffy * neg_diffx], name='sampled')
if borderMode == 'constant': if borderMode == 'constant':
max_coor = tf.constant([input_shape[0] - 1, input_shape[1] - 1], dtype=tf.float32) max_coor = tf.constant([input_shape[0] - 1, input_shape[1] - 1], dtype=tf.float32)
mask = tf.greater_equal(orig_mapping, 0.0) mask = tf.greater_equal(orig_mapping, 0.0)
mask2 = tf.less_equal(orig_mapping, max_coor) mask2 = tf.less_equal(orig_mapping, max_coor)
mask = tf.logical_and(mask, mask2) #bxh2xw2x2 mask = tf.logical_and(mask, mask2) # bxh2xw2x2
mask = tf.reduce_all(mask, [3]) # bxh2xw2 boolean mask = tf.reduce_all(mask, [3]) # bxh2xw2 boolean
mask = tf.expand_dims(mask, 3) mask = tf.expand_dims(mask, 3)
ret = ret * tf.cast(mask, tf.float32) ret = ret * tf.cast(mask, tf.float32)
return ret return ret
from ._test import TestModel from ._test import TestModel
class TestSample(TestModel): class TestSample(TestModel):
def test_sample(self): def test_sample(self):
import numpy as np import numpy as np
h, w = 3, 4 h, w = 3, 4
def np_sample(img, coords): def np_sample(img, coords):
# a reference implementation # a reference implementation
coords = np.maximum(coords, 0) coords = np.maximum(coords, 0)
coords = np.minimum(coords, coords = np.minimum(coords,
np.array([img.shape[1]-1, img.shape[2]-1])) np.array([img.shape[1] - 1, img.shape[2] - 1]))
xs = coords[:,:,:,1].reshape((img.shape[0], -1)) xs = coords[:, :, :, 1].reshape((img.shape[0], -1))
ys = coords[:,:,:,0].reshape((img.shape[0], -1)) ys = coords[:, :, :, 0].reshape((img.shape[0], -1))
ret = np.zeros((img.shape[0], coords.shape[1], coords.shape[2], ret = np.zeros((img.shape[0], coords.shape[1], coords.shape[2],
img.shape[3]), dtype='float32') img.shape[3]), dtype='float32')
for k in range(img.shape[0]): for k in range(img.shape[0]):
xss, yss = xs[k], ys[k] xss, yss = xs[k], ys[k]
ret[k,:,:,:] = img[k,yss,xss,:].reshape((coords.shape[1], ret[k, :, :, :] = img[k, yss, xss, :].reshape((coords.shape[1],
coords.shape[2], 3)) coords.shape[2], 3))
return ret return ret
bimg = np.random.rand(2, h, w, 3).astype('float32') bimg = np.random.rand(2, h, w, 3).astype('float32')
#mat = np.array([ # mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]], #[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]] #[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2 #], dtype='float32') #2x2x2x2
mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3]) mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3])
true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32')) true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32'))
inp, mapping = self.make_variable(bimg, mat) inp, mapping = self.make_variable(bimg, mat)
output = sample(inp, tf.cast(tf.floor(mapping+0.5), tf.int32)) output = sample(inp, tf.cast(tf.floor(mapping + 0.5), tf.int32))
res = self.run_variable(output) res = self.run_variable(output)
self.assertTrue((res == true_res).all()) self.assertTrue((res == true_res).all())
...@@ -146,7 +153,7 @@ if __name__ == '__main__': ...@@ -146,7 +153,7 @@ if __name__ == '__main__':
diff = 200 diff = 200
for x in range(w): for x in range(w):
for y in range(h): for y in range(h):
mapping[0,y,x,:] = np.array([y-diff+0.4, x-diff+0.5]) mapping[0, y, x, :] = np.array([y - diff + 0.4, x - diff + 0.5])
mapv = tf.Variable(mapping) mapv = tf.Variable(mapping)
output = ImageSample('sample', [imv, mapv], borderMode='constant') output = ImageSample('sample', [imv, mapv], borderMode='constant')
...@@ -155,12 +162,10 @@ if __name__ == '__main__': ...@@ -155,12 +162,10 @@ if __name__ == '__main__':
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv)) #out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output) #out = sess.run(output)
#print(out[0].min()) # print(out[0].min())
#print(out[0].max()) # print(out[0].max())
#print(out[0].sum()) # print(out[0].sum())
out = sess.run([output])[0] out = sess.run([output])[0]
im = out[0] im = out[0]
cv2.imwrite('sampled.jpg', im) cv2.imwrite('sampled.jpg', im)
...@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names ...@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse']) #_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(object): class InputVar(object):
def __init__(self, type, shape, name, sparse=False): def __init__(self, type, shape, name, sparse=False):
self.type = type self.type = type
self.shape = shape self.shape = shape
self.name = name self.name = name
self.sparse = sparse self.sparse = sparse
def dumps(self): def dumps(self):
return pickle.dumps(self) return pickle.dumps(self)
@staticmethod @staticmethod
def loads(buf): def loads(buf):
return pickle.loads(buf) return pickle.loads(buf)
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
...@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i ...@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [#SummaryGradient(), return [ # SummaryGradient(),
CheckGradient() CheckGradient()
] ]
class ModelFromMetaGraph(ModelDesc): class ModelFromMetaGraph(ModelDesc):
""" """
Load the whole exact TF graph from a saved meta_graph. Load the whole exact TF graph from a saved meta_graph.
Only useful for inference. Only useful for inference.
""" """
def __init__(self, filename): def __init__(self, filename):
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES, for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys().VARIABLES]: tf.GraphKeys().VARIABLES]:
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
def _get_input_vars(self): def _get_input_vars(self):
col = tf.get_collection(INPUT_VARS_KEY) col = tf.get_collection(INPUT_VARS_KEY)
......
...@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm ...@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register() @layer_register()
def Maxout(x, num_unit): def Maxout(x, num_unit):
""" """
...@@ -31,6 +32,7 @@ def Maxout(x, num_unit): ...@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
x = tf.reshape(x, [-1, ch / num_unit, num_unit]) x = tf.reshape(x, [-1, ch / num_unit, num_unit])
return tf.reduce_max(x, ndim, name='output') return tf.reduce_max(x, ndim, name='output')
@layer_register(log_shape=False) @layer_register(log_shape=False)
def PReLU(x, init=tf.constant_initializer(0.001), name=None): def PReLU(x, init=tf.constant_initializer(0.001), name=None):
""" """
...@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None): ...@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name = 'output' name = 'output'
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
@layer_register(use_scope=False, log_shape=False) @layer_register(use_scope=False, log_shape=False)
def LeakyReLU(x, alpha, name=None): def LeakyReLU(x, alpha, name=None):
""" """
...@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None): ...@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
return tf.maximum(x, alpha * x, name=name) return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha) #alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) #x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name) # return tf.mul(x, 0.5, name=name)
@layer_register(log_shape=False, use_scope=False) @layer_register(log_shape=False, use_scope=False)
def BNReLU(x, name=None): def BNReLU(x, name=None):
......
...@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf ...@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
@layer_register() @layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'): def MaxPooling(x, shape, stride=None, padding='VALID'):
""" """
...@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'): ...@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register() @layer_register()
def AvgPooling(x, shape, stride=None, padding='VALID'): def AvgPooling(x, shape, stride=None, padding='VALID'):
""" """
...@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'): ...@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding) return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register() @layer_register()
def GlobalAvgPooling(x): def GlobalAvgPooling(x):
""" """
...@@ -65,6 +68,8 @@ def GlobalAvgPooling(x): ...@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
return tf.reduce_mean(x, [1, 2]) return tf.reduce_mean(x, [1, 2])
# https://github.com/tensorflow/tensorflow/issues/2169 # https://github.com/tensorflow/tensorflow/issues/2169
def UnPooling2x2ZeroFilled(x): def UnPooling2x2ZeroFilled(x):
out = tf.concat(3, [x, tf.zeros_like(x)]) out = tf.concat(3, [x, tf.zeros_like(x)])
out = tf.concat(2, [out, tf.zeros_like(out)]) out = tf.concat(2, [out, tf.zeros_like(out)])
...@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x): ...@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
ret.set_shape([None, None, None, sh[3]]) ret.set_shape([None, None, None, sh[3]])
return ret return ret
@layer_register() @layer_register()
def FixedUnPooling(x, shape, unpool_mat=None): def FixedUnPooling(x, shape, unpool_mat=None):
""" """
...@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
# perform a tensor-matrix kronecker product # perform a tensor-matrix kronecker product
fx = symbf.flatten(tf.transpose(x, [0, 3, 1, 2])) fx = symbf.flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = tf.expand_dims(fx, -1) # (bchw)x1 fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) #1x(shxsw) mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) # 1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw) prod = tf.matmul(fx, mat) # (bchw) x(shxsw)
prod = tf.reshape(prod, tf.pack( prod = tf.reshape(prod, tf.pack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]])) [-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1]) prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
...@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
[-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]])) [-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
return prod return prod
@layer_register() @layer_register()
def BilinearUpSample(x, shape): def BilinearUpSample(x, shape):
""" """
...@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape): ...@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
:param shape: an integer, the upsample factor :param shape: an integer, the upsample factor
""" """
#inp_shape = tf.shape(x) #inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x, # return tf.image.resize_bilinear(x,
#tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]), # tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#align_corners=True) # align_corners=True)
inp_shape = x.get_shape().as_list() inp_shape = x.get_shape().as_list()
ch = inp_shape[3] ch = inp_shape[3]
...@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape): ...@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
shape = int(shape) shape = int(shape)
filter_shape = 2 * shape filter_shape = 2 * shape
def bilinear_conv_filler(s): def bilinear_conv_filler(s):
""" """
s: width, height of the conv filter s: width, height of the conv filter
...@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape): ...@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
ret = np.zeros((s, s), dtype='float32') ret = np.zeros((s, s), dtype='float32')
for x in range(s): for x in range(s):
for y in range(s): for y in range(s):
ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) ret[x, y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
return ret return ret
w = bilinear_conv_filler(filter_shape) w = bilinear_conv_filler(filter_shape)
w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch)) w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch))
...@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape): ...@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape):
shape=(filter_shape, filter_shape, ch, ch), shape=(filter_shape, filter_shape, ch, ch),
name='bilinear_upsample_filter') name='bilinear_upsample_filter')
deconv = tf.nn.conv2d_transpose(x, weight_var, deconv = tf.nn.conv2d_transpose(x, weight_var,
tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32), tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32),
[1,shape,shape,1], 'SAME') [1, shape, shape, 1], 'SAME')
if inp_shape[1]: inp_shape[1] *= shape if inp_shape[1]:
if inp_shape[2]: inp_shape[2] *= shape inp_shape[1] *= shape
if inp_shape[2]:
inp_shape[2] *= shape
deconv.set_shape(inp_shape) deconv.set_shape(inp_shape)
return deconv return deconv
from ._test import TestModel from ._test import TestModel
class TestPool(TestModel): class TestPool(TestModel):
def test_fixed_unpooling(self): def test_fixed_unpooling(self):
h, w = 3, 4 h, w = 3, 4
mat = np.random.rand(h, w, 3).astype('float32') mat = np.random.rand(h, w, 3).astype('float32')
...@@ -173,13 +184,13 @@ class TestPool(TestModel): ...@@ -173,13 +184,13 @@ class TestPool(TestModel):
inp = tf.reshape(inp, [1, h, w, 3]) inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, 2) output = FixedUnPooling('unpool', inp, 2)
res = self.run_variable(output) res = self.run_variable(output)
self.assertEqual(res.shape, (1, 2*h, 2*w, 3)) self.assertEqual(res.shape, (1, 2 * h, 2 * w, 3))
# mat is on cornser # mat is on cornser
ele = res[0,::2,::2,0] ele = res[0, ::2, ::2, 0]
self.assertTrue((ele == mat[:,:,0]).all()) self.assertTrue((ele == mat[:, :, 0]).all())
# the rest are zeros # the rest are zeros
res[0,::2,::2,:] = 0 res[0, ::2, ::2, :] = 0
self.assertTrue((res == 0).all()) self.assertTrue((res == 0).all())
def test_upsample(self): def test_upsample(self):
...@@ -191,7 +202,7 @@ class TestPool(TestModel): ...@@ -191,7 +202,7 @@ class TestPool(TestModel):
inp = tf.reshape(inp, [1, h, w, 1]) inp = tf.reshape(inp, [1, h, w, 1])
output = BilinearUpSample('upsample', inp, scale) output = BilinearUpSample('upsample', inp, scale)
res = self.run_variable(output)[0,:,:,0] res = self.run_variable(output)[0, :, :, 0]
from skimage.transform import rescale from skimage.transform import rescale
res2 = rescale(mat, scale) res2 = rescale(mat, scale)
...@@ -199,9 +210,9 @@ class TestPool(TestModel): ...@@ -199,9 +210,9 @@ class TestPool(TestModel):
diff = np.abs(res2 - res) diff = np.abs(res2 - res)
# not equivalent to rescale on edge? # not equivalent to rescale on edge?
diff[0,:] = 0 diff[0, :] = 0
diff[:,0] = 0 diff[:, 0] = 0
if not diff.max() < 1e-4: if not diff.max() < 1e-4:
import IPython; import IPython
IPython.embed(config=IPython.terminal.ipapp.load_default_config()) IPython.embed(config=IPython.terminal.ipapp.load_default_config())
self.assertTrue(diff.max() < 1e-4) self.assertTrue(diff.max() < 1e-4)
...@@ -12,6 +12,7 @@ from ._common import layer_register ...@@ -12,6 +12,7 @@ from ._common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout'] __all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
@memoized @memoized
def _log_regularizer(name): def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name)) logger.info("Apply regularizer for {}".format(name))
...@@ -19,6 +20,7 @@ def _log_regularizer(name): ...@@ -19,6 +20,7 @@ def _log_regularizer(name):
l2_regularizer = tf.contrib.layers.l2_regularizer l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer l1_regularizer = tf.contrib.layers.l1_regularizer
def regularize_cost(regex, func, name=None): def regularize_cost(regex, func, name=None):
""" """
Apply a regularizer on every trainable variable matching the regex. Apply a regularizer on every trainable variable matching the regex.
...@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None): ...@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
keep_prob = tf.constant(keep_prob if is_training else 1.0) keep_prob = tf.constant(keep_prob if is_training else 1.0)
return tf.nn.dropout(x, keep_prob) return tf.nn.dropout(x, keep_prob)
...@@ -8,6 +8,7 @@ from ._common import layer_register ...@@ -8,6 +8,7 @@ from ._common import layer_register
__all__ = ['ConcatWith'] __all__ = ['ConcatWith']
@layer_register(use_scope=False, log_shape=False) @layer_register(use_scope=False, log_shape=False)
def ConcatWith(x, dim, tensor): def ConcatWith(x, dim, tensor):
""" """
......
...@@ -8,6 +8,7 @@ from ._common import layer_register ...@@ -8,6 +8,7 @@ from ._common import layer_register
__all__ = ['SoftMax'] __all__ = ['SoftMax']
@layer_register() @layer_register()
def SoftMax(x, use_temperature=False, temperature_init=1.0): def SoftMax(x, use_temperature=False, temperature_init=1.0):
""" """
...@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0): ...@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0):
""" """
if use_temperature: if use_temperature:
t = tf.get_variable('invtemp', [], t = tf.get_variable('invtemp', [],
initializer=tf.constant_initializer(1.0 / float(temperature_init))) initializer=tf.constant_initializer(1.0 / float(temperature_init)))
x = x * t x = x * t
return tf.nn.softmax(x, name='output') return tf.nn.softmax(x, name='output')
...@@ -8,6 +8,7 @@ import os.path ...@@ -8,6 +8,7 @@ import os.path
__all__ = [] __all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages( ...@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if module_name.startswith('_'): if module_name.startswith('_'):
continue continue
global_import(module_name) global_import(module_name)
...@@ -12,9 +12,10 @@ from ..utils import logger ...@@ -12,9 +12,10 @@ from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor', __all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase', 'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph', 'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor'] 'DataParallelOfflinePredictor']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class PredictorBase(object): class PredictorBase(object):
...@@ -46,7 +47,9 @@ class PredictorBase(object): ...@@ -46,7 +47,9 @@ class PredictorBase(object):
:return: output as defined by the config :return: output as defined by the config
""" """
class AsyncPredictorBase(PredictorBase): class AsyncPredictorBase(PredictorBase):
@abstractmethod @abstractmethod
def put_task(self, dp, callback=None): def put_task(self, dp, callback=None):
""" """
...@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase): ...@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
# in Tornado, Future.result() doesn't wait # in Tornado, Future.result() doesn't wait
return fut.result() return fut.result()
class OnlinePredictor(PredictorBase): class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_tensors, output_tensors, return_input=False): def __init__(self, sess, input_tensors, output_tensors, return_input=False):
self.session = sess self.session = sess
self.return_input = return_input self.return_input = return_input
...@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase): ...@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor): class OfflinePredictor(OnlinePredictor):
""" Build a predictor from a given config, in an independent graph""" """ Build a predictor from a given config, in an independent graph"""
def __init__(self, config): def __init__(self, config):
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
...@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
super(OfflinePredictor, self).__init__( super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(build_tower_fn, towers): def build_multi_tower_prediction_graph(build_tower_fn, towers):
...@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers): ...@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
""" """
for k in towers: for k in towers:
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('{}{}'.format(PREDICT_TOWER, k)): TowerContext('{}{}'.format(PREDICT_TOWER, k)):
build_tower_fn(k) build_tower_fn(k)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor): class MultiTowerOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers): def __init__(self, config, towers):
self.graph = tf.Graph() self.graph = tf.Graph()
self.predictors = [] self.predictors = []
...@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for k in towers: for k in towers:
output_vars = get_tensors_by_names( output_vars = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n \ ['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input)) self.sess, input_vars, output_vars, config.return_input))
...@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictors(self, n): def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)] return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers): def __init__(self, config, towers):
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
...@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for k in towers: for k in towers:
towername = PREDICT_TOWER + str(k) towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders( input_vars = config.model.build_placeholders(
prefix=towername + '-') prefix=towername + '-')
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False): TowerContext(towername, is_training=False):
config.model.build_graph(input_vars) config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars]) input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names( output_vars.extend(get_tensors_by_names(
[towername + '/' + n \ [towername + '/' + n
for n in config.output_names])) for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names) input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess) config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__( super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
...@@ -15,11 +15,13 @@ from .base import OfflinePredictor ...@@ -15,11 +15,13 @@ from .base import OfflinePredictor
import multiprocessing import multiprocessing
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult' ] __all__ = ['PredictConfig', 'get_predict_func', 'PredictResult']
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictConfig(object): class PredictConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
The config used by `get_predict_func`. The config used by `get_predict_func`.
...@@ -61,12 +63,14 @@ class PredictConfig(object): ...@@ -61,12 +63,14 @@ class PredictConfig(object):
self.output_names = kwargs.pop('output_var_names') self.output_names = kwargs.pop('output_var_names')
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!") #logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names assert len(self.input_names), self.input_names
for v in self.input_names: assert_type(v, six.string_types) for v in self.input_names:
assert_type(v, six.string_types)
assert len(self.output_names), self.output_names assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False) self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config): def get_predict_func(config):
""" """
Produce a offline predictor run inside a new session. Produce a offline predictor run inside a new session.
...@@ -76,4 +80,3 @@ def get_predict_func(config): ...@@ -76,4 +80,3 @@ def get_predict_func(config):
a list of output values defined in ``config.output_var_names``. a list of output values defined in ``config.output_var_names``.
""" """
return OfflinePredictor(config) return OfflinePredictor(config)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: concurrency.py # File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing, threading import multiprocessing
import threading
import tensorflow as tf import tensorflow as tf
import time import time
import six import six
...@@ -25,10 +26,12 @@ except ImportError: ...@@ -25,10 +26,12 @@ except ImportError:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker'] __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else: else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor'] 'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs offline in multiprocess""" """ Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, config): def __init__(self, idx, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
...@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
describe_model() describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue""" """ An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, inqueue, outqueue, config): def __init__(self, idx, inqueue, outqueue, config):
""" """
:param inqueue: input queue to get data point. elements are (task_id, dp) :param inqueue: input queue to get data point. elements are (task_id, dp)
...@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
...@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread): ...@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
while True: while True:
batched, futures = self.fetch_batch() batched, futures = self.fetch_batch()
outputs = self.func(batched) outputs = self.func(batched)
#print "Worker {} batched {} Queue {}".format( # print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize()) # self.id, len(futures), self.queue.qsize())
# debug, for speed testing # debug, for speed testing
#if not hasattr(self, 'xxx'): # if not hasattr(self, 'xxx'):
#self.xxx = outputs = self.func(batched) # self.xxx = outputs = self.func(batched)
#else: # else:
#outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])] # outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures): for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs]) f.set_result([k[idx] for k in outputs])
...@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread): ...@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
cnt += 1 cnt += 1
return batched, futures return batched, futures
class MultiThreadAsyncPredictor(AsyncPredictorBase): class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" """
An multithread online async predictor which run a list of PredictorBase. An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally. It would do an extra batching internally.
""" """
def __init__(self, predictors, batch_size=5): def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor""" """ :param predictors: a list of OnlinePredictor"""
assert len(predictors) assert len(predictors)
...@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
#assert isinstance(k, OnlinePredictor), type(k) #assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here # TODO use predictors.return_input here
assert k.return_input == False assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100) self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [ self.threads = [
PredictorWorkerThread( PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size) self.input_queue, f, id, batch_size=batch_size)
......
...@@ -20,10 +20,12 @@ from .common import PredictConfig ...@@ -20,10 +20,12 @@ from .common import PredictConfig
from .base import OfflinePredictor from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor'] 'MultiProcessDatasetPredictor']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object): class DatasetPredictorBase(object):
def __init__(self, config, dataset): def __init__(self, config, dataset):
""" """
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
...@@ -45,10 +47,12 @@ class DatasetPredictorBase(object): ...@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
""" """
return list(self.get_result()) return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase): class SimpleDatasetPredictor(DatasetPredictorBase):
""" """
Run the predict_config on a given `DataFlow`. Run the predict_config on a given `DataFlow`.
""" """
def __init__(self, config, dataset): def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset) super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config) self.predictor = OfflinePredictor(config)
...@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz, disable=(sz==0)) as pbar: with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
for dp in self.dataset.get_data(): for dp in self.dataset.get_data():
res = self.predictor(dp) res = self.predictor(dp)
yield res yield res
pbar.update() pbar.update()
# TODO allow unordered # TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase): class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True): def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
""" """
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported. Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
...@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.ordered = ordered self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue( self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
if use_gpu: if use_gpu:
try: try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
assert len(gpus) >= self.nr_proc, \ assert len(gpus) >= self.nr_proc, \
"nr_proc={} while only {} gpus available".format( "nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus)) self.nr_proc, len(gpus))
except KeyError: except KeyError:
# TODO number of GPUs not checked # TODO number of GPUs not checked
gpus = list(range(self.nr_proc)) gpus = list(range(self.nr_proc))
...@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# worker produces (idx, result) to outqueue # worker produces (idx, result) to outqueue
self.outqueue = multiprocessing.Queue() self.outqueue = multiprocessing.Queue()
self.workers = [MultiProcessQueuePredictWorker( self.workers = [MultiProcessQueuePredictWorker(
i, self.inqueue, self.outqueue, self.config) i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)] for i in range(self.nr_proc)]
# start inqueue and workers # start inqueue and workers
self.inqueue_proc.start() self.inqueue_proc.start()
...@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
if ordered: if ordered:
self.result_queue = OrderedResultGatherProc( self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc) self.outqueue, nr_producer=self.nr_proc)
self.result_queue.start() self.result_queue.start()
ensure_proc_terminate(self.result_queue) ensure_proc_terminate(self.result_queue)
else: else:
...@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz, disable=(sz==0)) as pbar: with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
die_cnt = 0 die_cnt = 0
while True: while True:
res = self.result_queue.get() res = self.result_queue.get()
...@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.result_queue.join() self.result_queue.join()
self.result_queue.terminate() self.result_queue.terminate()
for p in self.workers: for p in self.workers:
p.join(); p.terminate() p.join()
p.terminate()
...@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope'] ...@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = [] _ArgScopeStack = []
@contextmanager @contextmanager
def argscope(layers, **param): def argscope(layers, **param):
if not isinstance(layers, list): if not isinstance(layers, list):
...@@ -33,6 +34,7 @@ def argscope(layers, **param): ...@@ -33,6 +34,7 @@ def argscope(layers, **param):
yield yield
del _ArgScopeStack[-1] del _ArgScopeStack[-1]
def get_arg_scope(): def get_arg_scope():
""" """
:returns: the current argscope. :returns: the current argscope.
......
...@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config', ...@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
'freeze_collection', 'freeze_collection',
'get_tf_version'] 'get_tf_version']
def get_default_sess_config(mem_fraction=0.99): def get_default_sess_config(mem_fraction=0.99):
""" """
Return a better session config to use as default. Return a better session config to use as default.
...@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
#conf.log_device_placement = True #conf.log_device_placement = True
return conf return conf
def get_global_step_var(): def get_global_step_var():
""" :returns: the global_step variable in the current graph. create if not existed""" """ :returns: the global_step variable in the current graph. create if not existed"""
try: try:
...@@ -45,19 +47,21 @@ def get_global_step_var(): ...@@ -45,19 +47,21 @@ def get_global_step_var():
except KeyError: except KeyError:
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
assert scope.name == '', \ assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!" "Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[], var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.constant_initializer(dtype=tf.int32), initializer=tf.constant_initializer(dtype=tf.int32),
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int32)
return var return var
def get_global_step(): def get_global_step():
""" :returns: global_step value in current graph and session""" """ :returns: global_step value in current graph and session"""
return tf.train.global_step( return tf.train.global_step(
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
def get_op_tensor_name(name): def get_op_tensor_name(name):
""" """
Tensor name is assumed to be ``op_name + ':0'`` Tensor name is assumed to be ``op_name + ':0'``
...@@ -72,6 +76,7 @@ def get_op_tensor_name(name): ...@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
get_op_var_name = get_op_tensor_name get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names): def get_tensors_by_names(names):
""" """
Get a list of tensors in the default graph by a list of names Get a list of tensors in the default graph by a list of names
...@@ -85,26 +90,31 @@ def get_tensors_by_names(names): ...@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
get_vars_by_names = get_tensors_by_names get_vars_by_names = get_tensors_by_names
def backup_collection(keys): def backup_collection(keys):
ret = {} ret = {}
for k in keys: for k in keys:
ret[k] = copy(tf.get_collection(k)) ret[k] = copy(tf.get_collection(k))
return ret return ret
def restore_collection(backup): def restore_collection(backup):
for k, v in six.iteritems(backup): for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:] del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v) tf.get_collection_ref(k).extend(v)
def clear_collection(keys): def clear_collection(keys):
for k in keys: for k in keys:
del tf.get_collection_ref(k)[:] del tf.get_collection_ref(k)[:]
@contextmanager @contextmanager
def freeze_collection(keys): def freeze_collection(keys):
backup = backup_collection(keys) backup = backup_collection(keys)
yield yield
restore_collection(backup) restore_collection(backup)
def get_tf_version(): def get_tf_version():
return int(tf.__version__.split('.')[1]) return int(tf.__version__.split('.')[1])
...@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', ...@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors', 'ScaleGradient', 'MapGradient', 'apply_grad_processors',
'GlobalNormClip'] 'GlobalNormClip']
def apply_grad_processors(grads, gradprocs): def apply_grad_processors(grads, gradprocs):
""" """
:param grads: list of (grad, var). :param grads: list of (grad, var).
...@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs): ...@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
g = proc.process(g) g = proc.process(g)
return g return g
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class GradientProcessor(object): class GradientProcessor(object):
...@@ -51,6 +53,7 @@ class GradientProcessor(object): ...@@ -51,6 +53,7 @@ class GradientProcessor(object):
class GlobalNormClip(GradientProcessor): class GlobalNormClip(GradientProcessor):
def __init__(self, global_norm): def __init__(self, global_norm):
""" Clip by global norm """ Clip by global norm
Note that the global norm is the sum of norm for **all** gradients Note that the global norm is the sum of norm for **all** gradients
...@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor): ...@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm') g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm')
return list(zip(g, v)) return list(zip(g, v))
class MapGradient(GradientProcessor): class MapGradient(GradientProcessor):
""" """
Apply a function on all gradient if the name matches regex. Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged. Keep the other gradients unchanged.
""" """
def __init__(self, func, regex='.*'): def __init__(self, func, regex='.*'):
""" """
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the :param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
...@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor): ...@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor):
args = inspect.getargspec(func).args args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func) arg_num = len(args) - inspect.ismethod(func)
assert arg_num in [1, 2], \ assert arg_num in [1, 2], \
"The function must take 1 or 2 arguments! ({})".format(args) "The function must take 1 or 2 arguments! ({})".format(args)
if arg_num == 1: if arg_num == 1:
self.func = lambda grad, var: func(grad) self.func = lambda grad, var: func(grad)
else: else:
...@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor): ...@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
_summaried_gradient = set() _summaried_gradient = set()
class SummaryGradient(MapGradient): class SummaryGradient(MapGradient):
""" """
Summary history and RMS for each graident variable Summary history and RMS for each graident variable
""" """
def __init__(self): def __init__(self):
super(SummaryGradient, self).__init__(self._mapper) super(SummaryGradient, self).__init__(self._mapper)
...@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient): ...@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
add_moving_summary(rms(grad, name=name + '/rms')) add_moving_summary(rms(grad, name=name + '/rms'))
return grad return grad
class CheckGradient(MapGradient): class CheckGradient(MapGradient):
""" """
Check for numeric issue. Check for numeric issue.
""" """
def __init__(self): def __init__(self):
super(CheckGradient, self).__init__(self._mapper) super(CheckGradient, self).__init__(self._mapper)
...@@ -128,10 +137,12 @@ class CheckGradient(MapGradient): ...@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name) grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name)
return grad return grad
class ScaleGradient(MapGradient): class ScaleGradient(MapGradient):
""" """
Scale certain gradient by a multiplier Scale certain gradient by a multiplier
""" """
def __init__(self, multipliers, log=True): def __init__(self, multipliers, log=True):
""" """
:param multipliers: list of (regex, float) :param multipliers: list of (regex, float)
......
...@@ -9,6 +9,7 @@ from ..utils import logger ...@@ -9,6 +9,7 @@ from ..utils import logger
__all__ = ['describe_model', 'get_shape_str'] __all__ = ['describe_model', 'get_shape_str']
def describe_model(): def describe_model():
""" print a description of the current model parameters """ """ print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
...@@ -40,5 +41,3 @@ def get_shape_str(tensors): ...@@ -40,5 +41,3 @@ def get_shape_str(tensors):
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors)) assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list())
return shape_str return shape_str
...@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore', ...@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default. # TODO they initialize_all at the beginning by default.
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session""" """ Base class for utilities to initialize a session"""
...@@ -35,23 +36,29 @@ class SessionInit(object): ...@@ -35,23 +36,29 @@ class SessionInit(object):
def _init(self, sess): def _init(self, sess):
pass pass
class JustCurrentSession(SessionInit): class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder""" """ Just use the current default session. This is a no-op placeholder"""
def _init(self, sess): def _init(self, sess):
pass pass
class NewSession(SessionInit): class NewSession(SessionInit):
""" """
Create a new session. All variables will be initialized by their Create a new session. All variables will be initialized by their
initializer. initializer.
""" """
def _init(self, sess): def _init(self, sess):
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore an old model saved by `ModelSaver`. Restore an old model saved by `ModelSaver`.
""" """
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
""" """
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file. :param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
...@@ -71,7 +78,7 @@ class SaverRestore(SessionInit): ...@@ -71,7 +78,7 @@ class SaverRestore(SessionInit):
new_path = model_path.split('.index')[0] new_path = model_path.split('.index')[0]
if new_path != model_path: if new_path != model_path:
logger.warn( logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path)) "[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
self.set_path(model_path) self.set_path(model_path)
...@@ -146,10 +153,12 @@ class SaverRestore(SessionInit): ...@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name)) logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
""" """
Restore variables from a dictionary. Restore variables from a dictionary.
""" """
def __init__(self, param_dict): def __init__(self, param_dict):
""" """
:param param_dict: a dict of {name: value} :param param_dict: a dict of {name: value}
...@@ -158,7 +167,7 @@ class ParamRestore(SessionInit): ...@@ -158,7 +167,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables]) variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
...@@ -174,14 +183,15 @@ class ParamRestore(SessionInit): ...@@ -174,14 +183,15 @@ class ParamRestore(SessionInit):
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, upd = SessionUpdate(sess,
[v for v in variables if \ [v for v in variables if
get_savename_from_varname(v.name) in intersect]) get_savename_from_varname(v.name) in intersect])
logger.info("Restoring from dict ...") logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
class ChainInit(SessionInit): class ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance.""" """ Init a session by a list of SessionInit instance."""
def __init__(self, sess_inits, new_session=True): def __init__(self, sess_inits, new_session=True):
""" """
:params sess_inits: list of `SessionInit` instances. :params sess_inits: list of `SessionInit` instances.
......
...@@ -15,6 +15,7 @@ from .symbolic_functions import rms ...@@ -15,6 +15,7 @@ from .symbolic_functions import rms
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average'] 'add_moving_summary', 'summary_moving_average']
def create_summary(name, v): def create_summary(name, v):
""" """
Return a tf.Summary object with name and simple scalar value v Return a tf.Summary object with name and simple scalar value v
...@@ -25,6 +26,7 @@ def create_summary(name, v): ...@@ -25,6 +26,7 @@ def create_summary(name, v):
s.value.add(tag=name, simple_value=v) s.value.add(tag=name, simple_value=v)
return s return s
def add_activation_summary(x, name=None): def add_activation_summary(x, name=None):
""" """
Add summary to graph for an activation tensor x. Add summary to graph for an activation tensor x.
...@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None): ...@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
tf.summary.scalar(name + '-sparsity', tf.nn.zero_fraction(x)) tf.summary.scalar(name + '-sparsity', tf.nn.zero_fraction(x))
tf.summary.scalar(name + '-rms', rms(x)) tf.summary.scalar(name + '-rms', rms(x))
def add_param_summary(summary_lists): def add_param_summary(summary_lists):
""" """
Add summary for all trainable variables matching the regex Add summary for all trainable variables matching the regex
...@@ -54,6 +57,7 @@ def add_param_summary(summary_lists): ...@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower: if ctx is not None and not ctx.is_main_training_tower:
return return
def perform(var, action): def perform(var, action):
ndim = var.get_shape().ndims ndim = var.get_shape().ndims
name = var.name.replace(':0', '') name = var.name.replace(':0', '')
...@@ -87,6 +91,7 @@ def add_param_summary(summary_lists): ...@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
for act in actions: for act in actions:
perform(p, act) perform(p, act)
def add_moving_summary(v, *args): def add_moving_summary(v, *args):
""" """
:param v: tensor or list of tensor to summary :param v: tensor or list of tensor to summary
...@@ -102,6 +107,7 @@ def add_moving_summary(v, *args): ...@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized @memoized
def summary_moving_average(tensors=None): def summary_moving_average(tensors=None):
""" """
...@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None): ...@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c)) tf.summary.scalar(name + '-summary', averager.average(c))
return avg_maintain_op return avg_maintain_op
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from ..utils import logger from ..utils import logger
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
""" """
:param logits: NxC :param logits: NxC
...@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): ...@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction :returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
""" """
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name) tf.float32, name=name)
def flatten(x): def flatten(x):
""" """
...@@ -21,6 +23,7 @@ def flatten(x): ...@@ -21,6 +23,7 @@ def flatten(x):
""" """
return tf.reshape(x, [-1]) return tf.reshape(x, [-1])
def batch_flatten(x): def batch_flatten(x):
""" """
Flatten the tensor except the first dimension. Flatten the tensor except the first dimension.
...@@ -30,6 +33,7 @@ def batch_flatten(x): ...@@ -30,6 +33,7 @@ def batch_flatten(x):
return tf.reshape(x, [-1, int(np.prod(shape))]) return tf.reshape(x, [-1, int(np.prod(shape))])
return tf.reshape(x, tf.pack([tf.shape(x)[0], -1])) return tf.reshape(x, tf.pack([tf.shape(x)[0], -1]))
def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss, The class-balanced cross entropy loss,
...@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): ...@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
cost = tf.sub(loss_pos, loss_neg, name=name) cost = tf.sub(loss_pos, loss_neg, name=name)
return cost return cost
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'): def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss, The class-balanced cross entropy loss,
...@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost = tf.reduce_mean(cost * (1 - beta), name=name) cost = tf.reduce_mean(cost * (1 - beta), name=name)
#logstable = tf.log(1 + tf.exp(-tf.abs(z))) #logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#loss_pos = -beta * tf.reduce_mean(-y * # loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z))) #(logstable - tf.minimum(0.0, z)))
#loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) * # loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0))) #(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name) #cost = tf.sub(loss_pos, loss_neg, name=name)
return cost return cost
def print_stat(x, message=None): def print_stat(x, message=None):
""" a simple print op. """ a simple print op.
Use it like: x = print_stat(x) Use it like: x = print_stat(x)
...@@ -89,7 +95,8 @@ def print_stat(x, message=None): ...@@ -89,7 +95,8 @@ def print_stat(x, message=None):
if message is None: if message is None:
message = x.op.name message = x.op.name
return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), x], summarize=20, return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), x], summarize=20,
message=message, name='print_' + x.op.name) message=message, name='print_' + x.op.name)
def rms(x, name=None): def rms(x, name=None):
if name is None: if name is None:
...@@ -98,14 +105,16 @@ def rms(x, name=None): ...@@ -98,14 +105,16 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
def huber_loss(x, delta=1, name='huber_loss'): def huber_loss(x, delta=1, name='huber_loss'):
sqrcost = tf.square(x) sqrcost = tf.square(x)
abscost = tf.abs(x) abscost = tf.abs(x)
return tf.reduce_sum( return tf.reduce_sum(
tf.select(abscost < delta, tf.select(abscost < delta,
sqrcost * 0.5, sqrcost * 0.5,
abscost * delta - 0.5 * delta ** 2), abscost * delta - 0.5 * delta ** 2),
name=name) name=name)
def get_scalar_var(name, init_value, summary=False, trainable=False): def get_scalar_var(name, init_value, summary=False, trainable=False):
""" """
...@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False): ...@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
:param summary: summary this variable :param summary: summary this variable
""" """
ret = tf.get_variable(name, shape=[], ret = tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value), initializer=tf.constant_initializer(init_value),
trainable=trainable) trainable=trainable)
if summary: if summary:
# this is recognized in callbacks.StatHolder # this is recognized in callbacks.StatHolder
tf.summary.scalar(name + '-summary', ret) tf.summary.scalar(name + '-summary', ret)
......
...@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext'] ...@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext = None _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
def __init__(self, tower_name, is_training=None): def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """ """ tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name self._name = tower_name
...@@ -65,7 +67,7 @@ class TowerContext(object): ...@@ -65,7 +67,7 @@ class TowerContext(object):
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, \ assert _CurrentTowerContext is None, \
"Nesting TowerContext!" "Nesting TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
if len(self._name): if len(self._name):
self._scope = tf.name_scope(self._name) self._scope = tf.name_scope(self._name)
...@@ -78,7 +80,7 @@ class TowerContext(object): ...@@ -78,7 +80,7 @@ class TowerContext(object):
self._scope.__exit__(exc_type, exc_val, exc_tb) self._scope.__exit__(exc_type, exc_val, exc_tb)
return False return False
def get_current_tower_context(): def get_current_tower_context():
global _CurrentTowerContext global _CurrentTowerContext
return _CurrentTowerContext return _CurrentTowerContext
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: varmanip.py # File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six, os import six
import os
import tensorflow as tf import tensorflow as tf
from collections import defaultdict from collections import defaultdict
import re import re
...@@ -13,7 +14,8 @@ from ..utils.naming import * ...@@ -13,7 +14,8 @@ from ..utils.naming import *
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name'] 'get_savename_from_varname', 'is_training_name']
def get_savename_from_varname( def get_savename_from_varname(
varname, varname_prefix=None, varname, varname_prefix=None,
...@@ -33,13 +35,15 @@ def get_savename_from_varname( ...@@ -33,13 +35,15 @@ def get_savename_from_varname(
name = re.sub('tower[p0-9]+/', '', name) name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \ if varname_prefix is not None \
and name.startswith(varname_prefix): and name.startswith(varname_prefix):
name = name[len(varname_prefix)+1:] name = name[len(varname_prefix) + 1:]
if savename_prefix is not None: if savename_prefix is not None:
name = savename_prefix + '/' + name name = savename_prefix + '/' + name
return name return name
class SessionUpdate(object): class SessionUpdate(object):
""" Update the variables in a session """ """ Update the variables in a session """
def __init__(self, sess, vars_to_update): def __init__(self, sess, vars_to_update):
""" """
:param vars_to_update: a collection of variables to update :param vars_to_update: a collection of variables to update
...@@ -66,11 +70,12 @@ class SessionUpdate(object): ...@@ -66,11 +70,12 @@ class SessionUpdate(object):
if varshape != value.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \ assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape) "{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name)) logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value}) self.sess.run(op, feed_dict={p: value})
def dump_session_params(path): def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as """ Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore npy format, loadable by ParamRestore
...@@ -90,6 +95,7 @@ the same name".format(v.name)) ...@@ -90,6 +95,7 @@ the same name".format(v.name))
logger.info(str(result.keys())) logger.info(str(result.keys()))
np.save(path, result) np.save(path, result)
def dump_chkpt_vars(model_path): def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict""" """ Dump all variables from a checkpoint to a dict"""
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
...@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path): ...@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n) result[n] = reader.get_tensor(n)
return result return result
def is_training_name(name): def is_training_name(name):
""" """
This is only used to improve logging. This is only used to improve logging.
......
...@@ -8,6 +8,7 @@ import os.path ...@@ -8,6 +8,7 @@ import os.path
__all__ = [] __all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else [] lst = p.__all__ if '__all__' in dir(p) else []
...@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages( ...@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if module_name.startswith('_'): if module_name.startswith('_'):
continue continue
global_import(module_name) global_import(module_name)
...@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary ...@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException): class StopTraining(BaseException):
pass pass
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Trainer(object): class Trainer(object):
""" Base class for a trainer.""" """ Base class for a trainer."""
...@@ -91,7 +94,7 @@ class Trainer(object): ...@@ -91,7 +94,7 @@ class Trainer(object):
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150 suffix = '-summary' # issue#6150
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
...@@ -99,7 +102,7 @@ class Trainer(object): ...@@ -99,7 +102,7 @@ class Trainer(object):
def write_scalar_summary(self, name, val): def write_scalar_summary(self, name, val):
self.summary_writer.add_summary( self.summary_writer.add_summary(
create_summary(name, val), get_global_step()) create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val) self.stat_holder.add_stat(name, val)
def setup(self): def setup(self):
...@@ -138,7 +141,7 @@ class Trainer(object): ...@@ -138,7 +141,7 @@ class Trainer(object):
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step())) logger.info("Start training with global_step={}".format(get_global_step()))
for epoch_num in range( for epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1): self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
epoch_num, get_global_step() + self.config.step_per_epoch)): epoch_num, get_global_step() + self.config.step_per_epoch)):
...@@ -147,7 +150,7 @@ class Trainer(object): ...@@ -147,7 +150,7 @@ class Trainer(object):
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
callbacks.trigger_step() # not useful? callbacks.trigger_step() # not useful?
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self.trigger_epoch() self.trigger_epoch()
......
...@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow ...@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from .input_data import InputData from .input_data import InputData
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
class TrainConfig(object): class TrainConfig(object):
""" """
Config for training a model with a single loss Config for training a model with a single loss
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
:param dataset: the dataset to train. a `DataFlow` instance. :param dataset: the dataset to train. a `DataFlow` instance.
......
...@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer ...@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer'] __all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer): class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """ """ A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # need to run summary_op every epoch
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
...@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer): ...@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self) self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer): class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower""" """ get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
...@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
# GATE_NONE faster? # GATE_NONE faster?
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
cost_var, cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False) colocate_gradients_with_ops=False)
add_moving_summary(cost_var) add_moving_summary(cost_var)
return cost_var, grads return cost_var, grads
def run_step(self): def run_step(self):
""" Simply run self.train_op""" """ Simply run self.train_op"""
self.sess.run(self.train_op) self.sess.run(self.train_op)
#if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
#self.cnt = 0 # self.cnt = 0
#else: # else:
#self.cnt += 1 # self.cnt += 1
#if self.cnt % 10 == 0: # if self.cnt % 10 == 0:
## debug-benchmark code: # # debug-benchmark code:
#run_metadata = tf.RunMetadata() # run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op], # self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata # run_metadata=run_metadata
#) # )
#from tensorflow.python.client import timeline # from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w') # trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format()) # trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit() # import sys; sys.exit()
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer): SingleCostFeedfreeTrainer):
def __init__(self, config): def __init__(self, config):
""" """
A trainer with single cost, single training tower and feed-free input A trainer with single cost, single training tower and feed-free input
...@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer( ...@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer(
super(SimpleFeedfreeTrainer, self).__init__(config) super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower) self._setup_predictor_factory(config.predict_tower)
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!" "SimpleFeedfreeTrainer doesn't support multigpu!"
def _setup(self): def _setup(self):
super(SimpleFeedfreeTrainer, self)._setup() super(SimpleFeedfreeTrainer, self)._setup()
...@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer( ...@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
# skip training # skip training
#self.train_op = tf.group(*self.dequed_inputs) #self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer): class QueueInputTrainer(SimpleFeedfreeTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
...@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer): ...@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!") logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
...@@ -14,13 +14,16 @@ from ..utils import logger ...@@ -14,13 +14,16 @@ from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput', __all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput'] 'DummyConstantInput']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputData(object): class InputData(object):
pass pass
class FeedInput(InputData): class FeedInput(InputData):
def __init__(self, ds): def __init__(self, ds):
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
...@@ -39,7 +42,9 @@ class FeedInput(InputData): ...@@ -39,7 +42,9 @@ class FeedInput(InputData):
feed = dict(zip(self.input_vars, data)) feed = dict(zip(self.input_vars, data))
return feed return feed
class FeedfreeInput(InputData): class FeedfreeInput(InputData):
def get_input_tensors(self): def get_input_tensors(self):
return self._get_input_tensors() return self._get_input_tensors()
...@@ -49,7 +54,9 @@ class FeedfreeInput(InputData): ...@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
always create and return a list of new input tensors always create and return a list of new input tensors
""" """
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, ds, input_placehdrs): def __init__(self, trainer, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread' self.name = 'EnqueueThread'
...@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread): ...@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
...@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread): ...@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
pass pass
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
""" """
:param ds: a `DataFlow` instance :param ds: a `DataFlow` instance
...@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput): ...@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput):
def _setup(self, trainer): def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars() self.input_placehdrs = trainer.model.get_input_vars()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!" "QueueInput can only be used with input placeholders!"
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs], 50, [x.dtype for x in self.input_placehdrs],
name='input_queue') name='input_queue')
self.thread = EnqueueThread( self.thread = EnqueueThread(
trainer, self.queue, self.ds, self.input_placehdrs) trainer, self.queue, self.ds, self.input_placehdrs)
trainer.config.callbacks.append(StartProcOrThread(self.thread)) trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self): def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs): for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
# test the overhead of queue # test the overhead of queue
#with tf.device('/gpu:0'): # with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3], # ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False), # dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)] # tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret return ret
class DummyConstantInput(QueueInput): class DummyConstantInput(QueueInput):
""" only for debugging performance issues """ """ only for debugging performance issues """
def __init__(self, ds, shapes): def __init__(self, ds, shapes):
super(DummyConstantInput, self).__init__(ds) super(DummyConstantInput, self).__init__(ds)
self.shapes = shapes self.shapes = shapes
...@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput): ...@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput):
for idx, p in enumerate(placehdrs): for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'): with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name, ret.append(tf.get_variable('dummy-' + p.op.name,
shape=self.shapes[idx], dtype=p.dtype, trainable=False, shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer())) initializer=tf.constant_initializer()))
return ret return ret
class TensorInput(FeedfreeInput): class TensorInput(FeedfreeInput):
def __init__(self, get_tensor_fn, size=None): def __init__(self, get_tensor_fn, size=None):
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
self._size = size self._size = size
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import itertools, re import itertools
import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
...@@ -12,7 +13,7 @@ from ..utils.naming import * ...@@ -12,7 +13,7 @@ from ..utils.naming import *
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer from .base import Trainer
...@@ -22,6 +23,7 @@ from .input_data import QueueInput ...@@ -22,6 +23,7 @@ from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer): class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
...@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer): ...@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
restore_collection(backup) restore_collection(backup)
return grad_list return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer, class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
if hasattr(config, 'dataset'): if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue) self._input_method = QueueInput(config.dataset, input_queue)
...@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
if len(tower_grads) == 1: if len(tower_grads) == 1:
...@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup() super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance: # debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]] #ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops) #self.train_op = tf.group(*ops)
#return # return
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
...@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def run_step(self): def run_step(self):
self.sess.run(self.train_op) self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(MultiGPUTrainer, class AsyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
def __init__(self, config, def __init__(self, config,
input_queue=None, input_queue=None,
average_gradient=True, average_gradient=True,
predict_tower=None): predict_tower=None):
if hasattr(config, 'dataset'): if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue) self._input_method = QueueInput(config.dataset, input_queue)
else: else:
...@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self): def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor() gradprocs = self.model.get_gradient_processor()
if self._average_gradient and self.config.nr_tower > 1: if self._average_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
...@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self.training_threads = [] self.training_threads = []
for k in range(1, len(self.config.tower)): for k in range(1, len(self.config.tower)):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
def f(op=train_op): # avoid late-binding
self.sess.run([op]) self.sess.run([op])
next(self.async_step_counter) next(self.async_step_counter)
th = LoopThread(f) th = LoopThread(f)
...@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def run_step(self): def run_step(self):
if not self.async_running: if not self.async_running:
self.async_running = True self.async_running = True
for th in self.training_threads: # resume all threads for th in self.training_threads: # resume all threads
th.resume() th.resume()
next(self.async_step_counter) next(self.async_step_counter)
self.sess.run(self.train_op) self.sess.run(self.train_op)
...@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
async_step_total_cnt = int(re.findall( async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0]) '[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary( self.write_scalar_summary(
'async_global_step', async_step_total_cnt) 'async_global_step', async_step_total_cnt)
except: except:
logger.exception("Cannot log async_global_step") logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch() super(AsyncMultiGPUTrainer, self)._trigger_epoch()
...@@ -10,13 +10,14 @@ from .base import Trainer ...@@ -10,13 +10,14 @@ from .base import Trainer
from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer','MultiPredictorTowerTrainer'] __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
...@@ -52,8 +53,10 @@ class PredictorFactory(object): ...@@ -52,8 +53,10 @@ class PredictorFactory(object):
build_multi_tower_prediction_graph(fn, self.towers) build_multi_tower_prediction_graph(fn, self.towers)
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive demo trainer """ """ A naive demo trainer """
def __init__(self, config): def __init__(self, config):
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
...@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer): ...@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer):
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
...@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer): ...@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer):
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0) return self._predictor_factory.get_predictor(input_names, output_names, 0)
class MultiPredictorTowerTrainer(Trainer): class MultiPredictorTowerTrainer(Trainer):
""" A trainer with possibly multiple prediction tower """ """ A trainer with possibly multiple prediction tower """
def _setup_predictor_factory(self, predict_tower): def _setup_predictor_factory(self, predict_tower):
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
predict_tower = predict_tower or [0] predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower) self.sess, self.model, predict_tower)
def get_predict_func(self, input_names, output_names, tower=0): def get_predict_func(self, input_names, output_names, tower=0):
""" """
......
...@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow. ...@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
__all__ = [] __all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), None, level=1) p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
...@@ -23,7 +24,7 @@ _TO_IMPORT = set([ ...@@ -23,7 +24,7 @@ _TO_IMPORT = set([
'naming', 'naming',
'utils', 'utils',
'gpu' 'gpu'
]) ])
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
...@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages( ...@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
if module_name in _TO_IMPORT: if module_name in _TO_IMPORT:
_global_import(module_name) _global_import(module_name)
__all__.append(module_name) __all__.append(module_name)
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
import operator import operator
import inspect, six, functools import inspect
import six
import functools
import collections import collections
__all__ = [ 'map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs'] __all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs']
def map_arg(**maps): def map_arg(**maps):
""" """
...@@ -26,11 +29,13 @@ def map_arg(**maps): ...@@ -26,11 +29,13 @@ def map_arg(**maps):
return wrapper return wrapper
return deco return deco
class memoized(object): class memoized(object):
'''Decorator. Caches a function's return value each time it is called. '''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned If called later with the same arguments, the cached value is returned
(not reevaluated). (not reevaluated).
''' '''
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
self.cache = {} self.cache = {}
...@@ -60,8 +65,11 @@ class memoized(object): ...@@ -60,8 +65,11 @@ class memoized(object):
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {} _MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func): def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary? h = hash(func) # make sure it is hashable. is it necessary?
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS: if func not in _MEMOIZED_NOARGS:
res = func(*args, **kwargs) res = func(*args, **kwargs)
...@@ -70,15 +78,16 @@ def memoized_ignoreargs(func): ...@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
return _MEMOIZED_NOARGS[func] return _MEMOIZED_NOARGS[func]
return wrapper return wrapper
#_GLOBAL_MEMOIZED_CACHE = dict() # _GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func): # def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different # """ Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func) # calls to global_memoized(func)
#""" # """
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None) # ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None: # if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func) # ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret # return ret
def shape2d(a): def shape2d(a):
""" """
......
...@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', ...@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE', 'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'mask_sigint', 'start_proc_mask_signal'] 'mask_sigint', 'start_proc_mask_signal']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
""" """
A thread that has a 'stop' event. A thread that has a 'stop' event.
""" """
def __init__(self): def __init__(self):
super(StoppableThread, self).__init__() super(StoppableThread, self).__init__()
self._stop_evt = threading.Event() self._stop_evt = threading.Event()
...@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread): ...@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
except queue.Empty: except queue.Empty:
pass pass
class LoopThread(StoppableThread): class LoopThread(StoppableThread):
""" A pausable thread that simply runs a loop""" """ A pausable thread that simply runs a loop"""
def __init__(self, func, pausable=True): def __init__(self, func, pausable=True):
""" """
:param func: the function to run :param func: the function to run
...@@ -89,6 +93,7 @@ class DIE(object): ...@@ -89,6 +93,7 @@ class DIE(object):
""" A placeholder class indicating end of queue """ """ A placeholder class indicating end of queue """
pass pass
def ensure_proc_terminate(proc): def ensure_proc_terminate(proc):
if isinstance(proc, list): if isinstance(proc, list):
for p in proc: for p in proc:
...@@ -114,6 +119,7 @@ def mask_sigint(): ...@@ -114,6 +119,7 @@ def mask_sigint():
yield yield
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc): def start_proc_mask_signal(proc):
if not isinstance(proc, list): if not isinstance(proc, list):
proc = [proc] proc = [proc]
...@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc): ...@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc):
for p in proc: for p in proc:
p.start() p.start()
def subproc_call(cmd, timeout=None): def subproc_call(cmd, timeout=None):
try: try:
output = subprocess.check_output( output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT, cmd, stderr=subprocess.STDOUT,
shell=True, timeout=timeout) shell=True, timeout=timeout)
return output return output
except subprocess.TimeoutExpired as e: except subprocess.TimeoutExpired as e:
logger.warn("Command timeout!") logger.warn("Command timeout!")
...@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None): ...@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
logger.warn("Commnad failed: {}".format(e.returncode)) logger.warn("Commnad failed: {}".format(e.returncode))
logger.warn(e.output) logger.warn(e.output)
class OrderedContainer(object): class OrderedContainer(object):
""" """
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2). Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
""" """
def __init__(self, start=0): def __init__(self, start=0):
self.ranks = [] self.ranks = []
self.data = [] self.data = []
...@@ -163,11 +172,13 @@ class OrderedContainer(object): ...@@ -163,11 +172,13 @@ class OrderedContainer(object):
self.wait_for += 1 self.wait_for += 1
return rank, ret return rank, ret
class OrderedResultGatherProc(multiprocessing.Process): class OrderedResultGatherProc(multiprocessing.Process):
""" """
Gather indexed data from a data queue, and produce results with the Gather indexed data from a data queue, and produce results with the
original index-based order. original index-based order.
""" """
def __init__(self, data_queue, nr_producer, start=0): def __init__(self, data_queue, nr_producer, start=0):
""" """
:param data_queue: a multiprocessing.Queue to produce input dp :param data_queue: a multiprocessing.Queue to produce input dp
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import sys import sys
__all__ = ['enable_call_trace'] __all__ = ['enable_call_trace']
def enable_call_trace(): def enable_call_trace():
def tracer(frame, event, arg): def tracer(frame, event, arg):
if event == 'call': if event == 'call':
...@@ -21,9 +22,9 @@ def enable_call_trace(): ...@@ -21,9 +22,9 @@ def enable_call_trace():
if caller: if caller:
caller_line_no = caller.f_lineno caller_line_no = caller.f_lineno
caller_filename = caller.f_code.co_filename caller_filename = caller.f_code.co_filename
print('Call to `%s` on line %s:%s from %s:%s' % \ print('Call to `%s` on line %s:%s from %s:%s' %
(func_name, func_filename, func_line_no, (func_name, func_filename, func_line_no,
caller_filename, caller_line_no)) caller_filename, caller_line_no))
return return
sys.settrace(tracer) sys.settrace(tracer)
...@@ -32,6 +33,7 @@ if __name__ == '__main__': ...@@ -32,6 +33,7 @@ if __name__ == '__main__':
def b(a): def b(a):
print(2) print(2)
def a(): def a():
print(1) print(1)
b(1) b(1)
......
...@@ -12,11 +12,14 @@ from six.moves import range ...@@ -12,11 +12,14 @@ from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND'] __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
@memoized @memoized
def log_once(s): def log_once(s):
logger.warn(s) logger.warn(s)
# just a placeholder # just a placeholder
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Discretizer(object): class Discretizer(object):
...@@ -28,10 +31,13 @@ class Discretizer(object): ...@@ -28,10 +31,13 @@ class Discretizer(object):
def get_bin(self, v): def get_bin(self, v):
pass pass
class Discretizer1D(Discretizer): class Discretizer1D(Discretizer):
pass pass
class UniformDiscretizer1D(Discretizer1D): class UniformDiscretizer1D(Discretizer1D):
def __init__(self, minv, maxv, spacing): def __init__(self, minv, maxv, spacing):
""" """
:params minv: minimum value of the first bin :params minv: minimum value of the first bin
...@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D):
log_once("UniformDiscretizer1D: value larger than max!") log_once("UniformDiscretizer1D: value larger than max!")
return self.nr_bin - 1 return self.nr_bin - 1
return int(np.clip( return int(np.clip(
(v - self.minv) / self.spacing, (v - self.minv) / self.spacing,
0, self.nr_bin - 1)) 0, self.nr_bin - 1))
def get_bin_center(self, bin_id): def get_bin_center(self, bin_id):
return self.minv + self.spacing * (bin_id + 0.5) return self.minv + self.spacing * (bin_id + 0.5)
...@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
if v >= self.maxv or v <= self.minv: if v >= self.maxv or v <= self.minv:
return ret return ret
try: try:
for k in range(1, smooth_radius+1): for k in range(1, smooth_radius + 1):
ret[b+k] = smooth_factor ** k ret[b + k] = smooth_factor ** k
except IndexError: except IndexError:
pass pass
for k in range(1, min(smooth_radius+1, b+1)): for k in range(1, min(smooth_radius + 1, b + 1)):
ret[b-k] = smooth_factor ** k ret[b - k] = smooth_factor ** k
ret /= ret.sum() ret /= ret.sum()
return ret return ret
class UniformDiscretizerND(Discretizer): class UniformDiscretizerND(Discretizer):
def __init__(self, *min_max_spacing): def __init__(self, *min_max_spacing):
""" """
:params min_max_spacing: (minv, maxv, spacing) for each dimension :params min_max_spacing: (minv, maxv, spacing) for each dimension
...@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer): ...@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
if __name__ == '__main__': if __name__ == '__main__':
#u = UniformDiscretizer1D(-10, 10, 0.12) #u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1)) u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
import IPython as IP; import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config()) IP.embed(config=IP.terminal.ipapp.load_default_config())
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
# File: fs.py # File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, sys import os
import sys
from six.moves import urllib from six.moves import urllib
import errno import errno
from . import logger from . import logger
__all__ = ['mkdir_p', 'download', 'recursive_walk'] __all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname): def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists""" """ make a dir recursively, but do nothing if the dir exists"""
assert dirname is not None assert dirname is not None
...@@ -21,6 +23,7 @@ def mkdir_p(dirname): ...@@ -21,6 +23,7 @@ def mkdir_p(dirname):
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise e raise e
def download(url, dir): def download(url, dir):
mkdir_p(dir) mkdir_p(dir)
fname = url.split('/')[-1] fname = url.split('/')[-1]
...@@ -29,7 +32,7 @@ def download(url, dir): ...@@ -29,7 +32,7 @@ def download(url, dir):
def _progress(count, block_size, total_size): def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, (fname,
min(float(count * block_size)/ total_size, min(float(count * block_size) / total_size,
1.0) * 100.0)) 1.0) * 100.0))
sys.stdout.flush() sys.stdout.flush()
try: try:
...@@ -45,6 +48,7 @@ def download(url, dir): ...@@ -45,6 +48,7 @@ def download(url, dir):
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.') print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath return fpath
def recursive_walk(rootdir): def recursive_walk(rootdir):
for r, dirs, files in os.walk(rootdir): for r, dirs, files in os.walk(rootdir):
for f in files: for f in files:
......
...@@ -9,13 +9,15 @@ import argparse ...@@ -9,13 +9,15 @@ import argparse
__all__ = ['globalns', 'use_global_argument'] __all__ = ['globalns', 'use_global_argument']
if six.PY2: if six.PY2:
class NS: pass class NS:
pass
else: else:
import types import types
NS = types.SimpleNamespace NS = types.SimpleNamespace
globalns = NS() globalns = NS()
def use_global_argument(args): def use_global_argument(args):
""" """
Add the content of argparse.Namespace to globalns Add the content of argparse.Namespace to globalns
......
...@@ -8,20 +8,22 @@ from .utils import change_env ...@@ -8,20 +8,22 @@ from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus'] __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val): def change_gpu(val):
val = str(val) val = str(val)
if val == '-1': if val == '-1':
val = '' val = ''
return change_env('CUDA_VISIBLE_DEVICES', val) return change_env('CUDA_VISIBLE_DEVICES', val)
def get_nr_gpu(): def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO assert env is not None, 'gpu not set!' # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus(): def get_gpus():
""" return a list of GPU physical id""" """ return a list of GPU physical id"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(',')) return map(int, env.strip().split(','))
...@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb'] ...@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto" CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
class CaffeLayerProcessor(object): class CaffeLayerProcessor(object):
def __init__(self, net): def __init__(self, net):
self.net = net self.net = net
self.layer_names = net._layer_names self.layer_names = net._layer_names
...@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object): ...@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object):
self.param_dict.update(dic) self.param_dict.update(dic)
elif len(layer.blobs) != 0: elif len(layer.blobs) != 0:
logger.warn( logger.warn(
"{} layer contains parameters but is not supported!".format(layer.type)) "{} layer contains parameters but is not supported!".format(layer.type))
return self.param_dict return self.param_dict
def proc_conv(self, idx, name, param): def proc_conv(self, idx, name, param):
assert len(param) <= 2 assert len(param) <= 2
assert param[0].data.ndim == 4 assert param[0].data.ndim == 4
# caffe: ch_out, ch_in, h, w # caffe: ch_out, ch_in, h, w
W = param[0].data.transpose(2,3,1,0) W = param[0].data.transpose(2, 3, 1, 0)
if len(param) == 1: if len(param) == 1:
return {name + '/W': W} return {name + '/W': W}
else: else:
...@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object): ...@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
logger.info("FC layer {} takes spatial data.".format(name)) logger.info("FC layer {} takes spatial data.".format(name))
W = param[0].data W = param[0].data
# original: outx(CxHxW) # original: outx(CxHxW)
W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2,3,1,0) W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2, 3, 1, 0)
# become: (HxWxC)xout # become: (HxWxC)xout
else: else:
W = param[0].data.transpose() W = param[0].data.transpose()
...@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object): ...@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
def proc_bn(self, idx, name, param): def proc_bn(self, idx, name, param):
assert param[2].data[0] == 1.0 assert param[2].data[0] == 1.0
return {name +'/mean/EMA': param[0].data, return {name + '/mean/EMA': param[0].data,
name +'/variance/EMA': param[1].data } name + '/variance/EMA': param[1].data}
def proc_scale(self, idx, name, param): def proc_scale(self, idx, name, param):
bottom_name = self.net.bottom_names[name][0] bottom_name = self.net.bottom_names[name][0]
...@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object): ...@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
logger.info("Merge {} and {} into one BatchNorm layer".format( logger.info("Merge {} and {} into one BatchNorm layer".format(
name, name2)) name, name2))
return {name2 + '/beta': param[1].data, return {name2 + '/beta': param[1].data,
name2 + '/gamma': param[0].data } name2 + '/gamma': param[0].data}
# assume this scaling layer is part of some BN # assume this scaling layer is part of some BN
logger.error("Could not find a BN layer corresponding to this Scale layer!") logger.error("Could not find a BN layer corresponding to this Scale layer!")
raise ValueError() raise ValueError()
...@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file): ...@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
caffe.set_mode_cpu() caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST) net = caffe.Net(model_desc, model_file, caffe.TEST)
param_dict = CaffeLayerProcessor(net).process() param_dict = CaffeLayerProcessor(net).process()
logger.info("Model loaded from caffe. Params: " + \ logger.info("Model loaded from caffe. Params: " +
" ".join(sorted(param_dict.keys()))) " ".join(sorted(param_dict.keys())))
return param_dict return param_dict
def get_caffe_pb(): def get_caffe_pb():
dir = get_dataset_path('caffe') dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
...@@ -116,7 +119,7 @@ def get_caffe_pb(): ...@@ -116,7 +119,7 @@ def get_caffe_pb():
assert os.path.isfile(os.path.join(dir, 'caffe.proto')) assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir)) ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \ assert ret == 0, \
"Command `protoc caffe.proto --python_out .` failed!" "Command `protoc caffe.proto --python_out .` failed!"
import imp import imp
return imp.load_source('caffepb', caffe_pb_file) return imp.load_source('caffepb', caffe_pb_file)
...@@ -131,4 +134,3 @@ if __name__ == '__main__': ...@@ -131,4 +134,3 @@ if __name__ == '__main__':
import numpy as np import numpy as np
np.save(args.output, ret) np.save(args.output, ret)
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import logging import logging
import os, shutil import os
import shutil
import os.path import os.path
from termcolor import colored from termcolor import colored
from datetime import datetime from datetime import datetime
...@@ -12,7 +13,9 @@ import sys ...@@ -12,7 +13,9 @@ import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', 'warn_dependency'] __all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', 'warn_dependency']
class _MyFormatter(logging.Formatter): class _MyFormatter(logging.Formatter):
def format(self, record): def format(self, record):
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green') date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s' msg = '%(message)s'
...@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter): ...@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
self._fmt = fmt self._fmt = fmt
return super(_MyFormatter, self).format(record) return super(_MyFormatter, self).format(record)
def _getlogger(): def _getlogger():
logger = logging.getLogger('tensorpack') logger = logging.getLogger('tensorpack')
logger.propagate = False logger.propagate = False
...@@ -45,6 +49,8 @@ def get_time_str(): ...@@ -45,6 +49,8 @@ def get_time_str():
# logger file and directory: # logger file and directory:
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
LOG_DIR = None LOG_DIR = None
def _set_file(path): def _set_file(path):
if os.path.isfile(path): if os.path.isfile(path):
backup_name = path + '.' + get_time_str() backup_name = path + '.' + get_time_str()
...@@ -56,6 +62,7 @@ def _set_file(path): ...@@ -56,6 +62,7 @@ def _set_file(path):
_logger.addHandler(hdl) _logger.addHandler(hdl)
_logger.info("Argv: " + ' '.join(sys.argv)) _logger.info("Argv: " + ' '.join(sys.argv))
def set_logger_dir(dirname, action=None): def set_logger_dir(dirname, action=None):
""" """
Set the directory for global logging. Set the directory for global logging.
...@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', ...@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func) locals()[func] = getattr(_logger, func)
def disable_logger(): def disable_logger():
""" disable all logging ability from this moment""" """ disable all logging ability from this moment"""
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
globals()[func] = lambda x: None globals()[func] = lambda x: None
def auto_set_dir(action=None, overwrite=False): def auto_set_dir(action=None, overwrite=False):
""" set log directory to a subdir inside 'train_log', with the name being """ set log directory to a subdir inside 'train_log', with the name being
the main python file currently running""" the main python file currently running"""
...@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False): ...@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False):
mod = sys.modules['__main__'] mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__) basename = os.path.basename(mod.__file__)
set_logger_dir( set_logger_dir(
os.path.join('train_log', os.path.join('train_log',
basename[:basename.rfind('.')]), basename[:basename.rfind('.')]),
action=action) action=action)
def warn_dependency(name, dependencies): def warn_dependency(name, dependencies):
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) warn("Failed to import '{}', {} won't be available'".format(dependencies, name))
...@@ -7,10 +7,12 @@ import six ...@@ -7,10 +7,12 @@ import six
__all__ = ['LookUpTable'] __all__ = ['LookUpTable']
class LookUpTable(object): class LookUpTable(object):
def __init__(self, objlist): def __init__(self, objlist):
self.idx2obj = dict(enumerate(objlist)) self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v : k for k, v in six.iteritems(self.idx2obj)} self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
def size(self): def size(self):
return len(self.idx2obj) return len(self.idx2obj)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import numpy as np import numpy as np
class Rect(object): class Rect(object):
""" """
A Rectangle. A Rectangle.
...@@ -68,7 +69,7 @@ class Rect(object): ...@@ -68,7 +69,7 @@ class Rect(object):
def roi(self, img): def roi(self, img):
assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2]) assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2])
return img[self.y0:self.y1+1, self.x0:self.x1+1] return img[self.y0:self.y1 + 1, self.x0:self.x1 + 1]
def expand(self, frac): def expand(self, frac):
assert frac > 1.0, frac assert frac > 1.0, frac
...@@ -92,7 +93,7 @@ class Rect(object): ...@@ -92,7 +93,7 @@ class Rect(object):
xmax = min(self.x1, img.shape[1]) xmax = min(self.x1, img.shape[1])
ymax = min(self.y1, img.shape[0]) ymax = min(self.y1, img.shape[0])
patch = img[ymin:ymax, xmin:xmax] patch = img[ymin:ymax, xmin:xmax]
ret[ystart:ystart+patch.shape[0],xstart:xstart+patch.shape[1]] = patch ret[ystart:ystart + patch.shape[0], xstart:xstart + patch.shape[1]] = patch
return ret return ret
__repr__ = __str__ __repr__ = __str__
...@@ -101,6 +102,6 @@ class Rect(object): ...@@ -101,6 +102,6 @@ class Rect(object):
if __name__ == '__main__': if __name__ == '__main__':
x = Rect(2, 1, 3, 3, allow_neg=True) x = Rect(2, 1, 3, 3, allow_neg=True)
img = np.random.rand(3,3) img = np.random.rand(3, 3)
print(img) print(img)
print(x.roi_zeropad(img)) print(x.roi_zeropad(img))
...@@ -10,10 +10,12 @@ msgpack_numpy.patch() ...@@ -10,10 +10,12 @@ msgpack_numpy.patch()
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
#return dill.dumps(obj) # return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True) return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
#return dill.loads(buf) # return dill.loads(buf)
return msgpack.loads(buf) return msgpack.loads(buf)
This diff is collapsed.
...@@ -14,10 +14,12 @@ from .stats import StatCounter ...@@ -14,10 +14,12 @@ from .stats import StatCounter
from . import logger from . import logger
__all__ = ['total_timer', 'timed_operation', __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter'] 'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object): class IterSpeedCounter(object):
""" To count how often some code gets reached""" """ To count how often some code gets reached"""
def __init__(self, print_every, name=None): def __init__(self, print_every, name=None):
self.cnt = 0 self.cnt = 0
self.print_every = int(print_every) self.print_every = int(print_every)
...@@ -36,6 +38,7 @@ class IterSpeedCounter(object): ...@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format( logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt)) self.name, t, self.cnt, t / self.cnt))
@contextmanager @contextmanager
def timed_operation(msg, log_start=False): def timed_operation(msg, log_start=False):
if log_start: if log_start:
...@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False): ...@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
_TOTAL_TIMER_DATA = defaultdict(StatCounter) _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager @contextmanager
def total_timer(msg): def total_timer(msg):
start = time.time() start = time.time()
...@@ -54,6 +58,7 @@ def total_timer(msg): ...@@ -54,6 +58,7 @@ def total_timer(msg):
t = time.time() - start t = time.time() - start
_TOTAL_TIMER_DATA[msg].feed(t) _TOTAL_TIMER_DATA[msg].feed(t)
def print_total_timer(): def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
return return
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment