Commit e750306b authored by Yuxin Wu's avatar Yuxin Wu

serve data

parent 8cdc6efd
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serve_data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import imp
#import cv2
#import os
from tensorpack.dataflow import serve_data
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument('-p', '--port', help='port', type=int, required=True)
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
ds = config.dataset
serve_data(ds, "tcp://*:{}".format(args.port))
...@@ -71,7 +71,7 @@ class ModelSaver(Callback): ...@@ -71,7 +71,7 @@ class ModelSaver(Callback):
except OSError: except OSError:
pass pass
os.symlink(basename, linkname) os.symlink(basename, linkname)
except OSError, IOError: # disk error sometimes.. just ignore it except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback): class MinSaver(Callback):
......
...@@ -60,4 +60,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer): ...@@ -60,4 +60,3 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
proc = EnqueProc(ds, q, nr_consumer) proc = EnqueProc(ds, q, nr_consumer)
return q, proc return q, proc
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: remote.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
try:
import zmq
except ImportError:
logger.warn("Error in 'import zmq'. remote feature won't be available")
__all__ = []
else:
__all__ = ['serve_data', 'RemoteData']
from .base import DataFlow
from .common import RepeatedData
from ..utils import logger
from ..utils.serialize import dumps, loads
def serve_data(ds, addr):
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(10)
socket.bind(addr)
ds = RepeatedData(ds, -1)
try:
logger.info("Serving data at {}".format(addr))
while True:
for dp in ds.get_data():
socket.send(dumps(dp), copy=False)
finally:
socket.setsockopt(zmq.LINGER, 0)
socket.close()
if not ctx.closed:
ctx.destroy(0)
class RemoteData(DataFlow):
def __init__(self, addr):
self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.PULL)
self.socket.set_hwm(10)
self.socket.connect(addr)
def get_data(self):
while True:
dp = loads(self.socket.recv(copy=False))
yield dp
if __name__ == '__main__':
import sys
from tqdm import tqdm
from .raw import FakeData
addr = "tcp://127.0.0.1:8877"
if sys.argv[1] == 'serve':
ds = FakeData([(128,244,244,3)], 1000)
serve_data(ds, addr)
else:
ds = RemoteData(addr)
logger.info("Each DP is 73.5MB")
with tqdm(total=10000) as pbar:
for k in ds.get_data():
pbar.update()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment