import socket
import struct
import time
import threading
import random
import time
import numpy as np
import argparse
import csv

parser = argparse.ArgumentParser(description='Mininet demo')

parser.add_argument('--fid', help='Funtion id',
                    type=int, action="store", required=False)
parser.add_argument('--c', help='Concurrency',
                   type=int, action="store", required=True)
parser.add_argument('--t', help='Runtime',
                   type=int, action="store", required=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--rps', help='Requests per second',
                   type=int, action="store")
group.add_argument('--n', help='Number of requests to send',
                   type=int, action="store")


args = parser.parse_args()

PORT = 8000
dataInt = 0
fid = args.fid
runtime = args.t
concurrency = args.c
SERVER_IP = "192.168.2.3"

# packet_holder = [None] * 11
packet_holder = [[] for i in range(12)]
ingress_time = {}
stop_thread = False


def receive(i):
    global stop_thread, packet_holder
    CLIENT_IP = "0.0.0.0"
    port = 10000 + i
    #print i
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    # s.setblocking(0)
    s.bind((CLIENT_IP, port))
    # s.setblocking(0)
    print("listening to {} at port {}".format(CLIENT_IP, port))
    run_status = {}
    packet_holder[i] = []
    while True:
        if stop_thread:
            print "stop thread r"
            break
        packet, addr = s.recvfrom(1024)
        #print "packet received : ", packet
        packet_holder[i].append((packet, time.time()   ))
        # print "r", "{0:f}".format((time.time() * 1000)), "{0:f}".format(ingress_time[exec_id])


def genPacket():
    global fid
    packet = None
    exec_id = random.randint(0, 2 ** 30)
    chain_id = 1
    function_count = 5
    function_id = fid if (fid) else 1
    f0 = 0; f1 = 1; f2 = 2; f3 = 0; f4 = 0
    # print chain_id, exec_id, "function_id", function_id, function_count, \
    #     f0, f1, f2, f3, f4,
    
    chain_id = struct.pack(">I", chain_id)  # chain id
    exec_id_packed = struct.pack(">I", exec_id)  # execution id

    dataInt =1
    # print " dataInt", dataInt
    data = struct.pack(">I", dataInt)  # data
    
    function_count = struct.pack("B", function_count)  # function count
    function_id = struct.pack(">I", function_id)
    f0 = struct.pack("B", f0)  # f0
    f1 = struct.pack("B", f1)  # f1
    f2 = struct.pack("B", f2)  # f2 -> f0
    f3 = struct.pack("B", f3)  # f3 -> f1 f2
    f4 = struct.pack("B", f4)  # f4 -> f3
    
    packet = chain_id + exec_id_packed + function_id + data + function_count + f0 + f1 + f2 + f3 + f4
    # print dataInt, offload_status
    return packet, exec_id


def sendThread(start_time, runtime, sleep_time):
    global ingress_time
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    while True:
        if time.time() - start_time > runtime:
            break
        packet, exec_id = genPacket()
        if exec_id in ingress_time:
            continue
        s.sendto(packet, (SERVER_IP, PORT))
        ingress_time[exec_id] = time.time()   
        time.sleep(sleep_time)


def send():
    global egress_time, ingress_time, concurrency, runtime, stop_thread
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    print("Sending packet to %s at port %s" % (SERVER_IP, PORT))
    print("Runtime: %d Concurrency %d" % (runtime, concurrency))
    print("chain id, exec id, data, function count, functions dependencies...")

    # op = struct.unpack("B", packet[0])
    if args.n is not None:
        for i in range(args.n):
            packet, exec_id = genPacket()
            s.sendto(packet, (SERVER_IP, PORT))
            ingress_time[exec_id] = time.time() * 1000
            print("send", "{0:f}".format(ingress_time[exec_id]))


    elif args.rps is not None:

        start_time = time.time()
        sleep_time = concurrency / float(args.rps)
        print("calculated inter-arrival time, offload mode", sleep_time)
        for i in range(concurrency):
            t = threading.Thread(target=sendThread, args=[
                                 start_time, runtime, sleep_time])
            t.daemon = True
            t.start()
        time.sleep(runtime)

    print "stoppping thread"
    stop_thread = True
    print "thread stopped"
    # s.sendto(packet, (SERVER_IP, PORT))
    # r.join()

def printStatistics():
    global runtime
    e2e_time = []
    for packetThread in packet_holder:
        for packetTuple in packetThread:
            packet = packetTuple[0]
            base = 0
            chain_id = struct.unpack(">I", packet[base:base + 4])[0]
            base += 4
            exec_id = struct.unpack(">I", packet[base:base + 4])[0]
            e2e_time.append((packetTuple[1] - ingress_time[exec_id])* 1000)

    #print e2e_time
    data = np.array(e2e_time, dtype=float)
    np.savetxt("bm_static_1.csv", data, delimiter=' ', header='')
    p50 = np.percentile(data, 50)
    p95 = np.percentile(data, 95)
    p99 = np.percentile(data, 99)
    mean = np.mean(data)
    print("mean \t p50 \t p95 \t p99")
    print(mean, p50, p95, p99)

    fields=[args.rps, mean, len(e2e_time) / runtime, len(ingress_time), p50, p95, p99]
    with open('speedo_data_static2_1f_host.csv', 'a') as f:
        writer = csv.writer(f)
        writer.writerow(fields)
        
    print("rps", len(e2e_time) / runtime, len(ingress_time))
    return 0

r=None
for i in range(0, 11):
    r = threading.Thread(name="receive", target=receive, args=[i])
    r.daemon = True
    r.start()

time.sleep(1)
send()
time.sleep(170)
# r.join()
printStatistics()
#print "packet holder : ",packet_holder
#print "ingress_time : ",ingress_time