import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import getopt
import csv
import sys

filename = ''

try:
    opts, args = getopt.getopt(sys.argv[1:], "", ["data="])
    for arg, val in opts:
        if arg == '--data':
            filename = val
except:
    print('Invalid options specified')
    exit(1)

if not filename:
    print('Invalid file specified')
    exit()

df = pd.read_csv(filename)

count=1
for instance_id, grp1 in df.groupby(['instance']):
    #if instance_id == 'instance-1':
    if True:
        plt.figure(count, figsize=(8, 6))
        count = count + 1
        for algorithm, grp2 in grp1.groupby(['algorithm', 'epsilon']):
            rlist = []
            hlist = []
            algo_label = algorithm[0]
            if algorithm[1]:
                algo_label = algo_label + " " + str(algorithm[1])
            for horizon, grp3 in grp2.groupby(['horizon']):
                hlist.append(horizon)
                #maxr = grp3['regret'].max()
                #rlist.append(grp3['regret'].apply(lambda x: maxr - x).sum())
                rlist.append(grp3['regret'].mean())
            plt.plot(hlist, rlist, label=algo_label, marker=".")
        plt.xscale("log")
        plt.yscale("log")
        plt.legend()
        plt.xlabel("Horizon")
        plt.ylabel("Regret")
        plt.title(instance_id)
        plt.savefig(f"instance{instance_id.split('-')[1]}.png")

""" i1=[]
i2=[]
i3=[]
with open('data.csv', newline='') as csvfile:
    csvreader = csv.reader(csvfile)
    for row in csvreader:
        if row[0] == "instance-1":
            i1.append(row)
        elif row[0] == "instance-2":
            i2.append(row)
        elif row[0] == "instance-3":
            i3.append(row)
instance_1=pd.DataFrame(i1)
instance_1.columns=['instance','algorithm','random_seed','epsilon','horizon','regret']
instance_2=pd.DataFrame(i2)
instance_3=pd.DataFrame(i3)

#instance_1.set_index('horizon', inplace=True)
#instance_1.groupby(['algorithm'])['regret'].plot(legend=True)

print(instance_1)

for key, grp in instance_1.groupby(['algorithm']):
    plt.plot(grp['horizon'], grp['regret'], label=key, marker='o')
    pass

for key, grp in instance_1.groupby(['algorithm']):
    hori_list = []
    regreti_list = []
    for hori, grp2 in grp.groupby(['horizon']):
        regrmax = ((grp2['regret'].explode().astype(int).max()))
        regreti = ((grp2['regret'].apply(lambda x: regrmax-int(x))).sum())
        hori_list.append(hori)
        regreti_list.append(regreti_list)
    plt.plot(hori_list, regreti_list)


plt.xscale("log")
plt.yscale("log")
plt.show()
 """
#plt.show()

#grouped_df=instance_1.groupby(['algorithm'])

#for key, item in grouped_df:
    #print(grouped_df.get_group(key), "\n\n")
#instance_1.plot(kind='scatter',x='horizon',y='regret',color='red')
#plt.xscale("log")
#plt.yscale("log")
#plt.show()

#data.groupby(level="instance",axis=0)
