#
# Copyright (C) 2015-2017,  Netronome Systems, Inc.  All rights reserved.
#

import os, sys, struct, pprint, threading
from urlparse import urlparse
from contextlib import contextmanager

#sys.path.append('../../out/gen_thrift/py')

from thrift.transport import TTransport, TZlibTransport, TSocket
from thrift.protocol import TBinaryProtocol

from RTERPCInterface import (DesignBase, CountersBase, TablesBase, 
    ParserValueSetsBase, RegistersBase, MetersBase, TrafficClassBase, 
    DigestsBase, MulticastBase, DebugCtlBase, SystemBase)

from sdk6_rte import RunTimeEnvironment
from sdk6_rte.ttypes import *

import RTERPCInterface

COUNTER_TYPE_MAP = {
    P4CounterType.Global: RTERPCInterface.P4CounterType.Global,
    P4CounterType.Direct: RTERPCInterface.P4CounterType.Direct,
    P4CounterType.Static: RTERPCInterface.P4CounterType.Static,
}

REGISTER_TYPE_MAP = {
    RegisterType.Global: RTERPCInterface.RegisterType.Global,
    RegisterType.Direct: RTERPCInterface.RegisterType.Direct,
    RegisterType.Static: RTERPCInterface.RegisterType.Static,
}

METER_TYPE_MAP = {
    MeterType.Invalid: RTERPCInterface.MeterType.Invalid,
    MeterType.Global: RTERPCInterface.MeterType.Global,
    MeterType.Direct: RTERPCInterface.MeterType.Direct,
    MeterType.Static: RTERPCInterface.MeterType.Static,
}

METER_CLASS_MAP = {
    MeterClass.Invalid: RTERPCInterface.MeterClass.Invalid,
    MeterClass.Packets: RTERPCInterface.MeterClass.Packets,
    MeterClass.Bytes: RTERPCInterface.MeterClass.Bytes,
} 
# decorator to transform non SUCCESS RteReturn values to exceptions
def RteReturnHandler(err_msg=None):
    def _RteReturnHandler(func):
        def __RteReturnHandler(*args, **kwargs):
            try:
                rte_ret = func(*args, **kwargs)
            except TException, err:
                raise RTERPCInterface.RTECommError, "Communication failure with RPC server: %s"%str(err)
            else:
                if rte_ret.value != RteReturnValue.SUCCESS:
                    error = err_msg
                    if err_msg is None:
                        error = 'Error in %s'%func.func_name
                    reason = ''
                    if rte_ret.reason:
                        reason = ": %s.\nPlease see RTE log for more info"%rte_ret.reason
                    raise RTERPCInterface.RTEReturnError, '%s: %s%s'%(RTERPCInterface.RTE_RETURN_CODES[rte_ret.value], error, reason)
        return __RteReturnHandler
    return _RteReturnHandler

        
# decorator to catch thrift communication failures
def RPC(func):
    def _RPC(self, *args, **kwargs):
        with self.rte.THRIFT_API_LOCK:
            try:
                return func(self, *args, **kwargs)
            except TException, err:
                raise RTERPCInterface.RTECommError, "Communication failure with RPC server: %s"%str(err)
    return _RPC
    

class Design(DesignBase):
    @RPC
    def Load(self, elf_fw, pif_design, pif_config):
        with open(elf_fw, "rb") as f:
            elf_fw_data = f.read()
    
        pif_design_data = ""
        if pif_design:
            with open(pif_design, "rb") as f:
                pif_design_data = f.read()
    
        pif_config_data = ""
        if pif_config:
            with open(pif_config, "rb") as f:
                pif_config_data = f.read()

        self.rte.client.design_unload()

        self.rte.client.design_load(DesignLoadArgs(
            nfpfw=elf_fw_data, 
            pif_design_json=pif_design_data, 
            pif_config_json=pif_config_data,
        ))

        return True, ''

    @RPC
    def Unload(self):
        self.rte.client.design_unload()

    @RPC
    def ConfigReload(self, pif_config):
        with open(pif_config, "rb") as f:
            self.rte.client.design_reconfig(f.read())

    @RPC
    def LoadStatus(self):
        ls = self.rte.client.design_load_status()
        return {
            'is_loaded': ls.is_loaded,
            'uuid': ls.uuid, 
            'frontend_build_date': ls.frontend_build_date,
            'frontend_source': ls.frontend_source,
            'frontend_version': ls.frontend_version if ls.frontend_version is not None else '',
            'uptime': ls.uptime,
        }
        
class Counters(CountersBase):
    def ExtractRteValue(self, rv):
        return rv.intval if rv.type == RteValueType.Int64 else int(rv.stringval)

    @RPC
    def ListP4Counters(self):
        return [{
            'count': cntr.count,
            'name': cntr.name,
            'width': cntr.width,
            'tableid': cntr.tableid,
            'table': cntr.table,
            'type': COUNTER_TYPE_MAP[cntr.type],
            'id': cntr.id,
        } for cntr in self.rte.client.p4_counter_list_all()]
        
    @RPC
    def GetP4Counter(self, counter):
        counterId = self.ResolveToCounterId(counter)

        counterReturn = self.rte.client.p4_counter_retrieve(counterId)
        if counterReturn.count != -1:
            return struct.unpack('%sQ'%(counterReturn.count/8), counterReturn.data)
        else:
            return ()

    @RPC
    def ClearP4Counter(self, counter):
        counterId = self.ResolveToCounterId(counter)
        self.rte.client.p4_counter_clear(counterId)

    @RPC
    def ClearAllP4Counters(self):
        self.rte.client.p4_counter_clear_all()
        
    @RPC
    def GetSystemCounters(self):
        return [{
            'name': sc.name,
            'value': self.ExtractRteValue(sc.value),
            'id': sc.id,
        } for sc in self.rte.client.sys_counter_retrieve_all()]

    @RPC
    def ClearAllSysCounters(self):
        self.rte.client.sys_counter_clear_all()

class Tables(TablesBase):
    @RPC
    def AddRule(self, tbl_id, rule_name, default_rule, match, actions, priority = None, timeout = None):
        tbl_entry = TableEntry()
        tbl_entry.rule_name = rule_name
        tbl_entry.default_rule = default_rule
        tbl_entry.match = match
        tbl_entry.actions = actions
        if priority != None:
            tbl_entry.priority = priority
        if timeout != None:
            tbl_entry.timeout_seconds = timeout
        self.rte.client.table_entry_add(self.ResolveToTableId(tbl_id), tbl_entry)

    @RPC
    def EditRule(self, tbl_id, rule_name, default_rule, match, actions, priority = None, timeout = None):
        tbl_entry = TableEntry()
        tbl_entry.rule_name = rule_name
        tbl_entry.default_rule = default_rule
        tbl_entry.match = match
        tbl_entry.actions = actions
        if priority != None:
            tbl_entry.priority = priority
        if timeout != None:
            tbl_entry.timeout_seconds = timeout
        self.rte.client.table_entry_edit(self.ResolveToTableId(tbl_id), tbl_entry)

    @RPC
    def DeleteRule(self, tbl_id, rule_name, default_rule, match, actions):
        tbl_entry = TableEntry()
        tbl_entry.rule_name = rule_name
        tbl_entry.default_rule = default_rule
        tbl_entry.match = match
        tbl_entry.actions = actions
        self.rte.client.table_entry_delete(self.ResolveToTableId(tbl_id), tbl_entry)

    @RPC
    def List(self):
        return [{
            'tbl_name': td.tbl_name,
            'tbl_id': td.tbl_id,
            'support_timeout': td.support_timeout,
            'tbl_entries_max': td.tbl_entries_max,
        } for td in self.rte.client.table_list_all()]
    
    @RPC
    def ListRules(self, tbl_id):
        return [{
            'timeout_seconds': te.timeout_seconds,
            'actions': te.actions,
            'priority': te.priority,
            'rule_name': te.rule_name,
            'default_rule': te.default_rule,
            'match': te.match,
        } for te in self.rte.client.table_retrieve(self.ResolveToTableId(tbl_id))]

    @RPC
    def GetVersion(self):
        return self.rte.client.table_version_get()

class Registers(RegistersBase):
    @RPC
    def List(self):
        return [{
            'name': reg.name,
            'id': reg.id,
            'type': REGISTER_TYPE_MAP[reg.type],
            'count': reg.count,
            'table': reg.table,
            'tableid': reg.tableid,
            'fields': [{
                'name': fld.name,
                'width': fld.width,             
            } for fld in reg.fields],
        } for reg in self.rte.client.register_list_all()]
    
    def ResolveToRegisterArrayArg(self, register, index, count):
        reg = None
        if isinstance(register, int):
            reg_id = register
        elif isinstance(register, str):
            reg = self.GetRegisterByName(register)
            reg_id = reg['id']
        else:
            raise RTERPCInterface.RTEError, "Unhandled register parameter type: %s"%type(register)

        if count == -1:
            if reg is None:
                reg = self.GetRegisterByName(register)
            count = reg['count']
        return RegisterArrayArg(reg_id=reg_id, index=index, count=count)
            
    @RPC
    def Get(self, register, index=0, count=1):
        return self.rte.client.register_retrieve(self.ResolveToRegisterArrayArg(register, index, count))
            
    @RPC
    def Clear(self, register, index=0, count=1):
        self.rte.client.register_clear(self.ResolveToRegisterArrayArg(register, index, count))

    @RPC
    def Set(self, register, values, index=0, count=1):
        self.rte.client.register_set(self.ResolveToRegisterArrayArg(register, index, count), values)

    @RPC
    def SetField(self, register, field_id, value, index=0, count=1):
        self.rte.client.register_field_set(self.ResolveToRegisterArrayArg(register, index, count), field_id, value)

class TrafficClass(TrafficClassBase):
    @RPC
    def Get(self, port_id):
        return [{
            'class_id': tcc.class_id,
            'weight': tcc.weight,
            'queue_no': tcc.queue_no,
            'committed': tcc.committed, 
        } for tcc in self.rte.client.traffic_class_get(port_id)]
            
    @RPC
    def Set(self, port_id, cfgs):
        self.rte.client.traffic_class_set(port_id, [TrafficClassCfg(**cfg) for cfg in cfgs])

    @RPC
    def Commit(self, port_id):
        self.rte.client.traffic_class_commit(port_id)

class Meters(MetersBase):
    @RPC
    def List(self):
        return [{
            'name': mtr.name,
            'id': mtr.id, 
            'type': METER_TYPE_MAP[mtr.type],
            'mclass': METER_CLASS_MAP[mtr.mclass],
            'count': mtr.count,
            'table': mtr.table,
            'tableid': mtr.tableid,
        } for mtr in self.rte.client.meter_list_all()]
            
    @RPC
    def GetConfig(self, meter_id):
        return [{
            'rate_k': mcfg.rate_k,
            'burst_k': mcfg.burst_k,
            'array_offset': mcfg.array_offset,
            'count': mcfg.count,
        } for mcfg in self.rte.client.meter_config_get(meter_id)]

    @RPC
    def SetConfig(self, meter_id, configs):
        ops = [MeterCfg(cfg['rate'], cfg['burst'], cfg['off'], cfg['cnt']) for cfg in configs] 
        self.rte.client.meter_config_set(meter_id, ops)

class Digests(DigestsBase):
    @RPC
    def List(self):
        return [{
            'name': dd.name,
            'id': dd.id,
            'app_id': dd.app_id,
            'field_list_name': dd.field_list_name,
            'fields': [{
                'name': fld.name,
                'width': fld.width,
            } for fld in dd.fields],         
        } for dd in self.rte.client.digest_list_all()]
    
    @RPC
    def Register(self, digest_id):
        return self.rte.client.digest_register(digest_id)        

    @RPC
    def Deregister(self, digest_regid):
        return self.rte.client.digest_deregister(digest_regid)        
    
    @RPC
    def Get(self, digest_handle):
        return self.rte.client.digest_retrieve(digest_handle)        

class Multicast(MulticastBase):
    @RPC
    def List(self):
        return [{
            'group_id': mcce.group_id,
            'max_entries': mcce.max_entries,
            'ports': [] if mcce.ports is None else mcce.ports,
        } for mcce in self.rte.client.mcast_config_get_all()]
    
    @RPC
    def SetConfig(self, group_id, ports):
        cfg = McastCfgEntry(group_id, len(ports), ports)
        self.rte.client.mcast_config_set(cfg)

class System(SystemBase):
    @RPC
    def Shutdown(self):
        return self.rte.client.sys_shutdown()

    @RPC
    def Ping(self):
        return self.rte.client.sys_ping()
    
    @RPC
    def Echo(self, echo_msg):
        return self.rte.client.sys_echo(echo_msg)

    @RPC
    def GetVersion(self):
        return self.rte.client.sys_version_get()

    @RPC
    def GetLogLevel(self):
        return self.rte.client.sys_log_level_get()

    @RPC
    def SetLogLevel(self, level):
        self.rte.client.sys_log_level_set(level)

    @RPC
    def GetPortInfo(self):
        return [{
            'id': pi.id,
            'token': pi.token,
            'info': pi.info,
        } for pi in self.rte.client.ports_info_retrieve()]

class DebugCtl(DebugCtlBase):
    @RPC
    def Execute(self, debug_id, debug_data):
        res = self.rte.client.debugctl(debug_id, debug_data)
        if res.return_value == -1:
            raise RTERPCInterface.RTEError, "Error encountered during debugctl '%s'"%debug_id
        return res.return_data

class ParserValueSets(ParserValueSetsBase):
    @RPC
    def List(self):
        return [{
            'pvs_id': pvs.pvs_id,
            'pvs_name': pvs.pvs_name,
            'pvs_entries_max': pvs.pvs_entries_max,
            'key_layout': [{
                'name': fld.name,
                'width': fld.width,
            } for fld in pvs.key_layout],
        } for pvs in self.rte.client.parser_value_set_list_all()]

    @RPC
    def Clear(self, pvs_id):
        self.rte.client.parser_value_set_clear(pvs_id)
    @RPC
    def Add(self, pvs_id, pvs_entries):
        pvs_value_entries = []
        for e in pvs_entries:
            pvs_value_entries.append(ParserValueSetEntry(value=e[0], mask=e[1]))
        self.rte.client.parser_value_set_add(pvs_id, pvs_value_entries)
    @RPC
    def Retrieve(self, pvs_id):
        return [{
            'value': pvse.value,
            'mask': pvse.mask,
        } for pvse in self.rte.client.parser_value_set_retrieve(pvs_id)]


def DoConnect(conn, host, port, device_id=0, use_zlib=True, serialise_api=False):
    conn.transport = TTransport.TBufferedTransport(TSocket.TSocket(host, port))
    if use_zlib:
        conn.transport = TZlibTransport.TZlibTransport(conn.transport)
    conn.client = RunTimeEnvironment.Client(TBinaryProtocol.TBinaryProtocol(conn.transport))

    # post apply decorators
    conn.client.design_load = RteReturnHandler('Loading firmware failed')(conn.client.design_load)
    conn.client.design_unload = RteReturnHandler('Unloading firmware failed')(conn.client.design_unload)
    conn.client.design_reconfig = RteReturnHandler('Reload of user config failed')(conn.client.design_reconfig)
    conn.client.sys_log_level_set = RteReturnHandler('Set log level failed')(conn.client.sys_log_level_set)
    conn.client.table_entry_add = RteReturnHandler('Adding table entry failed')(conn.client.table_entry_add)
    conn.client.table_entry_edit = RteReturnHandler('Editing table entry failed')(conn.client.table_entry_edit)
    conn.client.table_entry_delete = RteReturnHandler('Deleting table entry failed')(conn.client.table_entry_delete)
    conn.client.p4_counter_clear = RteReturnHandler('P4 counter clear failed')(conn.client.p4_counter_clear)
    conn.client.p4_counter_clear_all = RteReturnHandler('P4 counter clear allfailed')(conn.client.p4_counter_clear_all)
    conn.client.sys_counter_clear_all = RteReturnHandler('System counter clear all failed')(conn.client.sys_counter_clear_all)
    conn.client.register_clear = RteReturnHandler('Register clear failed')(conn.client.register_clear)
    conn.client.register_field_set = RteReturnHandler('Register field set failed')(conn.client.register_field_set)
    conn.client.register_set = RteReturnHandler('Register set failed')(conn.client.register_set)
    conn.client.mcast_config_set = RteReturnHandler('Multicast config set failed')(conn.client.mcast_config_set)
    conn.client.meter_config_set = RteReturnHandler('Meter config set failed')(conn.client.meter_config_set)
    conn.client.digest_deregister = RteReturnHandler('Digest deregister failed')(conn.client.digest_deregister)
    conn.client.parser_value_set_add = RteReturnHandler('Parser value set add failed')(conn.client.parser_value_set_add)
    conn.client.parser_value_set_clear = RteReturnHandler('Parser value set clear failed')(conn.client.parser_value_set_clear)
    conn.client.traffic_class_commit = RteReturnHandler('Traffic Class commit failed')(conn.client.traffic_class_commit)
    conn.client.traffic_class_set = RteReturnHandler('Traffic Class set failed')(conn.client.traffic_class_set)
    
    try:
        conn.transport.open()
    except TException, err:
        raise RTERPCInterface.RTECommError, "Communication failure with RPC server: %s"%str(err)

    conn.THRIFT_API_LOCK = threading.Lock() if serialise_api else RTERPCInterface.NullCtx()


def DoDisconnect(conn):
    if conn.transport is not None:
        conn.transport.close()
        conn.transport = None
                
