Commit 321440af authored by Yuxin Wu's avatar Yuxin Wu

misc clean-ups

parent 8ef16f14
......@@ -61,7 +61,9 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
def __init__(self, idx, pipe_c2s, pipe_s2c):
"""
:param idx: idx of this process
Args:
idx: idx of this process
pipe_c2s, pipe_s2c (str): name of the pipe
"""
super(SimulatorProcessStateExchange, self).__init__(idx)
self.c2s = pipe_c2s
......@@ -177,111 +179,6 @@ class SimulatorMaster(threading.Thread):
self.context.destroy(linger=0)
# ------------------- the following code are not used at all. Just experimental
class SimulatorProcessDF(SimulatorProcessBase):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def __init__(self, idx, pipe_c2s):
super(SimulatorProcessDF, self).__init__(idx)
self.pipe_c2s = pipe_c2s
def run(self):
self.player = self._build_player()
self.ctx = zmq.Context()
self.c2s_socket = self.ctx.socket(zmq.PUSH)
self.c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
self.c2s_socket.set_hwm(5)
self.c2s_socket.connect(self.pipe_c2s)
self._prepare()
for dp in self.get_data():
self.c2s_socket.send(dumps(dp), copy=False)
@abstractmethod
def _prepare(self):
pass
@abstractmethod
def get_data(self):
pass
class SimulatorProcessSharedWeight(SimulatorProcessDF):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.condvar = condvar
self.shared_dic = shared_dic
self.pred_config = pred_config
def _prepare(self):
disable_layer_logging()
self.predictor = OfflinePredictor(self.pred_config)
with self.predictor.graph.as_default():
vars_to_update = self._params_to_update()
self.sess_updater = SessionUpdate(
self.predictor.session, vars_to_update)
# TODO setup callback for explore?
self.predictor.graph.finalize()
self.weight_lock = threading.Lock()
# start a thread to wait for notification
def func():
self.condvar.acquire()
while True:
self.condvar.wait()
self._trigger_evt()
self.evt_th = threading.Thread(target=func)
self.evt_th.daemon = True
self.evt_th.start()
def _trigger_evt(self):
with self.weight_lock:
self.sess_updater.update(self.shared_dic['params'])
logger.info("Updated.")
def _params_to_update(self):
# can be overwritten to update more params
return tf.trainable_variables()
class WeightSync(Callback):
""" Sync weight from main process to shared_dic and notify"""
def __init__(self, condvar, shared_dic):
self.condvar = condvar
self.shared_dic = shared_dic
def _setup_graph(self):
self.vars = self._params_to_update()
def _params_to_update(self):
# can be overwritten to update more params
return tf.trainable_variables()
def _before_train(self):
self._sync()
def _trigger_epoch(self):
self._sync()
def _sync(self):
logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic
self.condvar.acquire()
self.condvar.notify_all()
self.condvar.release()
if __name__ == '__main__':
import random
from tensorpack.RL import NaiveRLEnvironment
......
......@@ -94,7 +94,6 @@ class OnlineTensorboardExport(Callback):
class Model(ModelDesc):
def _get_inputs(self):
# TODO: allow arbitrary batch sizes
return [InputDesc(tf.float32, (BATCH, ), 'theta'),
InputDesc(tf.float32, (BATCH, SHAPE, SHAPE), 'image'),
InputDesc(tf.float32, (BATCH, SHAPE, SHAPE), 'gt_image'),
......@@ -120,9 +119,9 @@ class Model(ModelDesc):
logger.info('Parameter net output: {}'.format(pred_filter.get_shape().as_list()))
return pred_filter
def _build_graph(self, input_vars):
def _build_graph(self, inputs):
kernel_size = 9
theta, image, gt_image, gt_filter = input_vars
theta, image, gt_image, gt_filter = inputs
image = image
gt_image = gt_image
......
......@@ -73,8 +73,8 @@ class Model(GANModelDesc):
.FullyConnected('fct', 1, nl=tf.identity)())
return l
def _build_graph(self, input_vars):
image_pos, y = input_vars
def _build_graph(self, inputs):
image_pos, y = inputs
image_pos = tf.expand_dims(image_pos * 2.0 - 1, -1)
y = tf.one_hot(y, 10, name='label_onehot')
......
......@@ -138,7 +138,7 @@ class Model(ModelDesc):
br1 = AvgPooling('avgpool', l, 5, 3, padding='VALID')
br1 = Conv2D('conv11', br1, 128, 1)
shape = br1.get_shape().as_list()
br1 = Conv2D('convout', br1, 768, shape[1:3], padding='VALID') # TODO gauss, stddev=0.01
br1 = Conv2D('convout', br1, 768, shape[1:3], padding='VALID')
br1 = FullyConnected('fc', br1, 1000, nl=tf.identity)
with tf.variable_scope('incep-17-1280a'):
......
......@@ -45,7 +45,7 @@ class HookToCallback(Callback):
def _before_train(self):
sess = tf.get_default_session()
# TODO fix coord?
# coord is set to None when converting
self._hook.after_create_session(sess, None)
def _before_run(self, ctx):
......
......@@ -151,11 +151,13 @@ class Cifar100(CifarBase):
if __name__ == '__main__':
ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataflow_images
mean = ds.get_per_channel_mean()
print(mean)
dump_dataflow_images(ds, '/tmp/cifar', 100)
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
# break
import cv2
ds.reset_state()
for i, dp in enumerate(ds.get_data()):
if i == 100:
break
img = dp[0]
cv2.imwrite("{:04d}.jpg".format(i), img)
......@@ -2,7 +2,6 @@
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import sys
import os
import multiprocessing as mp
from six.moves import range
......@@ -11,35 +10,11 @@ from .base import DataFlow
from ..utils import get_tqdm, logger
from ..utils.concurrency import DIE
from ..utils.serialize import dumps
from ..utils.fs import mkdir_p
__all__ = ['dump_dataflow_images', 'dump_dataflow_to_process_queue',
__all__ = ['dump_dataflow_to_process_queue',
'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord']
def dump_dataflow_images(df, dirname, max_count=None, index=0):
""" Dump images from a DataFlow to a directory.
Args:
df (DataFlow): the DataFlow to dump.
dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
"""
# TODO pass a name_func to write label as filename?
mkdir_p(dirname)
if max_count is None:
max_count = sys.maxint
df.reset_state()
for i, dp in enumerate(df.get_data()):
if i % 100 == 0:
print(i)
if i > max_count:
return
img = dp[index]
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_process_queue(df, size, nr_consumer):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
......@@ -160,9 +135,3 @@ try:
except ImportError:
dump_dataflow_to_tfrecord = create_dummy_func( # noqa
'dump_dataflow_to_tfrecord', 'tensorflow')
try:
import cv2
except ImportError:
dump_dataflow_images = create_dummy_func( # noqa
'dump_dataflow_images', 'cv2')
......@@ -107,7 +107,7 @@ class ModelDescBase(object):
:returns: a list of InputDesc
"""
# TODO only use InputSource in the future? Now mainly used in predict/
# TODO only use InputSource in the future? Now only used in predictor_factory
def build_graph(self, inputs):
"""
Build the whole symbolic graph.
......
......@@ -67,6 +67,8 @@ class PredictorFactory(object):
input.setup(self._model.get_inputs_desc())
input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input
# TODO still using tensors here instead of inputsource
# can be fixed after having towertensorhandle inside modeldesc
self._model.build_graph(input)
desc_names = [k.name for k in self._model.get_inputs_desc()]
......@@ -88,7 +90,7 @@ class PredictorFactory(object):
tower = self._towers[tower]
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
# use a previously-built tower
# TODO conflict with inference runner??
# TODO check conflict with inference runner??
if tower_name not in self._names_built:
with tf.variable_scope(self._vs_name, reuse=True):
handle = self.build(tower_name, device)
......
......@@ -196,15 +196,14 @@ class DictRestore(SessionInit):
self.prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # TODO
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self.prms))
intersect = variable_names & param_names
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
logger.info("Params to restore: {}".format(', '.join(map(str, intersect))))
mismatch = MismatchLogger('graph', 'dict')
for k in sorted(variable_names - param_names):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment