Commit fb2a051c authored by Yuxin Wu's avatar Yuxin Wu

run autopep8 over tensorpack/

parent 59553585
......@@ -8,6 +8,8 @@ import os
import os.path
__all__ = []
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -20,4 +22,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
_global_import(module_name)
......@@ -9,7 +9,8 @@ from collections import deque
from .envbase import ProxyPlayer
__all__ = ['PreventStuckPlayer', 'LimitLengthPlayer', 'AutoRestartPlayer',
'MapPlayerState']
'MapPlayerState']
class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
......@@ -17,6 +18,7 @@ class PreventStuckPlayer(ProxyPlayer):
where the agent needs to press the 'start' button to start playing.
"""
# TODO hash the state as well?
def __init__(self, player, nr_repeat, action):
"""
It does auto-reset, but doesn't auto-restart the underlying player.
......@@ -40,10 +42,12 @@ class PreventStuckPlayer(ProxyPlayer):
super(PreventStuckPlayer, self).restart_episode()
self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout
"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
......@@ -64,9 +68,11 @@ class LimitLengthPlayer(ProxyPlayer):
self.player.restart_episode()
self.cnt = 0
class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
def action(self, act):
r, isOver = self.player.action(act)
if isOver:
......@@ -74,7 +80,9 @@ class AutoRestartPlayer(ProxyPlayer):
self.player.restart_episode()
return r, isOver
class MapPlayerState(ProxyPlayer):
def __init__(self, player, func):
super(MapPlayerState, self).__init__(player)
self.func = func
......
......@@ -13,8 +13,10 @@ from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
@six.add_metaclass(ABCMeta)
class RLEnvironment(object):
def __init__(self):
self.reset_stat()
......@@ -60,13 +62,15 @@ class RLEnvironment(object):
s = self.current_state()
act = func(s)
r, isOver = self.action(act)
#print r
# print r
if isOver:
s = [self.stats[k] for k in stat]
self.reset_stat()
return s if len(s) > 1 else s[0]
class ActionSpace(object):
def __init__(self):
self.rng = get_rng(self)
......@@ -77,7 +81,9 @@ class ActionSpace(object):
def num_actions(self):
raise NotImplementedError()
class DiscreteActionSpace(ActionSpace):
def __init__(self, num):
super(DiscreteActionSpace, self).__init__()
self.num = num
......@@ -94,19 +100,25 @@ class DiscreteActionSpace(ActionSpace):
def __str__(self):
return "DiscreteActionSpace({})".format(self.num)
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """
def __init__(self, player):
self.player = player
......
......@@ -10,14 +10,15 @@ import six
from six.moves import queue
from ..dataflow import DataFlow
from ..utils import logger, get_tqdm
from ..utils import logger, get_tqdm, get_rng
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback
__all__ = ['ExpReplay']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'isOver'])
['state', 'action', 'reward', 'isOver'])
class ExpReplay(DataFlow, Callback):
"""
......@@ -27,19 +28,20 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def __init__(self,
predictor_io_names,
player,
batch_size=32,
memory_size=1e6,
init_memory_size=50000,
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None,
update_frequency=1,
history_len=1
):
predictor_io_names,
player,
batch_size=32,
memory_size=1e6,
init_memory_size=50000,
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None,
update_frequency=1,
history_len=1
):
"""
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
......@@ -78,10 +80,10 @@ class ExpReplay(DataFlow, Callback):
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
#if len(self.mem):
#from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0]))
#return
# if len(self.mem):
# from copy import deepcopy # quickly fill the memory for debug
# self.mem.append(deepcopy(self.mem[0]))
# return
old_s = self.player.current_state()
if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions))
......@@ -115,19 +117,19 @@ class ExpReplay(DataFlow, Callback):
while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2 # for debug
#def view_state(state, next_state):
#""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
#r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
#r = np.concatenate([r, r2], axis=0)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#exp = batch_exp[0]
#print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
#if exp[2] or exp[4]:
#view_state(exp[0], exp[1])
# import cv2 # for debug
# def view_state(state, next_state):
# """ for debugging state representation"""
# r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
# r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
# r = np.concatenate([r, r2], axis=0)
# print r.shape
# cv2.imshow("state", r)
# cv2.waitKey()
# exp = batch_exp[0]
# print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
# if exp[2] or exp[4]:
# view_state(exp[0], exp[1])
yield self._process_batch(batch_exp)
self._populate_job_queue.put(1)
......@@ -141,9 +143,10 @@ class ExpReplay(DataFlow, Callback):
# when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1)
samples = [self.mem[k] for k in range(idx, idx+self.history_len+1)]
samples = [self.mem[k] for k in range(idx, idx + self.history_len + 1)]
def concat(idx):
v = [x.state for x in samples[idx:idx+self.history_len]]
v = [x.state for x in samples[idx:idx + self.history_len]]
return np.concatenate(v, axis=2)
state = concat(0)
next_state = concat(1)
......@@ -155,12 +158,12 @@ class ExpReplay(DataFlow, Callback):
# zero-fill state before starting
zero_fill = False
for k in range(1, self.history_len):
if samples[start_idx-k].isOver:
if samples[start_idx - k].isOver:
zero_fill = True
if zero_fill:
state[:,:,-k-1] = 0
state[:, :, -k - 1] = 0
if k + 2 <= self.history_len:
next_state[:,:,-k-2] = 0
next_state[:, :, -k - 2] = 0
return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp):
......@@ -178,6 +181,7 @@ class ExpReplay(DataFlow, Callback):
def _before_train(self):
# spawn a separate thread to run policy, can speed up 1.3x
self._populate_job_queue = queue.Queue(maxsize=1)
def populate_job_func():
self._populate_job_queue.get()
with self.trainer.sess.as_default():
......@@ -203,22 +207,23 @@ class ExpReplay(DataFlow, Callback):
pass
self.player.reset_stat()
if __name__ == '__main__':
from .atari import AtariPlayer
import sys
predictor = lambda x: np.array([1,1,1,1])
predictor = lambda x: np.array([1, 1, 1, 1])
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
num_actions=player.get_action_space().num_actions(),
populate_size=1001,
history_len=4)
player=player,
num_actions=player.get_action_space().num_actions(),
populate_size=1001,
history_len=4)
E._init_memory()
for k in E.get_data():
import IPython as IP;
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
pass
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
#break
# import IPython;
# IPython.embed(config=IPython.terminal.ipapp.load_default_config())
# break
......@@ -9,7 +9,7 @@ from ..utils import logger
try:
import gym
# TODO
#gym.undo_logger_setup()
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__ = ['GymEnv']
......@@ -26,11 +26,13 @@ from .envbase import RLEnvironment, DiscreteActionSpace
_ENV_LOCK = threading.Lock()
class GymEnv(RLEnvironment):
"""
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now
"""
def __init__(self, name, dumpdir=None, viz=False, auto_restart=True):
with _ENV_LOCK:
self.gymenv = gym.make(name)
......@@ -82,7 +84,7 @@ if __name__ == '__main__':
rng = get_rng(num)
while True:
act = rng.choice(range(num))
#print act
# print act
r, o = env.action(act)
env.current_state()
if r != 0 or o:
......
......@@ -9,10 +9,12 @@ from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images
Assume player will do auto-restart.
"""
def __init__(self, player, hist_len):
"""
:param hist_len: total length of the state, including the current
......@@ -49,4 +51,3 @@ class HistoryFramePlayer(ProxyPlayer):
super(HistoryFramePlayer, self).restart_episode()
self.history.clear()
self.history.append(self.player.current_state())
......@@ -25,8 +25,8 @@ from ..utils.serialize import loads, dumps
from ..utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
try:
import zmq
......@@ -34,8 +34,10 @@ except ImportError:
logger.warn_dependency('Simulator', 'zmq')
__all__ = []
class TransitionExperience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
......@@ -44,6 +46,7 @@ class TransitionExperience(object):
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
@six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process):
......@@ -63,6 +66,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to
send states and receive the next action
"""
def __init__(self, idx, pipe_c2s, pipe_s2c):
"""
:param idx: idx of this process
......@@ -81,7 +85,7 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket = context.socket(zmq.DEALER)
s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
#s2c_socket.set_hwm(5)
# s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c)
state = player.current_state()
......@@ -97,12 +101,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
# compatibility
SimulatorProcess = SimulatorProcessStateExchange
class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
class ClientState(object):
def __init__(self):
self.memory = [] # list of Experience
......@@ -174,9 +180,11 @@ class SimulatorMaster(threading.Thread):
def __del__(self):
self.context.destroy(linger=0)
class SimulatorProcessDF(SimulatorProcessBase):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def __init__(self, idx, pipe_c2s):
super(SimulatorProcessDF, self).__init__(idx)
self.pipe_c2s = pipe_c2s
......@@ -202,12 +210,14 @@ class SimulatorProcessDF(SimulatorProcessBase):
def get_data(self):
pass
class SimulatorProcessSharedWeight(SimulatorProcessDF):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.condvar = condvar
......@@ -220,7 +230,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
with self.predictor.graph.as_default():
vars_to_update = self._params_to_update()
self.sess_updater = SessionUpdate(
self.predictor.session, vars_to_update)
self.predictor.session, vars_to_update)
# TODO setup callback for explore?
self.predictor.graph.finalize()
......@@ -245,8 +255,10 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
# can be overwritten to update more params
return tf.trainable_variables()
class WeightSync(Callback):
""" Sync weight from main process to shared_dic and notify"""
def __init__(self, condvar, shared_dic):
self.condvar = condvar
self.shared_dic = shared_dic
......@@ -260,6 +272,7 @@ class WeightSync(Callback):
def _before_train(self):
self._sync()
def _trigger_epoch(self):
self._sync()
......@@ -274,13 +287,18 @@ class WeightSync(Callback):
if __name__ == '__main__':
import random
from tensorpack.RL import NaiveRLEnvironment
class NaiveSimulator(SimulatorProcess):
def _build_player(self):
return NaiveRLEnvironment()
class NaiveActioner(SimulatorActioner):
def _get_action(self, state):
time.sleep(1)
return random.randint(1, 12)
def _on_episode_over(self, client):
#print("Over: ", client.memory)
client.memory = []
......@@ -296,4 +314,3 @@ if __name__ == '__main__':
import time
time.sleep(100)
......@@ -2,7 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from tensorpack.train import *
......
......@@ -7,6 +7,8 @@ import os
__all__ = []
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -23,4 +25,3 @@ for _, module_name, _ in walk_packages(
continue
if not module_name.startswith('_'):
_global_import(module_name)
......@@ -11,6 +11,7 @@ import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
@six.add_metaclass(ABCMeta)
class Callback(object):
""" Base class for all callbacks """
......@@ -72,7 +73,9 @@ class Callback(object):
def __str__(self):
return type(self).__name__
class ProxyCallback(Callback):
def __init__(self, cb):
self.cb = cb
......@@ -91,11 +94,13 @@ class ProxyCallback(Callback):
def __str__(self):
return "Proxy-" + str(self.cb)
class PeriodicCallback(ProxyCallback):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
"""
def __init__(self, cb, period):
"""
:param cb: a `Callback`
......@@ -111,4 +116,3 @@ class PeriodicCallback(ProxyCallback):
def __str__(self):
return "Periodic-" + str(self.cb)
......@@ -9,7 +9,9 @@ from ..utils import logger
__all__ = ['StartProcOrThread']
class StartProcOrThread(Callback):
def __init__(self, procs_threads):
"""
Start extra threads and processes before training
......@@ -20,7 +22,7 @@ class StartProcOrThread(Callback):
self._procs_threads = procs_threads
def _before_train(self):
logger.info("Starting " + \
', '.join([k.name for k in self._procs_threads]))
logger.info("Starting " +
', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads)
......@@ -6,7 +6,9 @@ from ..tfutils.common import get_op_tensor_name
__all__ = ['OutputTensorDispatcer']
class OutputTensorDispatcer(object):
def __init__(self):
self._names = []
self._idxs = []
......
......@@ -12,10 +12,12 @@ from ..tfutils import get_op_var_name
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
"""
Dump a variable to image(s) after every epoch to logger.LOG_DIR.
"""
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
"""
:param var_name: the name of the variable.
......@@ -59,4 +61,3 @@ class DumpParamAsImage(Callback):
if self.clip:
res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8'))
......@@ -10,8 +10,10 @@ from ..utils import logger
__all__ = ['RunOp']
class RunOp(Callback):
""" Run an op periodically"""
def __init__(self, setup_func, run_before=True, run_epoch=True):
"""
:param setup_func: a function that returns the op in the graph
......@@ -34,5 +36,5 @@ class RunOp(Callback):
if self.run_epoch:
self._op.run()
#def _log(self):
# def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
......@@ -12,7 +12,9 @@ from ..utils import logger
__all__ = ['Callbacks']
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
......@@ -39,10 +41,12 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
class Callbacks(Callback):
"""
A container to hold all callbacks, and execute them in the right order and proper session.
"""
def __init__(self, cbs):
"""
:param cbs: a list of `Callbacks`
......
......@@ -14,7 +14,8 @@ from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name
__all__ = ['ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta)
class Inferencer(object):
......@@ -59,12 +60,14 @@ class Inferencer(object):
def _get_output_tensors(self):
pass
class ScalarStats(Inferencer):
"""
Write some scalar tensor to both stat and summary.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the inference dataflow.
"""
def __init__(self, names_to_print, prefix='validation'):
"""
:param names_to_print: list of names of tensors, or just a name
......@@ -96,6 +99,7 @@ class ScalarStats(Inferencer):
ret[name] = stat
return ret
class ClassificationError(Inferencer):
"""
Compute classification error in batch mode, from a `wrong` variable
......@@ -109,6 +113,7 @@ class ClassificationError(Inferencer):
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
"""
def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'):
"""
:param wrong_var_name: name of the `wrong` variable
......@@ -138,6 +143,7 @@ class ClassificationError(Inferencer):
def _after_inference(self):
return {self.summary_name: self.err_stat.ratio}
class BinaryClassificationStats(Inferencer):
""" Compute precision/recall in binary classification, given the
prediction vector and the label vector.
......
......@@ -18,6 +18,7 @@ from ..train.input_data import FeedfreeInput
__all__ = ['InferenceRunner']
def summary_inferencer(trainer, infs):
for inf in infs:
ret = inf.after_inference()
......@@ -29,6 +30,7 @@ def summary_inferencer(trainer, infs):
continue
trainer.write_scalar_summary(k, v)
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
......@@ -54,16 +56,17 @@ class InferenceRunner(Callback):
self.input_tensors = input_tensors
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.pred_func = self.trainer.get_predict_func(
self.input_tensors, self.output_tensors)
self.input_tensors, self.output_tensors)
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.get_reuse_placehdrs()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
......@@ -79,6 +82,7 @@ class InferenceRunner(Callback):
IOTensor = InferenceRunner.IOTensor
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
......@@ -102,7 +106,7 @@ class InferenceRunner(Callback):
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
......@@ -110,6 +114,7 @@ class InferenceRunner(Callback):
def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
......@@ -139,9 +144,9 @@ class FeedfreeInferenceRunner(Callback):
if self.input_tensor_names is not None:
assert isinstance(self.input_tensor_names, list)
self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_tensor_names]
if model_placehdrs[idx].name in self.input_tensor_names]
assert len(self._input_tensors) == len(self.input_tensor_names), \
"names of input tensors are not defined in the Model"
"names of input tensors are not defined in the Model"
def _find_output_tensors(self):
# doesn't support output an input tensor
......@@ -152,6 +157,7 @@ class FeedfreeInferenceRunner(Callback):
IOTensor = InferenceRunner.IOTensor
self.output_tensors = all_names
def find_oid(idxs):
ret = []
for idx in idxs:
......@@ -161,7 +167,6 @@ class FeedfreeInferenceRunner(Callback):
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
......@@ -170,11 +175,11 @@ class FeedfreeInferenceRunner(Callback):
sz = self._input_data.size()
with get_tqdm(total=sz) as pbar:
for _ in range(sz):
#outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#inf_output = [(outputs if k.isOutput else dp)[k.index]
#for k in tensormap]
#inf.datapoint(inf_output)
# outputs = self.pred_func(dp)
# for inf, tensormap in zip(self.infs, self.inf_to_tensors):
# inf_output = [(outputs if k.isOutput else dp)[k.index]
# for k in tensormap]
# inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
......
......@@ -17,6 +17,8 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
@six.add_metaclass(ABCMeta)
class HyperParam(object):
""" Base class for a hyper param"""
......@@ -35,8 +37,10 @@ class HyperParam(object):
""" A name to display"""
return self._readable_name
class GraphVarParam(HyperParam):
""" a variable in the graph can be a hyperparam"""
def __init__(self, name, shape=[]):
self.name = name
self.shape = shape
......@@ -56,13 +60,15 @@ class GraphVarParam(HyperParam):
self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v):
self.assign_op.eval(feed_dict={self.val_holder:v})
self.assign_op.eval(feed_dict={self.val_holder: v})
def get_value(self):
return self.var.eval()
class ObjAttrParam(HyperParam):
""" an attribute of an object can be a hyperparam"""
def __init__(self, obj, attrname, readable_name=None):
""" :param readable_name: default to be attrname."""
self.obj = obj
......@@ -78,6 +84,7 @@ class ObjAttrParam(HyperParam):
def get_value(self, v):
return getattr(self.obj, self.attrname)
class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
......@@ -126,10 +133,12 @@ class HyperParamSetter(Callback):
if v is not None:
self.param.set_value(v)
class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by loading the value from a file each time it get called.
"""
def __init__(self, param, file_name='hyper.txt'):
"""
:param file_name: a file containing the value of the variable.
......@@ -149,7 +158,7 @@ class HumanHyperParamSetter(HyperParamSetter):
with open(self.file_name) as f:
lines = f.readlines()
lines = [s.strip().split(':') for s in lines]
dic = {str(k):float(v) for k, v in lines}
dic = {str(k): float(v) for k, v in lines}
ret = dic[self.param.readable_name]
return ret
except:
......@@ -158,10 +167,12 @@ class HumanHyperParamSetter(HyperParamSetter):
self.param.readable_name, self.file_name))
return None
class ScheduledHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by a predefined schedule.
"""
def __init__(self, param, schedule, interp=None):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
......@@ -196,7 +207,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v
class HyperParamSetterWithFunc(HyperParamSetter):
def __init__(self, param, func):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
......@@ -207,10 +220,12 @@ class HyperParamSetterWithFunc(HyperParamSetter):
def _get_value_to_set(self):
return self.f(self.epoch_num, self.get_current_value())
class StatMonitorParamSetter(HyperParamSetter):
def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False
):
last_k, reverse=False
):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
......@@ -236,22 +251,21 @@ class StatMonitorParamSetter(HyperParamSetter):
def _get_value_to_set(self):
holder = self.trainer.stat_holder
hist = holder.get_stat_history(self.stat_name)
if len(hist) < self.last_k+1 or \
if len(hist) < self.last_k + 1 or \
self.epoch_num - self.last_changed_epoch < self.last_k:
return None
hist = hist[-self.last_k-1:] # len==last_k+1
hist = hist[-self.last_k - 1:] # len==last_k+1
hist_first = hist[0]
if not self.reverse:
hist_min = min(hist)
if hist_min < hist_first - self.threshold: # small enough
if hist_min < hist_first - self.threshold: # small enough
return None
else:
hist_max = max(hist)
if hist_max > hist_first + self.threshold: # large enough
if hist_max > hist_first + self.threshold: # large enough
return None
self.last_changed_epoch = self.epoch_num
logger.info("[StatMonitorParamSetter] Triggered, history: " +
','.join(map(str, hist)))
','.join(map(str, hist)))
return self.value_func(self.get_current_value())
......@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import os, shutil
import os
import shutil
import re
from .base import Callback
......@@ -13,12 +14,14 @@ from ..tfutils import get_global_step
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
"""
Save the model to logger directory.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=None):
var_collections=None):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
......@@ -71,9 +74,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
try:
if not self.meta_graph_written:
self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR,
'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys())
os.path.join(logger.LOG_DIR,
'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True
self.saver.save(
tf.get_default_session(),
......@@ -83,7 +86,9 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback):
def __init__(self, monitor_stat, reverse=True, filename=None):
self.monitor_stat = monitor_stat
self.reverse = reverse
......@@ -116,15 +121,14 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver):
def __init__(self, monitor_stat):
super(MaxSaver, self).__init__(monitor_stat, True)
......@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re, os
import re
import os
import operator
import json
......@@ -13,10 +14,12 @@ from ..tfutils.common import get_global_step
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
class StatHolder(object):
"""
A holder to keep all statistics aside from tensorflow events.
"""
def __init__(self, log_dir):
"""
:param log_dir: directory to save the stats.
......@@ -62,9 +65,11 @@ class StatHolder(object):
ret = []
for h in self.stat_history:
v = h.get(key, None)
if v is not None: ret.append(v)
if v is not None:
ret.append(v)
v = self.stat_now.get(key, None)
if v is not None: ret.append(v)
if v is not None:
ret.append(v)
return ret
def finalize(self):
......@@ -88,13 +93,15 @@ class StatHolder(object):
with open(tmp_filename, 'w') as f:
json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
except IOError: # disk error sometimes..
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
class StatPrinter(Callback):
"""
Control what stats to print.
"""
def __init__(self, print_tag=None):
"""
:param print_tag: a list of regex to match scalar summary to print.
......@@ -116,6 +123,7 @@ class StatPrinter(Callback):
self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Callback):
"""
Execute a command with some specific stats.
......@@ -126,6 +134,7 @@ class SendStat(Callback):
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
def __init__(self, command, stats):
self.command = command
if not isinstance(stats, list):
......
......@@ -12,6 +12,7 @@ from . import imgaug
__all__ = ['dataset', 'imgaug', 'dftools']
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -24,6 +25,5 @@ __SKIP = ['dftools', 'dataset', 'imgaug']
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_') and \
module_name not in __SKIP:
module_name not in __SKIP:
_global_import(module_name)
......@@ -10,6 +10,7 @@ from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
@six.add_metaclass(ABCMeta)
class DataFlow(object):
""" Base class for all DataFlow """
......@@ -17,7 +18,6 @@ class DataFlow(object):
class Infinity:
pass
@abstractmethod
def get_data(self):
"""
......@@ -44,11 +44,14 @@ class DataFlow(object):
class RNGDataFlow(DataFlow):
""" A dataflow with rng"""
def reset_state(self):
self.rng = get_rng(self)
class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another"""
def __init__(self, ds):
"""
:param ds: a :mod:`DataFlow` instance to proxy
......
......@@ -15,7 +15,9 @@ __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData', 'TestDataSpeed', 'BatchDataByShape']
class TestDataSpeed(ProxyDataFlow):
def __init__(self, ds, size=1000):
super(TestDataSpeed, self).__init__(ds)
self.test_size = size
......@@ -31,7 +33,9 @@ class TestDataSpeed(ProxyDataFlow):
for dp in self.ds.get_data():
pbar.update()
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
Group data in `ds` into batches.
......@@ -91,11 +95,13 @@ class BatchData(ProxyDataFlow):
raise
except:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP;
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result
class BatchDataByShape(BatchData):
def __init__(self, ds, batch_size, idx):
""" Group datapoint of the same shape together to batches
......@@ -119,10 +125,12 @@ class BatchDataByShape(BatchData):
yield BatchData._aggregate_batch(holder)
del holder[:]
class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
"""
def __init__(self, ds, size):
"""
:param ds: a :mod:`DataFlow` to produce data
......@@ -154,10 +162,12 @@ class FixedSizeData(ProxyDataFlow):
if cnt == self._size:
return
class RepeatedData(ProxyDataFlow):
""" Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times.
"""
def __init__(self, ds, nr):
"""
:param ds: a :mod:`DataFlow` instance.
......@@ -184,8 +194,10 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data():
yield dp
class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint"""
def __init__(self, ds, func):
"""
:param ds: a :mod:`DataFlow` instance.
......@@ -202,8 +214,10 @@ class MapData(ProxyDataFlow):
if ret is not None:
yield ret
class MapDataComponent(ProxyDataFlow):
""" Apply map/filter on the given index in the datapoint"""
def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
......@@ -222,11 +236,13 @@ class MapDataComponent(ProxyDataFlow):
dp[self.index] = repl # NOTE modifying
yield dp
class RandomChooseData(RNGDataFlow):
"""
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
"""
def __init__(self, df_lists):
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
......@@ -257,10 +273,12 @@ class RandomChooseData(RNGDataFlow):
except StopIteration:
return
class RandomMixData(RNGDataFlow):
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
def __init__(self, df_lists):
"""
:param df_lists: list of dataflow.
......@@ -285,14 +303,16 @@ class RandomMixData(RNGDataFlow):
idxs = np.array(list(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs)))
itrs = [k.get_data() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs)-1)
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1)
for k in idxs:
yield next(itrs[k])
class ConcatData(DataFlow):
"""
Concatenate several dataflows.
"""
def __init__(self, df_lists):
"""
:param df_lists: list of :mod:`DataFlow` instances
......@@ -311,6 +331,7 @@ class ConcatData(DataFlow):
for dp in d.get_data():
yield dp
class JoinData(DataFlow):
"""
Join the components from each DataFlow.
......@@ -321,6 +342,7 @@ class JoinData(DataFlow):
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
def __init__(self, df_lists):
"""
:param df_lists: list of :mod:`DataFlow` instances
......@@ -329,7 +351,7 @@ class JoinData(DataFlow):
self._size = self.df_lists[0].size()
for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
def reset_state(self):
for d in self.df_lists:
......@@ -352,7 +374,9 @@ class JoinData(DataFlow):
for itr in itrs:
del itr
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size, nr_reuse=1):
"""
Cache a number of datapoints and shuffle them.
......@@ -393,10 +417,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
yield v
return
def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
return MapData(ds, lambda dp: [dp[i] for i in idxs])
......@@ -7,6 +7,8 @@ import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -19,4 +21,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
......@@ -3,7 +3,8 @@
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, glob
import os
import glob
import cv2
import numpy as np
......@@ -21,6 +22,7 @@ except ImportError:
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321
class BSDS500(RNGDataFlow):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
......@@ -65,7 +67,7 @@ class BSDS500(RNGDataFlow):
im = cv2.imread(f, cv2.IMREAD_COLOR)
assert im is not None
if im.shape[0] > im.shape[1]:
im = np.transpose(im, (1,0,2))
im = np.transpose(im, (1, 0, 2))
assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W))
imgid = os.path.basename(f).split('.')[0]
......@@ -96,5 +98,5 @@ class BSDS500(RNGDataFlow):
if __name__ == '__main__':
a = BSDS500('val')
for k in a.get_data():
cv2.imshow("haha", k[1].astype('uint8')*255)
cv2.imshow("haha", k[1].astype('uint8') * 255)
cv2.waitKey(1000)
......@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import os, sys
import os
import sys
import pickle
import numpy as np
import random
......@@ -23,6 +24,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
......@@ -42,6 +44,7 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar(filenames, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
ret = []
......@@ -54,7 +57,7 @@ def read_cifar(filenames, cifar_classnum):
data = dic[b'data']
if cifar_classnum == 10:
label = dic[b'labels']
IMG_NUM = 10000 # cifar10 data are split into blocks of 10000
IMG_NUM = 10000 # cifar10 data are split into blocks of 10000
elif cifar_classnum == 100:
label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000
......@@ -65,6 +68,7 @@ def read_cifar(filenames, cifar_classnum):
ret.append([img, label[k]])
return ret
def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
......@@ -77,11 +81,13 @@ def get_filenames(dir, cifar_classnum):
os.path.join(dir, 'cifar-100-python', 'test')]
return filenames
class CifarBase(RNGDataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
"""
Args:
......@@ -132,13 +138,17 @@ class CifarBase(RNGDataFlow):
return three values as mean of each channel
"""
mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1))
return np.mean(mean, axis=(0, 1))
class Cifar10(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
......@@ -149,7 +159,6 @@ if __name__ == '__main__':
print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data():
#from IPython import embed; embed()
#break
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
# break
......@@ -19,10 +19,12 @@ __all__ = ['ILSVRCMeta', 'ILSVRC12']
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object):
"""
Some metadata for ILSVRC dataset.
"""
def __init__(self, dir=None):
if dir is None:
dir = get_dataset_path('ilsvrc_metadata')
......@@ -82,14 +84,16 @@ class ILSVRCMeta(object):
with open(mean_file, 'rb') as f:
obj.ParseFromString(f.read())
arr = np.array(obj.data).reshape((3, 256, 256)).astype('float32')
arr = np.transpose(arr, [1,2,0])
arr = np.transpose(arr, [1, 2, 0])
if size is not None:
arr = cv2.resize(arr, size[::-1])
return arr
class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original', include_bb=False):
dir_structure='original', include_bb=False):
"""
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
......@@ -145,7 +149,7 @@ class ILSVRC12(RNGDataFlow):
if include_bb:
bbdir = os.path.join(dir, 'bbox') if not \
isinstance(include_bb, six.string_types) else include_bb
isinstance(include_bb, six.string_types) else include_bb
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(bbdir, self.imglist)
self.include_bb = include_bb
......@@ -171,11 +175,11 @@ class ILSVRC12(RNGDataFlow):
im = cv2.imread(fname.strip(), cv2.IMREAD_COLOR)
assert im is not None, fname
if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2)
im = np.expand_dims(im, 2).repeat(3, 2)
if self.include_bb:
bb = self.bblist[k]
if bb is None:
bb = [0, 0, im.shape[1]-1, im.shape[0]-1]
bb = [0, 0, im.shape[1] - 1, im.shape[0] - 1]
yield [im, label, bb]
else:
yield [im, label]
......@@ -216,12 +220,13 @@ class ILSVRC12(RNGDataFlow):
if __name__ == '__main__':
meta = ILSVRCMeta()
#print(meta.get_synset_words_1000())
# print(meta.get_synset_words_1000())
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True,
shuffle=False)
shuffle=False)
ds.reset_state()
for k in ds.get_data():
from IPython import embed; embed()
from IPython import embed
embed()
break
......@@ -17,6 +17,7 @@ __all__ = ['Mnist']
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
filepath = os.path.join(work_directory, filename)
......@@ -25,18 +26,20 @@ def maybe_download(filename, work_directory):
download(SOURCE_URL + filename, work_directory)
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
......@@ -46,24 +49,27 @@ def extract_images(filename):
data = data.astype('float32') / 255.0
return data
def extract_labels(filename):
"""Extract the labels into a 1D uint8 numpy array [index]."""
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
return labels
class Mnist(RNGDataFlow):
"""
Return [image, label],
image is 28x28 in the range [0,1]
"""
def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
......@@ -107,6 +113,6 @@ class Mnist(RNGDataFlow):
if __name__ == '__main__':
ds = Mnist('train')
for (img, label) in ds.get_data():
from IPython import embed; embed()
from IPython import embed
embed()
break
......@@ -24,6 +24,7 @@ TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.tra
VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@memoized_ignoreargs
def get_PennTreeBank(data_dir=None):
if data_dir is None:
......@@ -35,6 +36,5 @@ def get_PennTreeBank(data_dir=None):
# TODO these functions in TF might not be available in the future
word_to_id = tfreader._build_vocab(os.path.join(data_dir, 'ptb.train.txt'))
data3 = [np.asarray(tfreader._file_to_word_ids(os.path.join(data_dir, fname), word_to_id))
for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']]
for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']]
return data3, word_to_id
......@@ -19,6 +19,7 @@ except ImportError:
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(RNGDataFlow):
"""
SVHN Cropped Digit Dataset.
......@@ -41,12 +42,12 @@ class SVHNDigit(RNGDataFlow):
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
"File {} not found! Please download it from {}.".format(filename, SVHN_URL)
"File {} not found! Please download it from {}.".format(filename, SVHN_URL)
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
self.X = data['X'].transpose(3,0,1,2)
self.X = data['X'].transpose(3, 0, 1, 2)
self.Y = data['y'].reshape((-1))
self.Y[self.Y==10] = 0
self.Y[self.Y == 10] = 0
SVHNDigit._Cache[name] = (self.X, self.Y)
def size(self):
......
......@@ -12,6 +12,7 @@ import json
__all__ = ['VisualQA']
def read_json(fname):
f = open(fname)
ret = json.load(f)
......@@ -19,11 +20,14 @@ def read_json(fname):
return ret
# TODO shuffle
class VisualQA(DataFlow):
"""
Visual QA dataset. See http://visualqa.org/
Simply read q/a json file and produce q/a pairs in their original format.
"""
def __init__(self, question_file, annotation_file):
with timed_operation('Reading VQA JSON file'):
qobj, aobj = list(map(read_json, [question_file, annotation_file]))
......@@ -62,7 +66,7 @@ class VisualQA(DataFlow):
""" Get the n most common words in questions
n=4600 ~= thresh 6
"""
from nltk.tokenize import word_tokenize # will need to download 'punckt'
from nltk.tokenize import word_tokenize # will need to download 'punckt'
cnt = Counter()
for q in self.questions:
cnt.update(word_tokenize(q['question'].lower()))
......@@ -72,7 +76,7 @@ class VisualQA(DataFlow):
if __name__ == '__main__':
vqa = VisualQA('/home/wyx/data/VQA/MultipleChoice_mscoco_train2014_questions.json',
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
'/home/wyx/data/VQA/mscoco_train2014_annotations.json')
for k in vqa.get_data():
print(json.dumps(k))
break
......
......@@ -2,7 +2,8 @@
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import sys, os
import sys
import os
import cv2
import multiprocessing as mp
import six
......@@ -23,6 +24,8 @@ else:
__all__.extend(['dump_dataflow_to_lmdb'])
# TODO pass a name_func to write label as filename?
def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" Dump images from a `DataFlow` to a directory.
......@@ -43,6 +46,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
img = dp[index]
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
......@@ -56,8 +60,8 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
assert not os.path.isfile(lmdb_path), "LMDB file exists!"
ds.reset_state()
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end
try:
sz = ds.size()
except NotImplementedError:
......@@ -87,7 +91,9 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
the queue once you start it. Each element is (task_id, dp).
"""
q = mp.Queue(size)
class EnqueProc(mp.Process):
def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__()
self.ds = ds
......@@ -104,4 +110,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc = EnqueProc(ds, q, nr_consumer)
return q, proc
......@@ -40,10 +40,13 @@ Adapters for different data format.
"""
# TODO lazy load
class HDF5Data(RNGDataFlow):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
def __init__(self, filename, data_paths, shuffle=True):
"""
:param filename: h5 data file.
......@@ -54,7 +57,7 @@ class HDF5Data(RNGDataFlow):
logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths]
lens = [len(k) for k in self.dps]
assert all([k==lens[0] for k in lens])
assert all([k == lens[0] for k in lens])
self._size = lens[0]
self.shuffle = shuffle
......@@ -71,6 +74,7 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """
def __init__(self, lmdb_path, shuffle=True):
self._lmdb_path = lmdb_path
self._shuffle = shuffle
......@@ -78,9 +82,9 @@ class LMDBData(RNGDataFlow):
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100)
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
if self._shuffle:
......@@ -116,7 +120,9 @@ class LMDBData(RNGDataFlow):
v = self._txn.get(k)
yield [k, v]
class LMDBDataDecoder(LMDBData):
def __init__(self, lmdb_path, decoder, shuffle=True):
"""
:param decoder: a function taking k, v and return a data point,
......@@ -128,18 +134,24 @@ class LMDBDataDecoder(LMDBData):
def get_data(self):
for dp in super(LMDBDataDecoder, self).get_data():
v = self.decoder(dp[0], dp[1])
if v: yield v
if v:
yield v
class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file where each value is a serialized datapoint"""
def __init__(self, lmdb_path, shuffle=True):
super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_path, shuffle=True):
cpb = get_caffe_pb()
def decoder(k, v):
try:
datum = cpb.Datum()
......@@ -152,10 +164,12 @@ class CaffeLMDB(LMDBDataDecoder):
return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__(
lmdb_path, decoder=decoder, shuffle=shuffle)
lmdb_path, decoder=decoder, shuffle=shuffle)
class SVMLightData(RNGDataFlow):
""" Read X,y from a svmlight file """
def __init__(self, filename, shuffle=True):
self.X, self.y = sklearn.datasets.load_svmlight_file(filename)
self.X = np.asarray(self.X.todense())
......@@ -169,4 +183,4 @@ class SVMLightData(RNGDataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for id in idxs:
yield [self.X[id,:], self.y[id]]
yield [self.X[id, :], self.y[id]]
......@@ -11,7 +11,9 @@ from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class ImageFromFile(RNGDataFlow):
def __init__(self, files, channel=3, resize=None, shuffle=False):
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
......@@ -39,11 +41,12 @@ class ImageFromFile(RNGDataFlow):
if self.resize is not None:
im = cv2.resize(im, self.resize[::-1])
if self.channel == 1:
im = im[:,:,np.newaxis]
im = im[:, :, np.newaxis]
yield [im]
class AugmentImageComponent(MapDataComponent):
def __init__(self, ds, augmentors, index=0):
"""
Augment the image component of datapoints
......@@ -64,7 +67,8 @@ class AugmentImageComponent(MapDataComponent):
class AugmentImageComponents(MapData):
def __init__(self, ds, augmentors, index=(0,1)):
def __init__(self, ds, augmentors, index=(0, 1)):
""" Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
......
......@@ -7,6 +7,7 @@ from pkgutil import walk_packages
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -19,4 +20,3 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
......@@ -10,15 +10,15 @@ from .crop import *
from .imgproc import *
from .noname import *
from .deform import *
from .noise import SaltPepperNoise
from .noise import SaltPepperNoise
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
Contrast((0.8,1.2)),
Contrast((0.8, 1.2)),
Flip(horiz=True),
GaussianDeform(anchors, (360,480), 0.2, randrange=20),
#RandomCropRandomShape(0.3),
GaussianDeform(anchors, (360, 480), 0.2, randrange=20),
# RandomCropRandomShape(0.3),
SaltPepperNoise()
])
......
......@@ -9,6 +9,7 @@ from six.moves import zip
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
@six.add_metaclass(ABCMeta)
class Augmentor(object):
""" Base class for an augmentor"""
......@@ -58,7 +59,9 @@ class Augmentor(object):
size = []
return self.rng.uniform(low, high, size)
class ImageAugmentor(Augmentor):
def augment(self, img):
"""
Perform augmentation on the image in-place.
......@@ -71,10 +74,12 @@ class ImageAugmentor(Augmentor):
def _fprop_coord(self, coord, param):
return coord
class AugmentorList(ImageAugmentor):
"""
Augment by a list of augmentors
"""
def __init__(self, augmentors):
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
......@@ -107,4 +112,3 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """
for a in self.augs:
a.reset_state()
......@@ -10,10 +10,12 @@ from six.moves import range
import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop',
'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox']
'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox']
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape):
"""
:param crop_shape: a shape like (h, w)
......@@ -25,7 +27,7 @@ class RandomCrop(ImageAugmentor):
def _get_augment_params(self, img):
orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape
and orig_shape[1] >= self.crop_shape[1], orig_shape
diffh = orig_shape[0] - self.crop_shape[0]
h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1]
......@@ -34,13 +36,15 @@ class RandomCrop(ImageAugmentor):
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class CenterCrop(ImageAugmentor):
""" Crop the image at the center"""
def __init__(self, crop_shape):
"""
:param crop_shape: a shape like (h, w)
......@@ -52,13 +56,15 @@ class CenterCrop(ImageAugmentor):
orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location"""
def __init__(self, rect):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
......@@ -69,15 +75,16 @@ class FixedCrop(ImageAugmentor):
def _augment(self, img, _):
orig_shape = img.shape
return img[self.rect.y0: self.rect.y1+1,
self.rect.x0: self.rect.x0+1]
return img[self.rect.y0: self.rect.y1 + 1,
self.rect.x0: self.rect.x0 + 1]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
def perturb_BB(image_shape, bb, max_pertub_pixel,
rng=None, max_aspect_ratio_diff=0.3,
max_try=100):
rng=None, max_aspect_ratio_diff=0.3,
max_try=100):
"""
Perturb a bounding box.
:param image_shape: [h, w]
......@@ -113,6 +120,7 @@ class RandomCropAroundBox(ImageAugmentor):
"""
Crop a box around a bounding box
"""
def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3):
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
......@@ -124,9 +132,9 @@ class RandomCropAroundBox(ImageAugmentor):
def _get_augment_params(self, img):
shape = img.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1])
dist = self.perturb_ratio * np.sqrt(shape[0] * shape[1])
newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff)
self.rng, self.max_aspect_ratio_diff)
return newbox
def _augment(self, img, newbox):
......@@ -135,10 +143,12 @@ class RandomCropAroundBox(ImageAugmentor):
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class RandomCropRandomShape(ImageAugmentor):
def __init__(self, wmin, hmin,
wmax=None, hmax=None,
max_aspect_ratio=None):
wmax=None, hmax=None,
max_aspect_ratio=None):
"""
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape.
......@@ -151,18 +161,18 @@ class RandomCropRandomShape(ImageAugmentor):
def _get_augment_params(self, img):
hmax = self.hmax or img.shape[0]
wmax = self.wmax or img.shape[1]
h = self.rng.randint(self.hmin, hmax+1)
w = self.rng.randint(self.wmin, wmax+1)
h = self.rng.randint(self.hmin, hmax + 1)
w = self.rng.randint(self.wmin, wmax + 1)
diffh = img.shape[0] - h
diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0,x0,h,w)
return (y0, x0, h, w)
def _augment(self, img, param):
y0, x0, h, w = param
return img[y0:y0+h,x0:x0+w]
return img[y0:y0 + h, x0:x0 + w]
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
......@@ -10,8 +10,10 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup
class GaussianMap(object):
""" Generate gaussian weighted deformation map"""
def __init__(self, image_shape, sigma=0.5):
assert len(image_shape) == 2
self.shape = image_shape
......@@ -25,17 +27,18 @@ class GaussianMap(object):
x = x.astype('float32') / ret.shape[1] - anchor[1]
g = np.exp(-(x**2 + y ** 2) / self.sigma)
#cv2.imshow(" ", g)
#cv2.waitKey()
# cv2.waitKey()
return g
def np_sample(img, coords):
# a numpy implementation of ImageSample layer
coords = np.maximum(coords, 0)
coords = np.minimum(coords, np.array([img.shape[0]-1, img.shape[1]-1]))
coords = np.minimum(coords, np.array([img.shape[0] - 1, img.shape[1] - 1]))
lcoor = np.floor(coords).astype('int32')
ucoor = lcoor + 1
ucoor = np.minimum(ucoor, np.array([img.shape[0]-1, img.shape[1]-1]))
ucoor = np.minimum(ucoor, np.array([img.shape[0] - 1, img.shape[1] - 1]))
diff = coords - lcoor
neg_diff = 1.0 - diff
......@@ -46,17 +49,20 @@ def np_sample(img, coords):
diffy, diffx = np.split(diff, 2, axis=2)
ndiffy, ndiffx = np.split(neg_diff, 2, axis=2)
ret = img[lcoory,lcoorx,:] * ndiffx * ndiffy + \
img[ucoory, ucoorx,:] * diffx * diffy + \
img[lcoory, ucoorx,:] * ndiffy * diffx + \
img[ucoory,lcoorx,:] * diffy * ndiffx
return ret[:,:,0,:]
ret = img[lcoory, lcoorx, :] * ndiffx * ndiffy + \
img[ucoory, ucoorx, :] * diffx * diffy + \
img[lcoory, ucoorx, :] * ndiffy * diffx + \
img[ucoory, lcoorx, :] * diffy * ndiffx
return ret[:, :, 0, :]
# TODO input/output with different shape
class GaussianDeform(ImageAugmentor):
"""
Some kind of deformation. Quite slow.
"""
def __init__(self, anchors, shape, sigma=0.5, randrange=None):
"""
:param anchors: in [0,1] coordinate
......@@ -69,13 +75,13 @@ class GaussianDeform(ImageAugmentor):
self.anchors = anchors
self.K = len(self.anchors)
self.shape = shape
self.grid = np.mgrid[0:self.shape[0], 0:self.shape[1]].transpose(1,2,0)
self.grid = self.grid.astype('float32') # HxWx2
self.grid = np.mgrid[0:self.shape[0], 0:self.shape[1]].transpose(1, 2, 0)
self.grid = self.grid.astype('float32') # HxWx2
gm = GaussianMap(self.shape, sigma=sigma)
self.gws = np.array([gm.get_gaussian_weight(ank)
for ank in self.anchors], dtype='float32') # KxHxW
self.gws = self.gws.transpose(1, 2, 0) #HxWxK
for ank in self.anchors], dtype='float32') # KxHxW
self.gws = self.gws.transpose(1, 2, 0) # HxWxK
if randrange is None:
self.randrange = self.shape[0] / 8
else:
......
......@@ -10,11 +10,13 @@ import numpy as np
__all__ = ['Rotation', 'RotationAndCropValid']
class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0,1),
interp=cv2.INTER_CUBIC,
border=cv2.BORDER_REPLICATE):
def __init__(self, max_deg, center_range=(0, 1),
interp=cv2.INTER_CUBIC,
border=cv2.BORDER_REPLICATE):
"""
:param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center
......@@ -24,19 +26,21 @@ class Rotation(ImageAugmentor):
def _get_augment_params(self, img):
center = img.shape[1::-1] * self._rand_range(
self.center_range[0], self.center_range[1], (2,))
self.center_range[0], self.center_range[1], (2,))
deg = self._rand_range(-self.max_deg, self.max_deg)
return cv2.getRotationMatrix2D(tuple(center), deg, 1)
def _augment(self, img, rot_m):
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=self.border)
flags=self.interp, borderMode=self.border)
return ret
class RotationAndCropValid(ImageAugmentor):
""" Random rotate and crop the largest possible rect without the border
This will produce images of different shapes.
"""
def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
super(RotationAndCropValid, self).__init__()
self._init(locals())
......@@ -46,39 +50,39 @@ class RotationAndCropValid(ImageAugmentor):
return deg
def _augment(self, img, deg):
center = (img.shape[1]*0.5, img.shape[0]*0.5)
center = (img.shape[1] * 0.5, img.shape[0] * 0.5)
rot_m = cv2.getRotationMatrix2D(center, deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0])
newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5)
#print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy+newh,newx:newx+neww]
return ret[newy:newy + newh, newx:newx + neww]
@staticmethod
def largest_rotated_rect(w, h, angle):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle = angle / 180.0 * math.pi
if w <= 0 or h <= 0:
return 0,0
return 0, 0
width_is_longer = w >= h
side_long, side_short = (w,h) if width_is_longer else (h,w)
side_long, side_short = (w, h) if width_is_longer else (h, w)
# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
if side_short <= 2.*sin_a*cos_a*side_long:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x = 0.5*side_short
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
if side_short <= 2. * sin_a * cos_a * side_long:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x = 0.5 * side_short
wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
else:
# fully constrained case: crop touches all 4 sides
cos_2a = cos_a*cos_a - sin_a*sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
# fully constrained case: crop touches all 4 sides
cos_2a = cos_a * cos_a - sin_a * sin_a
wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
return int(wr), int(hr)
......@@ -7,12 +7,14 @@ import numpy as np
import cv2
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
'Gamma', 'Clip', 'Saturation', 'Lighting']
'Gamma', 'Clip', 'Saturation', 'Lighting']
class Brightness(ImageAugmentor):
"""
Random adjust brightness.
"""
def __init__(self, delta, clip=True):
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
......@@ -31,11 +33,13 @@ class Brightness(ImageAugmentor):
img = np.clip(img, 0, 255)
return img
class Contrast(ImageAugmentor):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
"""
def __init__(self, factor_range, clip=True):
"""
:param factor_range: an interval to random sample the `contrast_factor`.
......@@ -48,18 +52,20 @@ class Contrast(ImageAugmentor):
return self._rand_range(*self.factor_range)
def _augment(self, img, r):
mean = np.mean(img, axis=(0,1), keepdims=True)
mean = np.mean(img, axis=(0, 1), keepdims=True)
img = (img - mean) * r + mean
if self.clip:
img = np.clip(img, 0, 255)
return img
class MeanVarianceNormalize(ImageAugmentor):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
def __init__(self, all_channel=True):
"""
:param all_channel: if True, normalize all channels together. else separately.
......@@ -71,14 +77,15 @@ class MeanVarianceNormalize(ImageAugmentor):
mean = np.mean(img)
std = np.std(img)
else:
mean = np.mean(img, axis=(0,1), keepdims=True)
std = np.std(img, axis=(0,1), keepdims=True)
mean = np.mean(img, axis=(0, 1), keepdims=True)
std = np.std(img, axis=(0, 1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape)))
img = (img - mean) / std
return img
class GaussianBlur(ImageAugmentor):
def __init__(self, max_size=3):
""":params max_size: (maximum kernel size-1)/2"""
super(GaussianBlur, self).__init__()
......@@ -92,10 +99,11 @@ class GaussianBlur(ImageAugmentor):
def _augment(self, img, s):
return cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
borderType=cv2.BORDER_REPLICATE)
borderType=cv2.BORDER_REPLICATE)
class Gamma(ImageAugmentor):
def __init__(self, range=(-0.5, 0.5)):
super(Gamma, self).__init__()
self._init(locals())
......@@ -109,7 +117,9 @@ class Gamma(ImageAugmentor):
img = cv2.LUT(img, lut).astype('float32')
return img
class Clip(ImageAugmentor):
def __init__(self, min=0, max=255):
self._init(locals())
......@@ -117,7 +127,9 @@ class Clip(ImageAugmentor):
img = np.clip(img, self.min, self.max)
return img
class Saturation(ImageAugmentor):
def __init__(self, alpha=0.4):
""" Saturation, see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
......@@ -130,9 +142,11 @@ class Saturation(ImageAugmentor):
def _augment(self, img, v):
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img * v + (grey * (1 - v))[:,:,np.newaxis]
return img * v + (grey * (1 - v))[:, :, np.newaxis]
class Lighting(ImageAugmentor):
def __init__(self, std, eigval, eigvec):
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
......@@ -143,7 +157,7 @@ class Lighting(ImageAugmentor):
eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec)
assert eigval.shape == (3,)
assert eigvec.shape == (3,3)
assert eigvec.shape == (3, 3)
self._init(locals())
def _get_augment_params(self, img):
......@@ -156,4 +170,3 @@ class Lighting(ImageAugmentor):
inc = np.dot(self.eigvec, v).reshape((3,))
img += inc
return img
......@@ -7,14 +7,18 @@
from .base import ImageAugmentor
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
'RandomOrderAug']
'RandomOrderAug']
class Identity(ImageAugmentor):
def _augment(self, img, _):
return img
class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
def __init__(self, aug, prob):
self._init(locals())
super(RandomApplyAug, self).__init__()
......@@ -37,7 +41,9 @@ class RandomApplyAug(ImageAugmentor):
else:
return self.aug._augment(img, prm[1])
class RandomChooseAug(ImageAugmentor):
def __init__(self, aug_lists):
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
......@@ -65,7 +71,9 @@ class RandomChooseAug(ImageAugmentor):
idx, prm = prm
return self.aug_lists[idx]._augment(img, prm)
class RandomOrderAug(ImageAugmentor):
def __init__(self, aug_lists):
"""
Shuffle the augmentors into random order.
......@@ -93,10 +101,12 @@ class RandomOrderAug(ImageAugmentor):
img = self.aug_lists[k]._augment(img, prms[k])
return img
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func):
"""
:param func: a function which takes a image array and return a augmented one
......@@ -105,4 +115,3 @@ class MapImage(ImageAugmentor):
def _augment(self, img, _):
return self.func(img)
......@@ -9,7 +9,9 @@ import cv2
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)):
super(JpegNoise, self).__init__()
self._init(locals())
......@@ -23,6 +25,7 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor):
def __init__(self, sigma=1, clip=True):
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
......@@ -39,7 +42,9 @@ class GaussianNoise(ImageAugmentor):
ret = np.clip(ret, 0, 255)
return ret
class SaltPepperNoise(ImageAugmentor):
def __init__(self, white_prob=0.05, black_prob=0.05):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
......
......@@ -10,10 +10,12 @@ import cv2
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
class Flip(ImageAugmentor):
"""
Random flip.
"""
def __init__(self, horiz=False, vert=False, prob=0.5):
"""
Only one of horiz, vert can be set.
......@@ -45,8 +47,10 @@ class Flip(ImageAugmentor):
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class Resize(ImageAugmentor):
""" Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_CUBIC):
"""
:param shape: shape in (h, w)
......@@ -59,13 +63,15 @@ class Resize(ImageAugmentor):
img, self.shape[::-1],
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
ret = ret[:, :, np.newaxis]
return ret
class ResizeShortestEdge(ImageAugmentor):
""" Resize the shortest edge to a certain number while
keeping the aspect ratio
"""
def __init__(self, size):
size = size * 1.0
self._init(locals())
......@@ -76,13 +82,15 @@ class ResizeShortestEdge(ImageAugmentor):
desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
ret = ret[:, :, np.newaxis]
return ret
class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
"""
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
......@@ -112,6 +120,5 @@ class RandomResize(ImageAugmentor):
def _augment(self, img, dsize):
ret = cv2.resize(img, dsize, interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:,:,np.newaxis]
ret = ret[:, :, np.newaxis]
return ret
......@@ -9,11 +9,12 @@ from abc import abstractmethod
import numpy as np
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste']
'RandomPaste']
class BackgroundFiller(object):
""" Base class for all BackgroundFiller"""
def fill(self, background_shape, img):
"""
Return a proper background image of background_shape, given img
......@@ -28,8 +29,10 @@ class BackgroundFiller(object):
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value):
"""
:param value: the value to fill the background.
......@@ -44,10 +47,12 @@ class ConstantBackgroundFiller(BackgroundFiller):
return_shape = background_shape
return np.zeros(return_shape) + self.value
class CenterPaste(ImageAugmentor):
"""
Paste the image onto the center of a background canvas.
"""
def __init__(self, background_shape, background_filler=None):
"""
:param background_shape: shape of the background canvas.
......@@ -66,16 +71,18 @@ class CenterPaste(ImageAugmentor):
self.background_shape, img)
y0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
x0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class RandomPaste(CenterPaste):
"""
Randomly paste the image onto a background convas
"""
def _get_augment_params(self, img):
img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
......@@ -89,5 +96,5 @@ class RandomPaste(CenterPaste):
img_shape = img.shape[:2]
background = self.background_filler.fill(
self.background_shape, img)
background[y0:y0+img_shape[0], x0:x0+img_shape[1]] = img
background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background
......@@ -13,7 +13,7 @@ import os
from .base import ProxyDataFlow
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal)
mask_sigint, start_proc_mask_signal)
from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
......@@ -28,6 +28,7 @@ else:
class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True):
"""
:param ds: ds to take data from
......@@ -46,10 +47,12 @@ class PrefetchProcess(mp.Process):
for dp in self.ds.get_data():
self.queue.put(dp)
class PrefetchData(ProxyDataFlow):
"""
Prefetch data from a `DataFlow` using multiprocessing
"""
def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
:param ds: a `DataFlow` instance.
......@@ -82,6 +85,7 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes
pass
def BlockParallel(ds, queue_size):
# TODO more doc
"""
......@@ -92,7 +96,9 @@ def BlockParallel(ds, queue_size):
"""
return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name):
"""
:param ds: a `DataFlow` instance.
......@@ -112,8 +118,10 @@ class PrefetchProcessZMQ(mp.Process):
for dp in self.ds.get_data():
self.socket.send(dumps(dp), copy=False)
class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_proc=1, pipedir=None):
"""
:param ds: a `DataFlow` instance.
......@@ -176,9 +184,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
except:
pass
class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
def __init__(self, ds, gpus, pipedir=None):
self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
......@@ -188,4 +198,3 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu):
proc.start()
......@@ -17,8 +17,10 @@ except:
else:
__all__.append('DataFromSocket')
class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes"""
def __init__(self, shapes, size, random=True, dtype='float32'):
"""
:param shapes: a list of lists/tuples
......@@ -44,8 +46,10 @@ class FakeData(RNGDataFlow):
for _ in range(self._size):
yield copy.deepcopy(v)
class DataFromQueue(DataFlow):
""" Produce data from a queue """
def __init__(self, queue):
self.queue = queue
......@@ -53,8 +57,10 @@ class DataFromQueue(DataFlow):
while True:
yield self.queue.get()
class DataFromList(RNGDataFlow):
""" Produce data from a list"""
def __init__(self, lst, shuffle=True):
super(DataFromList, self).__init__()
self.lst = lst
......@@ -73,8 +79,10 @@ class DataFromList(RNGDataFlow):
for k in idxs:
yield self.lst[k]
class DataFromSocket(DataFlow):
""" Produce data from a zmq socket"""
def __init__(self, socket_name):
self._name = socket_name
......@@ -89,4 +97,3 @@ class DataFromSocket(DataFlow):
yield dp
finally:
ctx.destroy(linger=0)
......@@ -17,6 +17,7 @@ from .common import RepeatedData
from ..utils import logger
from ..utils.serialize import dumps, loads
def serve_data(ds, addr):
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
......@@ -36,7 +37,9 @@ def serve_data(ds, addr):
if not ctx.closed:
ctx.destroy(0)
class RemoteData(DataFlow):
def __init__(self, addr):
self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.PULL)
......@@ -54,7 +57,7 @@ if __name__ == '__main__':
from .raw import FakeData
addr = "tcp://127.0.0.1:8877"
if sys.argv[1] == 'serve':
ds = FakeData([(128,244,244,3)], 1000)
ds = FakeData([(128, 244, 244, 3)], 1000)
serve_data(ds, addr)
else:
ds = RemoteData(addr)
......@@ -62,4 +65,3 @@ if __name__ == '__main__':
with tqdm(total=10000) as pbar:
for k in ds.get_data():
pbar.update()
......@@ -14,9 +14,11 @@ except ImportError:
else:
__all__ = ['TFFuncMapper']
class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
......@@ -39,7 +41,7 @@ class TFFuncMapper(ProxyDataFlow):
def run_func(vals):
return self.sess.run(self.output_vars,
feed_dict=dict(zip(self.placeholders, vals)))
feed_dict=dict(zip(self.placeholders, vals)))
self.run_func = run_func
def get_data(self):
......@@ -63,16 +65,16 @@ if __name__ == '__main__':
v = tf.image.random_flip_left_right(v)
return v
ds = TFFuncMapper(ds,
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
#ds = AugmentImageComponent(ds,
#[imgaug.Brightness(0.1, clip=False),
#imgaug.Contrast((0.8, 1.2), clip=False),
#imgaug.Flip(horiz=True)
#])
#ds = PrefetchDataZMQ(ds, 4)
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
# ds = AugmentImageComponent(ds,
# [imgaug.Brightness(0.1, clip=False),
# imgaug.Contrast((0.8, 1.2), clip=False),
# imgaug.Flip(horiz=True)
# ])
# ds = PrefetchDataZMQ(ds, 4)
ds.reset_state()
import tqdm
......
......@@ -12,6 +12,7 @@ from ..utils import logger
__all__ = ['LinearWrap']
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -32,6 +33,7 @@ class LinearWrap(object):
"""
class TFModuleFunc(object):
def __init__(self, mod, tensor):
self._mod = mod
self._t = tensor
......@@ -88,4 +90,3 @@ class LinearWrap(object):
def print_tensor(self):
print(self._t)
return self
......@@ -5,7 +5,8 @@
import tensorflow as tf
from functools import wraps
import six
import copy, os
import copy
import os
from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
......@@ -16,13 +17,16 @@ from ..utils.argtools import shape2d
# make sure each layer is only logged once
_layer_logged = set()
def disable_layer_logging():
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything()
def layer_register(
summary_activation=False,
log_shape=True,
......@@ -42,13 +46,13 @@ def layer_register(
def wrapped_func(*args, **kwargs):
if use_scope:
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
args = args[1:] # actual positional args used to call func
assert isinstance(name, six.string_types), name
else:
assert not log_shape and not summary_activation
if isinstance(args[0], six.string_types):
name, inputs = args[0], args[1]
args = args[1:] # actual positional args used to call func
args = args[1:] # actual positional args used to call func
else:
inputs = args[0]
name = None
......@@ -97,13 +101,14 @@ def layer_register(
# need some special handling for sphinx to work with the arguments
on_doc = os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
or os.environ.get('TENSORPACK_DOC_BUILDING')
if on_doc:
from decorator import decorator
wrapper = decorator(wrapper)
return wrapper
def shape4d(a):
# for use with tensorflow NHWC ops
return [1] + shape2d(a) + [1]
......@@ -7,7 +7,9 @@ import tensorflow as tf
import numpy as np
import unittest
class TestModel(unittest.TestCase):
def run_variable(self, var):
sess = tf.Session()
sess.run(tf.global_variables_initializer())
......@@ -22,6 +24,7 @@ class TestModel(unittest.TestCase):
else:
return tf.Variable(args[0])
def run_test_case(case):
suite = unittest.TestLoader().loadTestsFromTestCase(case)
unittest.TextTestRunner(verbosity=2).run(suite)
......@@ -34,5 +37,3 @@ if __name__ == '__main__':
subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs:
run_test_case(cls)
......@@ -18,6 +18,8 @@ __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
@layer_register(log_shape=False)
def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
"""
......@@ -41,9 +43,9 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out = shape[-1] # channel
assert n_out is not None
beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer())
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
initializer=tf.constant_initializer(1.0))
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
......@@ -66,7 +68,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
#reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
......@@ -93,7 +95,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out])
else:
## use statistics in another tower
# use statistics in another tower
G = tf.get_default_graph()
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0')
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0')
......@@ -111,6 +113,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return tf.nn.batch_normalization(
x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
@layer_register(log_shape=False)
def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
"""
......@@ -135,9 +138,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out],
initializer=tf.constant_initializer())
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
initializer=tf.constant_initializer(1.0))
# x * gamma + beta
ctx = get_current_tower_context()
......@@ -147,22 +150,22 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
initializer=tf.constant_initializer(), trainable=False)
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
epsilon=epsilon, is_training=True)
# maintain EMA only in the main training tower
if ctx.is_main_training_tower:
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
add_model_variable(moving_mean)
add_model_variable(moving_var)
else:
......@@ -171,9 +174,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output')
# xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
# moving_mean, moving_var,
# epsilon=epsilon, is_training=False, name='output')
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
......
......@@ -12,6 +12,7 @@ from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D']
@layer_register()
def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1,
......@@ -61,14 +62,18 @@ def Conv2D(x, out_channel, kernel_shape,
for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs)
if nl is None:
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')
class StaticDynamicShape(object):
def __init__(self, static, dynamic):
self.static = static
self.dynamic = dynamic
def apply(self, f):
try:
st = f(self.static)
......@@ -76,11 +81,12 @@ class StaticDynamicShape(object):
except:
return StaticDynamicShape(None, f(self.dynamic))
@layer_register()
def Deconv2D(x, out_shape, kernel_shape,
stride, padding='SAME',
W_init=None, b_init=None,
nl=tf.identity, use_bias=True):
stride, padding='SAME',
W_init=None, b_init=None,
nl=tf.identity, use_bias=True):
"""
2D deconvolution on 4D inputs.
......
......@@ -11,6 +11,7 @@ from ..tfutils import symbolic_functions as symbf
__all__ = ['FullyConnected']
@layer_register()
def FullyConnected(x, out_dim,
W_init=None, b_init=None,
......@@ -40,6 +41,7 @@ def FullyConnected(x, out_dim,
b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None:
logger.warn("[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
logger.warn(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. Please use argscope instead.")
nl = tf.nn.relu
return nl(prod, name='output')
......@@ -12,6 +12,8 @@ __all__ = ['ImageSample']
# XXX TODO ugly.
# really need to fix this after tensorflow supports advanced indexing
# See github:tensorflow#418,#206
def sample(img, coords):
"""
:param img: bxhxwxc
......@@ -33,14 +35,15 @@ def sample(img, coords):
# bxh2xw2
batch_add = tf.range(tf.shape(img)[0]) * (shape[0] * shape[1])
batch_add = tf.reshape(batch_add, [-1, 1, 1]) #bx1x1
batch_add = tf.reshape(batch_add, [-1, 1, 1]) # bx1x1
flat_coords = coords + batch_add
img = tf.reshape(img, [-1, shape[2]]) #bhw x c
img = tf.reshape(img, [-1, shape[2]]) # bhw x c
sampled = tf.gather(img, flat_coords)
return sampled
@layer_register()
def ImageSample(inputs, borderMode='repeat'):
"""
......@@ -59,7 +62,7 @@ def ImageSample(inputs, borderMode='repeat'):
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
input_shape = template.get_shape().as_list()[1:]
assert None not in input_shape, \
"Images in ImageSample layer must have fully-defined shape"
"Images in ImageSample layer must have fully-defined shape"
assert borderMode in ['repeat', 'constant']
orig_mapping = mapping
......@@ -68,7 +71,7 @@ def ImageSample(inputs, borderMode='repeat'):
ucoor = lcoor + 1
diff = mapping - lcoor
neg_diff = 1.0 - diff #bxh2xw2x2
neg_diff = 1.0 - diff # bxh2xw2x2
lcoory, lcoorx = tf.split(3, 2, lcoor)
ucoory, ucoorx = tf.split(3, 2, ucoor)
......@@ -80,55 +83,59 @@ def ImageSample(inputs, borderMode='repeat'):
neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#tf.reduce_max(diff), diff], summarize=50)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
# tf.reduce_max(diff), diff], summarize=50)
ret = tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
if borderMode == 'constant':
max_coor = tf.constant([input_shape[0] - 1, input_shape[1] - 1], dtype=tf.float32)
mask = tf.greater_equal(orig_mapping, 0.0)
mask2 = tf.less_equal(orig_mapping, max_coor)
mask = tf.logical_and(mask, mask2) #bxh2xw2x2
mask = tf.reduce_all(mask, [3]) # bxh2xw2 boolean
mask = tf.logical_and(mask, mask2) # bxh2xw2x2
mask = tf.reduce_all(mask, [3]) # bxh2xw2 boolean
mask = tf.expand_dims(mask, 3)
ret = ret * tf.cast(mask, tf.float32)
return ret
from ._test import TestModel
class TestSample(TestModel):
def test_sample(self):
import numpy as np
h, w = 3, 4
def np_sample(img, coords):
# a reference implementation
coords = np.maximum(coords, 0)
coords = np.minimum(coords,
np.array([img.shape[1]-1, img.shape[2]-1]))
xs = coords[:,:,:,1].reshape((img.shape[0], -1))
ys = coords[:,:,:,0].reshape((img.shape[0], -1))
np.array([img.shape[1] - 1, img.shape[2] - 1]))
xs = coords[:, :, :, 1].reshape((img.shape[0], -1))
ys = coords[:, :, :, 0].reshape((img.shape[0], -1))
ret = np.zeros((img.shape[0], coords.shape[1], coords.shape[2],
img.shape[3]), dtype='float32')
for k in range(img.shape[0]):
xss, yss = xs[k], ys[k]
ret[k,:,:,:] = img[k,yss,xss,:].reshape((coords.shape[1],
coords.shape[2], 3))
ret[k, :, :, :] = img[k, yss, xss, :].reshape((coords.shape[1],
coords.shape[2], 3))
return ret
bimg = np.random.rand(2, h, w, 3).astype('float32')
#mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
# mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3])
true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32'))
inp, mapping = self.make_variable(bimg, mat)
output = sample(inp, tf.cast(tf.floor(mapping+0.5), tf.int32))
output = sample(inp, tf.cast(tf.floor(mapping + 0.5), tf.int32))
res = self.run_variable(output)
self.assertTrue((res == true_res).all())
......@@ -146,7 +153,7 @@ if __name__ == '__main__':
diff = 200
for x in range(w):
for y in range(h):
mapping[0,y,x,:] = np.array([y-diff+0.4, x-diff+0.5])
mapping[0, y, x, :] = np.array([y - diff + 0.4, x - diff + 0.5])
mapv = tf.Variable(mapping)
output = ImageSample('sample', [imv, mapv], borderMode='constant')
......@@ -155,12 +162,10 @@ if __name__ == '__main__':
#out = sess.run(tf.gradients(tf.reduce_sum(output), mapv))
#out = sess.run(output)
#print(out[0].min())
#print(out[0].max())
#print(out[0].sum())
# print(out[0].min())
# print(out[0].max())
# print(out[0].sum())
out = sess.run([output])[0]
im = out[0]
cv2.imwrite('sampled.jpg', im)
......@@ -16,21 +16,27 @@ from ..tfutils.common import get_tensors_by_names
from ..tfutils.gradproc import CheckGradient
from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph']
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(object):
def __init__(self, type, shape, name, sparse=False):
self.type = type
self.shape = shape
self.name = name
self.sparse = sparse
def dumps(self):
return pickle.dumps(self)
@staticmethod
def loads(buf):
return pickle.loads(buf)
@six.add_metaclass(ABCMeta)
class ModelDesc(object):
""" Base class for a model description """
......@@ -99,22 +105,24 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order"""
return [#SummaryGradient(),
CheckGradient()
]
return [ # SummaryGradient(),
CheckGradient()
]
class ModelFromMetaGraph(ModelDesc):
"""
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
"""
def __init__(self, filename):
tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys().VARIABLES]:
tf.GraphKeys().VARIABLES]:
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
"Collection {} not found in metagraph!".format(k)
def _get_input_vars(self):
col = tf.get_collection(INPUT_VARS_KEY)
......
......@@ -11,6 +11,7 @@ from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register()
def Maxout(x, num_unit):
"""
......@@ -31,6 +32,7 @@ def Maxout(x, num_unit):
x = tf.reshape(x, [-1, ch / num_unit, num_unit])
return tf.reduce_max(x, ndim, name='output')
@layer_register(log_shape=False)
def PReLU(x, init=tf.constant_initializer(0.001), name=None):
"""
......@@ -47,6 +49,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
name = 'output'
return tf.mul(x, 0.5, name=name)
@layer_register(use_scope=False, log_shape=False)
def LeakyReLU(x, alpha, name=None):
"""
......@@ -62,7 +65,8 @@ def LeakyReLU(x, alpha, name=None):
return tf.maximum(x, alpha * x, name=name)
#alpha = float(alpha)
#x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
#return tf.mul(x, 0.5, name=name)
# return tf.mul(x, 0.5, name=name)
@layer_register(log_shape=False, use_scope=False)
def BNReLU(x, name=None):
......
......@@ -12,6 +12,7 @@ from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
@layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'):
"""
......@@ -32,6 +33,7 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def AvgPooling(x, shape, stride=None, padding='VALID'):
"""
......@@ -52,6 +54,7 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def GlobalAvgPooling(x):
"""
......@@ -65,6 +68,8 @@ def GlobalAvgPooling(x):
return tf.reduce_mean(x, [1, 2])
# https://github.com/tensorflow/tensorflow/issues/2169
def UnPooling2x2ZeroFilled(x):
out = tf.concat(3, [x, tf.zeros_like(x)])
out = tf.concat(2, [out, tf.zeros_like(out)])
......@@ -79,6 +84,7 @@ def UnPooling2x2ZeroFilled(x):
ret.set_shape([None, None, None, sh[3]])
return ret
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
"""
......@@ -108,8 +114,8 @@ def FixedUnPooling(x, shape, unpool_mat=None):
# perform a tensor-matrix kronecker product
fx = symbf.flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
mat = tf.expand_dims(symbf.flatten(unpool_mat), 0) # 1x(shxsw)
prod = tf.matmul(fx, mat) # (bchw) x(shxsw)
prod = tf.reshape(prod, tf.pack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
......@@ -117,6 +123,7 @@ def FixedUnPooling(x, shape, unpool_mat=None):
[-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
return prod
@layer_register()
def BilinearUpSample(x, shape):
"""
......@@ -125,9 +132,9 @@ def BilinearUpSample(x, shape):
:param shape: an integer, the upsample factor
"""
#inp_shape = tf.shape(x)
#return tf.image.resize_bilinear(x,
#tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
#align_corners=True)
# return tf.image.resize_bilinear(x,
# tf.pack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True)
inp_shape = x.get_shape().as_list()
ch = inp_shape[3]
......@@ -136,7 +143,6 @@ def BilinearUpSample(x, shape):
shape = int(shape)
filter_shape = 2 * shape
def bilinear_conv_filler(s):
"""
s: width, height of the conv filter
......@@ -147,7 +153,7 @@ def BilinearUpSample(x, shape):
ret = np.zeros((s, s), dtype='float32')
for x in range(s):
for y in range(s):
ret[x,y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
ret[x, y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
return ret
w = bilinear_conv_filler(filter_shape)
w = np.repeat(w, ch * ch).reshape((filter_shape, filter_shape, ch, ch))
......@@ -155,17 +161,22 @@ def BilinearUpSample(x, shape):
shape=(filter_shape, filter_shape, ch, ch),
name='bilinear_upsample_filter')
deconv = tf.nn.conv2d_transpose(x, weight_var,
tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32),
[1,shape,shape,1], 'SAME')
tf.shape(x) * tf.constant([1, shape, shape, 1], tf.int32),
[1, shape, shape, 1], 'SAME')
if inp_shape[1]: inp_shape[1] *= shape
if inp_shape[2]: inp_shape[2] *= shape
if inp_shape[1]:
inp_shape[1] *= shape
if inp_shape[2]:
inp_shape[2] *= shape
deconv.set_shape(inp_shape)
return deconv
from ._test import TestModel
class TestPool(TestModel):
def test_fixed_unpooling(self):
h, w = 3, 4
mat = np.random.rand(h, w, 3).astype('float32')
......@@ -173,13 +184,13 @@ class TestPool(TestModel):
inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, 2)
res = self.run_variable(output)
self.assertEqual(res.shape, (1, 2*h, 2*w, 3))
self.assertEqual(res.shape, (1, 2 * h, 2 * w, 3))
# mat is on cornser
ele = res[0,::2,::2,0]
self.assertTrue((ele == mat[:,:,0]).all())
ele = res[0, ::2, ::2, 0]
self.assertTrue((ele == mat[:, :, 0]).all())
# the rest are zeros
res[0,::2,::2,:] = 0
res[0, ::2, ::2, :] = 0
self.assertTrue((res == 0).all())
def test_upsample(self):
......@@ -191,7 +202,7 @@ class TestPool(TestModel):
inp = tf.reshape(inp, [1, h, w, 1])
output = BilinearUpSample('upsample', inp, scale)
res = self.run_variable(output)[0,:,:,0]
res = self.run_variable(output)[0, :, :, 0]
from skimage.transform import rescale
res2 = rescale(mat, scale)
......@@ -199,9 +210,9 @@ class TestPool(TestModel):
diff = np.abs(res2 - res)
# not equivalent to rescale on edge?
diff[0,:] = 0
diff[:,0] = 0
diff[0, :] = 0
diff[:, 0] = 0
if not diff.max() < 1e-4:
import IPython;
import IPython
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
self.assertTrue(diff.max() < 1e-4)
......@@ -12,6 +12,7 @@ from ._common import layer_register
__all__ = ['regularize_cost', 'l2_regularizer', 'l1_regularizer', 'Dropout']
@memoized
def _log_regularizer(name):
logger.info("Apply regularizer for {}".format(name))
......@@ -19,6 +20,7 @@ def _log_regularizer(name):
l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer
def regularize_cost(regex, func, name=None):
"""
Apply a regularizer on every trainable variable matching the regex.
......@@ -48,4 +50,3 @@ def Dropout(x, keep_prob=0.5, is_training=None):
is_training = get_current_tower_context().is_training
keep_prob = tf.constant(keep_prob if is_training else 1.0)
return tf.nn.dropout(x, keep_prob)
......@@ -8,6 +8,7 @@ from ._common import layer_register
__all__ = ['ConcatWith']
@layer_register(use_scope=False, log_shape=False)
def ConcatWith(x, dim, tensor):
"""
......
......@@ -8,6 +8,7 @@ from ._common import layer_register
__all__ = ['SoftMax']
@layer_register()
def SoftMax(x, use_temperature=False, temperature_init=1.0):
"""
......@@ -16,6 +17,6 @@ def SoftMax(x, use_temperature=False, temperature_init=1.0):
"""
if use_temperature:
t = tf.get_variable('invtemp', [],
initializer=tf.constant_initializer(1.0 / float(temperature_init)))
initializer=tf.constant_initializer(1.0 / float(temperature_init)))
x = x * t
return tf.nn.softmax(x, name='output')
......@@ -8,6 +8,7 @@ import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if module_name.startswith('_'):
continue
global_import(module_name)
......@@ -12,9 +12,10 @@ from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
@six.add_metaclass(ABCMeta)
class PredictorBase(object):
......@@ -46,7 +47,9 @@ class PredictorBase(object):
:return: output as defined by the config
"""
class AsyncPredictorBase(PredictorBase):
@abstractmethod
def put_task(self, dp, callback=None):
"""
......@@ -67,7 +70,9 @@ class AsyncPredictorBase(PredictorBase):
# in Tornado, Future.result() doesn't wait
return fut.result()
class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_tensors, output_tensors, return_input=False):
self.session = sess
self.return_input = return_input
......@@ -85,6 +90,7 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor):
""" Build a predictor from a given config, in an independent graph"""
def __init__(self, config):
self.graph = tf.Graph()
with self.graph.as_default():
......@@ -98,7 +104,7 @@ class OfflinePredictor(OnlinePredictor):
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(build_tower_fn, towers):
......@@ -108,13 +114,15 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
"""
for k in towers:
logger.info(
"Building graph for predictor tower {}...".format(k))
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('{}{}'.format(PREDICT_TOWER, k)):
build_tower_fn(k)
tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers):
self.graph = tf.Graph()
self.predictors = []
......@@ -130,8 +138,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for k in towers:
output_vars = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n \
for n in config.output_names])
['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
......@@ -142,7 +150,9 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers):
self.graph = tf.Graph()
with self.graph.as_default():
......@@ -152,19 +162,19 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for k in towers:
towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders(
prefix=towername + '-')
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n \
for n in config.output_names]))
[towername + '/' + n
for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
sess, input_vars, output_vars, config.return_input)
......@@ -15,11 +15,13 @@ from .base import OfflinePredictor
import multiprocessing
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult' ]
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictConfig(object):
def __init__(self, **kwargs):
"""
The config used by `get_predict_func`.
......@@ -61,12 +63,14 @@ class PredictConfig(object):
self.output_names = kwargs.pop('output_var_names')
#logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names
for v in self.input_names: assert_type(v, six.string_types)
for v in self.input_names:
assert_type(v, six.string_types)
assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config):
"""
Produce a offline predictor run inside a new session.
......@@ -76,4 +80,3 @@ def get_predict_func(config):
a list of output values defined in ``config.output_var_names``.
"""
return OfflinePredictor(config)
......@@ -3,7 +3,8 @@
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing, threading
import multiprocessing
import threading
import tensorflow as tf
import time
import six
......@@ -25,10 +26,12 @@ except ImportError:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, config):
"""
:param idx: index of the worker. the 0th worker will print log.
......@@ -51,8 +54,10 @@ class MultiProcessPredictWorker(multiprocessing.Process):
with self.predictor.graph.as_default():
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, inqueue, outqueue, config):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
......@@ -76,6 +81,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
......@@ -88,13 +94,13 @@ class PredictorWorkerThread(threading.Thread):
while True:
batched, futures = self.fetch_batch()
outputs = self.func(batched)
#print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize())
# debug, for speed testing
#if not hasattr(self, 'xxx'):
#self.xxx = outputs = self.func(batched)
#else:
#outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
# print "Worker {} batched {} Queue {}".format(
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
# self.xxx = outputs = self.func(batched)
# else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
......@@ -119,11 +125,13 @@ class PredictorWorkerThread(threading.Thread):
cnt += 1
return batched, futures
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally.
"""
def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor"""
assert len(predictors)
......@@ -131,7 +139,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
#assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100)
self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size)
......
......@@ -20,10 +20,12 @@ from .common import PredictConfig
from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
'MultiProcessDatasetPredictor']
@six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object):
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
......@@ -45,10 +47,12 @@ class DatasetPredictorBase(object):
"""
return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config)
......@@ -60,14 +64,17 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with get_tqdm(total=sz, disable=(sz==0)) as pbar:
with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
for dp in self.dataset.get_data():
res = self.predictor(dp)
yield res
pbar.update()
# TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
......@@ -87,14 +94,14 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
if use_gpu:
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
assert len(gpus) >= self.nr_proc, \
"nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus))
"nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus))
except KeyError:
# TODO number of GPUs not checked
gpus = list(range(self.nr_proc))
......@@ -103,8 +110,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# worker produces (idx, result) to outqueue
self.outqueue = multiprocessing.Queue()
self.workers = [MultiProcessQueuePredictWorker(
i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
# start inqueue and workers
self.inqueue_proc.start()
......@@ -118,7 +125,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
if ordered:
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
self.outqueue, nr_producer=self.nr_proc)
self.result_queue.start()
ensure_proc_terminate(self.result_queue)
else:
......@@ -130,7 +137,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with get_tqdm(total=sz, disable=(sz==0)) as pbar:
with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
die_cnt = 0
while True:
res = self.result_queue.get()
......@@ -147,4 +154,5 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.result_queue.join()
self.result_queue.terminate()
for p in self.workers:
p.join(); p.terminate()
p.join()
p.terminate()
......@@ -12,6 +12,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = []
@contextmanager
def argscope(layers, **param):
if not isinstance(layers, list):
......@@ -33,6 +34,7 @@ def argscope(layers, **param):
yield
del _ArgScopeStack[-1]
def get_arg_scope():
"""
:returns: the current argscope.
......
......@@ -22,6 +22,7 @@ __all__ = ['get_default_sess_config',
'freeze_collection',
'get_tf_version']
def get_default_sess_config(mem_fraction=0.99):
"""
Return a better session config to use as default.
......@@ -38,6 +39,7 @@ def get_default_sess_config(mem_fraction=0.99):
#conf.log_device_placement = True
return conf
def get_global_step_var():
""" :returns: the global_step variable in the current graph. create if not existed"""
try:
......@@ -45,19 +47,21 @@ def get_global_step_var():
except KeyError:
scope = tf.get_variable_scope()
assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!"
"Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.constant_initializer(dtype=tf.int32),
trainable=False, dtype=tf.int32)
initializer=tf.constant_initializer(dtype=tf.int32),
trainable=False, dtype=tf.int32)
return var
def get_global_step():
""" :returns: global_step value in current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
def get_op_tensor_name(name):
"""
Tensor name is assumed to be ``op_name + ':0'``
......@@ -72,6 +76,7 @@ def get_op_tensor_name(name):
get_op_var_name = get_op_tensor_name
def get_tensors_by_names(names):
"""
Get a list of tensors in the default graph by a list of names
......@@ -85,26 +90,31 @@ def get_tensors_by_names(names):
get_vars_by_names = get_tensors_by_names
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
def get_tf_version():
return int(tf.__version__.split('.')[1])
......@@ -16,6 +16,7 @@ __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors',
'GlobalNormClip']
def apply_grad_processors(grads, gradprocs):
"""
:param grads: list of (grad, var).
......@@ -32,6 +33,7 @@ def apply_grad_processors(grads, gradprocs):
g = proc.process(g)
return g
@six.add_metaclass(ABCMeta)
class GradientProcessor(object):
......@@ -51,6 +53,7 @@ class GradientProcessor(object):
class GlobalNormClip(GradientProcessor):
def __init__(self, global_norm):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
......@@ -63,11 +66,13 @@ class GlobalNormClip(GradientProcessor):
g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm')
return list(zip(g, v))
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
......@@ -77,7 +82,7 @@ class MapGradient(GradientProcessor):
args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func)
assert arg_num in [1, 2], \
"The function must take 1 or 2 arguments! ({})".format(args)
"The function must take 1 or 2 arguments! ({})".format(args)
if arg_num == 1:
self.func = lambda grad, var: func(grad)
else:
......@@ -100,10 +105,12 @@ class MapGradient(GradientProcessor):
_summaried_gradient = set()
class SummaryGradient(MapGradient):
"""
Summary history and RMS for each graident variable
"""
def __init__(self):
super(SummaryGradient, self).__init__(self._mapper)
......@@ -115,10 +122,12 @@ class SummaryGradient(MapGradient):
add_moving_summary(rms(grad, name=name + '/rms'))
return grad
class CheckGradient(MapGradient):
"""
Check for numeric issue.
"""
def __init__(self):
super(CheckGradient, self).__init__(self._mapper)
......@@ -128,10 +137,12 @@ class CheckGradient(MapGradient):
grad = tf.check_numerics(grad, 'CheckGradient-' + var.op.name)
return grad
class ScaleGradient(MapGradient):
"""
Scale certain gradient by a multiplier
"""
def __init__(self, multipliers, log=True):
"""
:param multipliers: list of (regex, float)
......
......@@ -9,6 +9,7 @@ from ..utils import logger
__all__ = ['describe_model', 'get_shape_str']
def describe_model():
""" print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......@@ -40,5 +41,3 @@ def get_shape_str(tensors):
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list())
return shape_str
......@@ -20,6 +20,7 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default.
@six.add_metaclass(ABCMeta)
class SessionInit(object):
""" Base class for utilities to initialize a session"""
......@@ -35,23 +36,29 @@ class SessionInit(object):
def _init(self, sess):
pass
class JustCurrentSession(SessionInit):
""" Just use the current default session. This is a no-op placeholder"""
def _init(self, sess):
pass
class NewSession(SessionInit):
"""
Create a new session. All variables will be initialized by their
initializer.
"""
def _init(self, sess):
sess.run(tf.global_variables_initializer())
class SaverRestore(SessionInit):
"""
Restore an old model saved by `ModelSaver`.
"""
def __init__(self, model_path, prefix=None):
"""
:param model_path: a model name (model-xxxx) or a ``checkpoint`` file.
......@@ -71,7 +78,7 @@ class SaverRestore(SessionInit):
new_path = model_path.split('.index')[0]
if new_path != model_path:
logger.warn(
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
"[SaverRestore] {} is corrected to {} when restoring the model.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
self.set_path(model_path)
......@@ -146,10 +153,12 @@ class SaverRestore(SessionInit):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit):
"""
Restore variables from a dictionary.
"""
def __init__(self, param_dict):
"""
:param param_dict: a dict of {name: value}
......@@ -158,7 +167,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
variables = tf.get_collection(tf.GraphKeys().VARIABLES) # TODO
variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms))
......@@ -174,14 +183,15 @@ class ParamRestore(SessionInit):
logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess,
[v for v in variables if \
get_savename_from_varname(v.name) in intersect])
[v for v in variables if
get_savename_from_varname(v.name) in intersect])
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
class ChainInit(SessionInit):
""" Init a session by a list of SessionInit instance."""
def __init__(self, sess_inits, new_session=True):
"""
:params sess_inits: list of `SessionInit` instances.
......
......@@ -15,6 +15,7 @@ from .symbolic_functions import rms
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average']
def create_summary(name, v):
"""
Return a tf.Summary object with name and simple scalar value v
......@@ -25,6 +26,7 @@ def create_summary(name, v):
s.value.add(tag=name, simple_value=v)
return s
def add_activation_summary(x, name=None):
"""
Add summary to graph for an activation tensor x.
......@@ -44,6 +46,7 @@ def add_activation_summary(x, name=None):
tf.summary.scalar(name + '-sparsity', tf.nn.zero_fraction(x))
tf.summary.scalar(name + '-rms', rms(x))
def add_param_summary(summary_lists):
"""
Add summary for all trainable variables matching the regex
......@@ -54,6 +57,7 @@ def add_param_summary(summary_lists):
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
return
def perform(var, action):
ndim = var.get_shape().ndims
name = var.name.replace(':0', '')
......@@ -87,6 +91,7 @@ def add_param_summary(summary_lists):
for act in actions:
perform(p, act)
def add_moving_summary(v, *args):
"""
:param v: tensor or list of tensor to summary
......@@ -102,6 +107,7 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0, x.get_shape()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized
def summary_moving_average(tensors=None):
"""
......@@ -121,4 +127,3 @@ def summary_moving_average(tensors=None):
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
return avg_maintain_op
......@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np
from ..utils import logger
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
:param logits: NxC
......@@ -13,7 +14,8 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
:returns: a float32 vector of length N with 0/1 values. 1 means incorrect prediction
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)
tf.float32, name=name)
def flatten(x):
"""
......@@ -21,6 +23,7 @@ def flatten(x):
"""
return tf.reshape(x, [-1])
def batch_flatten(x):
"""
Flatten the tensor except the first dimension.
......@@ -30,6 +33,7 @@ def batch_flatten(x):
return tf.reshape(x, [-1, int(np.prod(shape))])
return tf.reshape(x, tf.pack([tf.shape(x)[0], -1]))
def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
......@@ -53,6 +57,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
cost = tf.sub(loss_pos, loss_neg, name=name)
return cost
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
......@@ -75,13 +80,14 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
cost = tf.reduce_mean(cost * (1 - beta), name=name)
#logstable = tf.log(1 + tf.exp(-tf.abs(z)))
#loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
#loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
# loss_pos = -beta * tf.reduce_mean(-y *
#(logstable - tf.minimum(0.0, z)))
# loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
#(logstable + tf.maximum(z, 0.0)))
#cost = tf.sub(loss_pos, loss_neg, name=name)
return cost
def print_stat(x, message=None):
""" a simple print op.
Use it like: x = print_stat(x)
......@@ -89,7 +95,8 @@ def print_stat(x, message=None):
if message is None:
message = x.op.name
return tf.Print(x, [tf.shape(x), tf.reduce_mean(x), x], summarize=20,
message=message, name='print_' + x.op.name)
message=message, name='print_' + x.op.name)
def rms(x, name=None):
if name is None:
......@@ -98,14 +105,16 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
def huber_loss(x, delta=1, name='huber_loss'):
sqrcost = tf.square(x)
abscost = tf.abs(x)
return tf.reduce_sum(
tf.select(abscost < delta,
sqrcost * 0.5,
abscost * delta - 0.5 * delta ** 2),
name=name)
tf.select(abscost < delta,
sqrcost * 0.5,
abscost * delta - 0.5 * delta ** 2),
name=name)
def get_scalar_var(name, init_value, summary=False, trainable=False):
"""
......@@ -113,8 +122,8 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
:param summary: summary this variable
"""
ret = tf.get_variable(name, shape=[],
initializer=tf.constant_initializer(init_value),
trainable=trainable)
initializer=tf.constant_initializer(init_value),
trainable=trainable)
if summary:
# this is recognized in callbacks.StatHolder
tf.summary.scalar(name + '-summary', ret)
......
......@@ -11,7 +11,9 @@ __all__ = ['get_current_tower_context', 'TowerContext']
_CurrentTowerContext = None
class TowerContext(object):
def __init__(self, tower_name, is_training=None):
""" tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name
......@@ -65,7 +67,7 @@ class TowerContext(object):
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
"Nesting TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext = self
if len(self._name):
self._scope = tf.name_scope(self._name)
......@@ -78,7 +80,7 @@ class TowerContext(object):
self._scope.__exit__(exc_type, exc_val, exc_tb)
return False
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
......@@ -3,7 +3,8 @@
# File: varmanip.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six, os
import six
import os
import tensorflow as tf
from collections import defaultdict
import re
......@@ -13,7 +14,8 @@ from ..utils.naming import *
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'get_savename_from_varname', 'is_training_name']
'get_savename_from_varname', 'is_training_name']
def get_savename_from_varname(
varname, varname_prefix=None,
......@@ -33,13 +35,15 @@ def get_savename_from_varname(
name = re.sub('tower[p0-9]+/', '', name)
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix)+1:]
name = name[len(varname_prefix) + 1:]
if savename_prefix is not None:
name = savename_prefix + '/' + name
return name
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update):
"""
:param vars_to_update: a collection of variables to update
......@@ -66,11 +70,12 @@ class SessionUpdate(object):
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
npy format, loadable by ParamRestore
......@@ -90,6 +95,7 @@ the same name".format(v.name))
logger.info(str(result.keys()))
np.save(path, result)
def dump_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict"""
if os.path.basename(model_path) == model_path:
......@@ -101,6 +107,7 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n)
return result
def is_training_name(name):
"""
This is only used to improve logging.
......
......@@ -8,6 +8,7 @@ import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
......@@ -25,4 +26,3 @@ for _, module_name, _ in walk_packages(
if module_name.startswith('_'):
continue
global_import(module_name)
......@@ -21,8 +21,11 @@ from ..tfutils.summary import create_summary
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
pass
@six.add_metaclass(ABCMeta)
class Trainer(object):
""" Base class for a trainer."""
......@@ -91,7 +94,7 @@ class Trainer(object):
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value)
......@@ -99,7 +102,7 @@ class Trainer(object):
def write_scalar_summary(self, name, val):
self.summary_writer.add_summary(
create_summary(name, val), get_global_step())
create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val)
def setup(self):
......@@ -138,7 +141,7 @@ class Trainer(object):
callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step()))
for epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation(
'Epoch {} (global_step {})'.format(
epoch_num, get_global_step() + self.config.step_per_epoch)):
......@@ -147,7 +150,7 @@ class Trainer(object):
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop():
return
self.run_step() # implemented by subclass
self.run_step() # implemented by subclass
callbacks.trigger_step() # not useful?
# trigger epoch outside the timing region.
self.trigger_epoch()
......
......@@ -9,15 +9,17 @@ from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
get_default_sess_config, SessionInit)
from .input_data import InputData
__all__ = ['TrainConfig']
class TrainConfig(object):
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
:param dataset: the dataset to train. a `DataFlow` instance.
......
......@@ -17,8 +17,10 @@ from .trainer import MultiPredictorTowerTrainer
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
......@@ -33,7 +35,9 @@ class FeedfreeTrainer(Trainer):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
......@@ -41,35 +45,37 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False)
cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False)
add_moving_summary(cost_var)
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
#if not hasattr(self, 'cnt'):
#self.cnt = 0
#else:
#self.cnt += 1
#if self.cnt % 10 == 0:
## debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# else:
# self.cnt += 1
# if self.cnt % 10 == 0:
# # debug-benchmark code:
# run_metadata = tf.RunMetadata()
# self.sess.run([self.train_op],
# options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
# run_metadata=run_metadata
# )
# from tensorflow.python.client import timeline
# trace = timeline.Timeline(step_stats=run_metadata.step_stats)
# trace_file = open('timeline.ctf.json', 'w')
# trace_file.write(trace.generate_chrome_trace_format())
# import sys; sys.exit()
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
SingleCostFeedfreeTrainer):
def __init__(self, config):
"""
A trainer with single cost, single training tower and feed-free input
......@@ -80,7 +86,7 @@ class SimpleFeedfreeTrainer(
super(SimpleFeedfreeTrainer, self).__init__(config)
self._setup_predictor_factory(config.predict_tower)
assert len(self.config.tower) == 1, \
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
def _setup(self):
super(SimpleFeedfreeTrainer, self)._setup()
......@@ -94,6 +100,7 @@ class SimpleFeedfreeTrainer(
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
......@@ -110,5 +117,5 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config)
......@@ -14,13 +14,16 @@ from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput']
'DummyConstantInput']
@six.add_metaclass(ABCMeta)
class InputData(object):
pass
class FeedInput(InputData):
def __init__(self, ds):
assert isinstance(ds, DataFlow), ds
self.ds = ds
......@@ -39,7 +42,9 @@ class FeedInput(InputData):
feed = dict(zip(self.input_vars, data))
return feed
class FeedfreeInput(InputData):
def get_input_tensors(self):
return self._get_input_tensors()
......@@ -49,7 +54,9 @@ class FeedfreeInput(InputData):
always create and return a list of new input tensors
"""
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__()
self.name = 'EnqueueThread'
......@@ -77,7 +84,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop():
return
feed = dict(zip(self.placehdrs, dp))
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
......@@ -91,7 +98,9 @@ class EnqueueThread(threading.Thread):
pass
logger.info("Enqueue Thread Exited.")
class QueueInput(FeedfreeInput):
def __init__(self, ds, queue=None):
"""
:param ds: a `DataFlow` instance
......@@ -108,32 +117,34 @@ class QueueInput(FeedfreeInput):
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
"QueueInput can only be used with input placeholders!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
name='input_queue')
50, [x.dtype for x in self.input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(
trainer, self.queue, self.ds, self.input_placehdrs)
trainer, self.queue, self.ds, self.input_placehdrs)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
# with tf.device('/gpu:0'):
# ret = [tf.Variable(tf.random_normal([128,224,224,3],
# dtype=tf.float32), trainable=False),
# tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class DummyConstantInput(QueueInput):
""" only for debugging performance issues """
def __init__(self, ds, shapes):
super(DummyConstantInput, self).__init__(ds)
self.shapes = shapes
......@@ -146,11 +157,13 @@ class DummyConstantInput(QueueInput):
for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name,
shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
return ret
class TensorInput(FeedfreeInput):
def __init__(self, get_tensor_fn, size=None):
self.get_tensor_fn = get_tensor_fn
self._size = size
......
......@@ -4,7 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import itertools, re
import itertools
import re
from six.moves import zip, range
from ..utils import logger
......@@ -12,7 +13,7 @@ from ..utils.naming import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer
......@@ -22,6 +23,7 @@ from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
......@@ -45,9 +47,11 @@ class MultiGPUTrainer(Trainer):
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config, input_queue=None, predict_tower=None):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
......@@ -64,7 +68,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available()
@staticmethod
def _average_grads(tower_grads):
if len(tower_grads) == 1:
......@@ -92,12 +95,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#return
# return
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
......@@ -109,13 +112,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def run_step(self):
self.sess.run(self.train_op)
class AsyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer):
def __init__(self, config,
input_queue=None,
average_gradient=True,
predict_tower=None):
input_queue=None,
average_gradient=True,
predict_tower=None):
if hasattr(config, 'dataset'):
self._input_method = QueueInput(config.dataset, input_queue)
else:
......@@ -134,7 +139,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
self.config.tower, lambda: self._get_cost_and_grad()[1])
gradprocs = self.model.get_gradient_processor()
if self._average_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
......@@ -157,7 +162,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self.training_threads = []
for k in range(1, len(self.config.tower)):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
def f(op=train_op): # avoid late-binding
self.sess.run([op])
next(self.async_step_counter)
th = LoopThread(f)
......@@ -169,7 +175,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
for th in self.training_threads: # resume all threads
th.resume()
next(self.async_step_counter)
self.sess.run(self.train_op)
......@@ -183,7 +189,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
'async_global_step', async_step_total_cnt)
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
......@@ -10,13 +10,14 @@ from .base import Trainer
from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer','MultiPredictorTowerTrainer']
__all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
......@@ -52,8 +53,10 @@ class PredictorFactory(object):
build_multi_tower_prediction_graph(fn, self.towers)
self.tower_built = True
class SimpleTrainer(Trainer):
""" A naive demo trainer """
def __init__(self, config):
super(SimpleTrainer, self).__init__(config)
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
......@@ -78,7 +81,7 @@ class SimpleTrainer(Trainer):
grads = self.config.optimizer.compute_gradients(cost_var)
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.model.get_gradient_processor())
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......@@ -93,13 +96,15 @@ class SimpleTrainer(Trainer):
def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0)
class MultiPredictorTowerTrainer(Trainer):
""" A trainer with possibly multiple prediction tower """
def _setup_predictor_factory(self, predict_tower):
# by default, use the first training gpu for prediction
predict_tower = predict_tower or [0]
self._predictor_factory = PredictorFactory(
self.sess, self.model, predict_tower)
self.sess, self.model, predict_tower)
def get_predict_func(self, input_names, output_names, tower=0):
"""
......
......@@ -12,6 +12,7 @@ These utils should be irrelevant to tensorflow.
__all__ = []
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
......@@ -23,7 +24,7 @@ _TO_IMPORT = set([
'naming',
'utils',
'gpu'
])
])
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
......@@ -36,5 +37,3 @@ for _, module_name, _ in walk_packages(
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.append(module_name)
......@@ -5,10 +5,13 @@
import operator
import inspect, six, functools
import inspect
import six
import functools
import collections
__all__ = [ 'map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs']
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs']
def map_arg(**maps):
"""
......@@ -26,11 +29,13 @@ def map_arg(**maps):
return wrapper
return deco
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
......@@ -60,8 +65,11 @@ class memoized(object):
return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary?
def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS:
res = func(*args, **kwargs)
......@@ -70,15 +78,16 @@ def memoized_ignoreargs(func):
return _MEMOIZED_NOARGS[func]
return wrapper
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
# _GLOBAL_MEMOIZED_CACHE = dict()
# def global_memoized(func):
# """ Make sure that the same `memoized` object is returned on different
# calls to global_memoized(func)
# """
# ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
# if ret is None:
# ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
# return ret
def shape2d(a):
"""
......
......@@ -23,10 +23,12 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'mask_sigint', 'start_proc_mask_signal']
class StoppableThread(threading.Thread):
"""
A thread that has a 'stop' event.
"""
def __init__(self):
super(StoppableThread, self).__init__()
self._stop_evt = threading.Event()
......@@ -56,8 +58,10 @@ class StoppableThread(threading.Thread):
except queue.Empty:
pass
class LoopThread(StoppableThread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func, pausable=True):
"""
:param func: the function to run
......@@ -89,6 +93,7 @@ class DIE(object):
""" A placeholder class indicating end of queue """
pass
def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
......@@ -114,6 +119,7 @@ def mask_sigint():
yield
signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc):
if not isinstance(proc, list):
proc = [proc]
......@@ -122,11 +128,12 @@ def start_proc_mask_signal(proc):
for p in proc:
p.start()
def subproc_call(cmd, timeout=None):
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT,
shell=True, timeout=timeout)
cmd, stderr=subprocess.STDOUT,
shell=True, timeout=timeout)
return output
except subprocess.TimeoutExpired as e:
logger.warn("Command timeout!")
......@@ -135,10 +142,12 @@ def subproc_call(cmd, timeout=None):
logger.warn("Commnad failed: {}".format(e.returncode))
logger.warn(e.output)
class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
def __init__(self, start=0):
self.ranks = []
self.data = []
......@@ -163,11 +172,13 @@ class OrderedContainer(object):
self.wait_for += 1
return rank, ret
class OrderedResultGatherProc(multiprocessing.Process):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def __init__(self, data_queue, nr_producer, start=0):
"""
:param data_queue: a multiprocessing.Queue to produce input dp
......
......@@ -7,6 +7,7 @@
import sys
__all__ = ['enable_call_trace']
def enable_call_trace():
def tracer(frame, event, arg):
if event == 'call':
......@@ -21,9 +22,9 @@ def enable_call_trace():
if caller:
caller_line_no = caller.f_lineno
caller_filename = caller.f_code.co_filename
print('Call to `%s` on line %s:%s from %s:%s' % \
(func_name, func_filename, func_line_no,
caller_filename, caller_line_no))
print('Call to `%s` on line %s:%s from %s:%s' %
(func_name, func_filename, func_line_no,
caller_filename, caller_line_no))
return
sys.settrace(tracer)
......@@ -32,6 +33,7 @@ if __name__ == '__main__':
def b(a):
print(2)
def a():
print(1)
b(1)
......
......@@ -12,11 +12,14 @@ from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
@memoized
def log_once(s):
logger.warn(s)
# just a placeholder
@six.add_metaclass(ABCMeta)
class Discretizer(object):
......@@ -28,10 +31,13 @@ class Discretizer(object):
def get_bin(self, v):
pass
class Discretizer1D(Discretizer):
pass
class UniformDiscretizer1D(Discretizer1D):
def __init__(self, minv, maxv, spacing):
"""
:params minv: minimum value of the first bin
......@@ -54,8 +60,8 @@ class UniformDiscretizer1D(Discretizer1D):
log_once("UniformDiscretizer1D: value larger than max!")
return self.nr_bin - 1
return int(np.clip(
(v - self.minv) / self.spacing,
0, self.nr_bin - 1))
(v - self.minv) / self.spacing,
0, self.nr_bin - 1))
def get_bin_center(self, bin_id):
return self.minv + self.spacing * (bin_id + 0.5)
......@@ -69,17 +75,18 @@ class UniformDiscretizer1D(Discretizer1D):
if v >= self.maxv or v <= self.minv:
return ret
try:
for k in range(1, smooth_radius+1):
ret[b+k] = smooth_factor ** k
for k in range(1, smooth_radius + 1):
ret[b + k] = smooth_factor ** k
except IndexError:
pass
for k in range(1, min(smooth_radius+1, b+1)):
ret[b-k] = smooth_factor ** k
for k in range(1, min(smooth_radius + 1, b + 1)):
ret[b - k] = smooth_factor ** k
ret /= ret.sum()
return ret
class UniformDiscretizerND(Discretizer):
def __init__(self, *min_max_spacing):
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
......@@ -122,6 +129,5 @@ class UniformDiscretizerND(Discretizer):
if __name__ == '__main__':
#u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
import IPython as IP;
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
......@@ -3,13 +3,15 @@
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, sys
import os
import sys
from six.moves import urllib
import errno
from . import logger
__all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists"""
assert dirname is not None
......@@ -21,6 +23,7 @@ def mkdir_p(dirname):
if e.errno != errno.EEXIST:
raise e
def download(url, dir):
mkdir_p(dir)
fname = url.split('/')[-1]
......@@ -29,7 +32,7 @@ def download(url, dir):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname,
min(float(count * block_size)/ total_size,
min(float(count * block_size) / total_size,
1.0) * 100.0))
sys.stdout.flush()
try:
......@@ -45,6 +48,7 @@ def download(url, dir):
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath
def recursive_walk(rootdir):
for r, dirs, files in os.walk(rootdir):
for f in files:
......
......@@ -9,13 +9,15 @@ import argparse
__all__ = ['globalns', 'use_global_argument']
if six.PY2:
class NS: pass
class NS:
pass
else:
import types
NS = types.SimpleNamespace
globalns = NS()
def use_global_argument(args):
"""
Add the content of argparse.Namespace to globalns
......
......@@ -8,20 +8,22 @@ from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val):
val = str(val)
if val == '-1':
val = ''
return change_env('CUDA_VISIBLE_DEVICES', val)
def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return len(env.split(','))
def get_gpus():
""" return a list of GPU physical id"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(','))
......@@ -19,7 +19,9 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
class CaffeLayerProcessor(object):
def __init__(self, net):
self.net = net
self.layer_names = net._layer_names
......@@ -42,14 +44,14 @@ class CaffeLayerProcessor(object):
self.param_dict.update(dic)
elif len(layer.blobs) != 0:
logger.warn(
"{} layer contains parameters but is not supported!".format(layer.type))
"{} layer contains parameters but is not supported!".format(layer.type))
return self.param_dict
def proc_conv(self, idx, name, param):
assert len(param) <= 2
assert param[0].data.ndim == 4
# caffe: ch_out, ch_in, h, w
W = param[0].data.transpose(2,3,1,0)
W = param[0].data.transpose(2, 3, 1, 0)
if len(param) == 1:
return {name + '/W': W}
else:
......@@ -65,7 +67,7 @@ class CaffeLayerProcessor(object):
logger.info("FC layer {} takes spatial data.".format(name))
W = param[0].data
# original: outx(CxHxW)
W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2,3,1,0)
W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2, 3, 1, 0)
# become: (HxWxC)xout
else:
W = param[0].data.transpose()
......@@ -74,8 +76,8 @@ class CaffeLayerProcessor(object):
def proc_bn(self, idx, name, param):
assert param[2].data[0] == 1.0
return {name +'/mean/EMA': param[0].data,
name +'/variance/EMA': param[1].data }
return {name + '/mean/EMA': param[0].data,
name + '/variance/EMA': param[1].data}
def proc_scale(self, idx, name, param):
bottom_name = self.net.bottom_names[name][0]
......@@ -89,7 +91,7 @@ class CaffeLayerProcessor(object):
logger.info("Merge {} and {} into one BatchNorm layer".format(
name, name2))
return {name2 + '/beta': param[1].data,
name2 + '/gamma': param[0].data }
name2 + '/gamma': param[0].data}
# assume this scaling layer is part of some BN
logger.error("Could not find a BN layer corresponding to this Scale layer!")
raise ValueError()
......@@ -104,10 +106,11 @@ def load_caffe(model_desc, model_file):
caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST)
param_dict = CaffeLayerProcessor(net).process()
logger.info("Model loaded from caffe. Params: " + \
logger.info("Model loaded from caffe. Params: " +
" ".join(sorted(param_dict.keys())))
return param_dict
def get_caffe_pb():
dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
......@@ -116,7 +119,7 @@ def get_caffe_pb():
assert os.path.isfile(os.path.join(dir, 'caffe.proto'))
ret = os.system('cd {} && protoc caffe.proto --python_out .'.format(dir))
assert ret == 0, \
"Command `protoc caffe.proto --python_out .` failed!"
"Command `protoc caffe.proto --python_out .` failed!"
import imp
return imp.load_source('caffepb', caffe_pb_file)
......@@ -131,4 +134,3 @@ if __name__ == '__main__':
import numpy as np
np.save(args.output, ret)
......@@ -3,7 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import logging
import os, shutil
import os
import shutil
import os.path
from termcolor import colored
from datetime import datetime
......@@ -12,7 +13,9 @@ import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', 'warn_dependency']
class _MyFormatter(logging.Formatter):
def format(self, record):
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s'
......@@ -28,6 +31,7 @@ class _MyFormatter(logging.Formatter):
self._fmt = fmt
return super(_MyFormatter, self).format(record)
def _getlogger():
logger = logging.getLogger('tensorpack')
logger.propagate = False
......@@ -45,6 +49,8 @@ def get_time_str():
# logger file and directory:
global LOG_FILE, LOG_DIR
LOG_DIR = None
def _set_file(path):
if os.path.isfile(path):
backup_name = path + '.' + get_time_str()
......@@ -56,6 +62,7 @@ def _set_file(path):
_logger.addHandler(hdl)
_logger.info("Argv: " + ' '.join(sys.argv))
def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
......@@ -98,11 +105,13 @@ _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception',
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
def disable_logger():
""" disable all logging ability from this moment"""
for func in _LOGGING_METHOD:
globals()[func] = lambda x: None
def auto_set_dir(action=None, overwrite=False):
""" set log directory to a subdir inside 'train_log', with the name being
the main python file currently running"""
......@@ -112,9 +121,10 @@ def auto_set_dir(action=None, overwrite=False):
mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__)
set_logger_dir(
os.path.join('train_log',
basename[:basename.rfind('.')]),
action=action)
os.path.join('train_log',
basename[:basename.rfind('.')]),
action=action)
def warn_dependency(name, dependencies):
warn("Failed to import '{}', {} won't be available'".format(dependencies, name))
......@@ -7,10 +7,12 @@ import six
__all__ = ['LookUpTable']
class LookUpTable(object):
def __init__(self, objlist):
self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v : k for k, v in six.iteritems(self.idx2obj)}
self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
def size(self):
return len(self.idx2obj)
......
......@@ -5,6 +5,7 @@
import numpy as np
class Rect(object):
"""
A Rectangle.
......@@ -68,7 +69,7 @@ class Rect(object):
def roi(self, img):
assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2])
return img[self.y0:self.y1+1, self.x0:self.x1+1]
return img[self.y0:self.y1 + 1, self.x0:self.x1 + 1]
def expand(self, frac):
assert frac > 1.0, frac
......@@ -92,7 +93,7 @@ class Rect(object):
xmax = min(self.x1, img.shape[1])
ymax = min(self.y1, img.shape[0])
patch = img[ymin:ymax, xmin:xmax]
ret[ystart:ystart+patch.shape[0],xstart:xstart+patch.shape[1]] = patch
ret[ystart:ystart + patch.shape[0], xstart:xstart + patch.shape[1]] = patch
return ret
__repr__ = __str__
......@@ -101,6 +102,6 @@ class Rect(object):
if __name__ == '__main__':
x = Rect(2, 1, 3, 3, allow_neg=True)
img = np.random.rand(3,3)
img = np.random.rand(3, 3)
print(img)
print(x.roi_zeropad(img))
......@@ -10,10 +10,12 @@ msgpack_numpy.patch()
__all__ = ['loads', 'dumps']
def dumps(obj):
#return dill.dumps(obj)
# return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
#return dill.loads(buf)
# return dill.loads(buf)
return msgpack.loads(buf)
This diff is collapsed.
......@@ -14,10 +14,12 @@ from .stats import StatCounter
from . import logger
__all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter']
'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object):
""" To count how often some code gets reached"""
def __init__(self, print_every, name=None):
self.cnt = 0
self.print_every = int(print_every)
......@@ -36,6 +38,7 @@ class IterSpeedCounter(object):
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
......@@ -47,6 +50,7 @@ def timed_operation(msg, log_start=False):
_TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager
def total_timer(msg):
start = time.time()
......@@ -54,6 +58,7 @@ def total_timer(msg):
t = time.time() - start
_TOTAL_TIMER_DATA[msg].feed(t)
def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment