#!/usr/bin/env python3
import collections
from redis.client import StrictRedis
import os
import time

while True:
    with open('.sai.prom.new', 'w') as f:
        r = StrictRedis(db=2, decode_responses=True)
        appldb = StrictRedis(db=0, decode_responses=True)
        metrics = collections.defaultdict(lambda: collections.defaultdict(collections.Counter))
        intf_oid_dict = r.hgetall('COUNTERS_PORT_NAME_MAP')
        for intf, oid in list(intf_oid_dict.items()) + list(r.hgetall('COUNTERS_LAG_NAME_MAP').items()):
            if not intf:
                continue
            oids = [oid]
            if intf.startswith('PortChannel'):
                # TODO: There seems to be no SAI counters for LAGs, so we emulate them by
                # summing all the underlying interface counters. This will break if
                # any of the underlying interfaces reset to zero without all of them
                # doing the same. But hopefully that never happens. lol
                oids = []
                for member_rec in appldb.keys(f'LAG_MEMBER_TABLE:{intf}:Ethernet*'):
                    member = member_rec.split(':')[2]
                    oids.append(intf_oid_dict[member])
            for oid in oids:
                for cntr, value in r.hgetall('COUNTERS:' + oid).items():
                    if cntr.startswith('SAI_PORT_STAT_ETHER_IN_PKTS_'):
                        suf = cntr[28:-7].replace('_TO_',':')
                        metrics['SAI_PORT_STAT_ETHER_IN_PKTS_'][suf][intf] += int(value)
                    elif cntr.startswith('SAI_PORT_STAT_ETHER_OUT_PKTS_'):
                        suf = cntr[29:-7].replace('_TO_',':')
                        metrics['SAI_PORT_STAT_ETHER_OUT_PKTS_'][suf][intf] += int(value)
                    elif cntr.startswith('SAI_PORT_STAT_PFC_'):
                        prio = cntr[18]
                        if '_RX_' in cntr:
                            metrics['SAI_PORT_STAT_PFC_RX_PKTS'][prio][intf] += int(value)
                        else:
                            metrics['SAI_PORT_STAT_PFC_TX_PKTS'][prio][intf] += int(value)
                    else:
                        metrics[cntr][0][intf] += int(value)

        def emit(name, values, **kwargs):
            for interface, val in values.items():
                kwargs['interface'] = interface
                labels = ','.join('%s="%s"' % (k, str(v)) for k, v in kwargs.items())
                print('%s{%s} %s' % (name, labels, val), file=f)

        simple_metrics = {
                'SAI_PORT_STAT_ETHER_RX_OVERSIZE_PKTS': 'sai_port_rx_oversize_packets_total',
                'SAI_PORT_STAT_ETHER_STATS_FRAGMENTS': 'sai_port_fragment_packets_total',
                'SAI_PORT_STAT_ETHER_STATS_JABBERS': 'sai_port_jabbers_packets_total',
                'SAI_PORT_STAT_ETHER_STATS_TX_NO_ERRORS': 'sai_port_tx_packets_no_errors_total',
                'SAI_PORT_STAT_ETHER_STATS_UNDERSIZE_PKTS': 'sai_port_undersized_packets_total',
                'SAI_PORT_STAT_ETHER_TX_OVERSIZE_PKTS': 'sai_port_tx_oversized_packets_total',
                'SAI_PORT_STAT_IF_IN_BROADCAST_PKTS': 'sai_port_in_broadcast_packets_total',
                'SAI_PORT_STAT_IF_IN_DISCARDS': 'sai_port_in_discarded_packets_total',
                'SAI_PORT_STAT_IF_IN_ERRORS': 'sai_port_in_errored_packets_total',
                'SAI_PORT_STAT_IF_IN_MULTICAST_PKTS': 'sai_port_in_multicast_packets_total',
                'SAI_PORT_STAT_IF_IN_NON_UCAST_PKTS': 'sai_port_in_non_unicast_packets_total',
                'SAI_PORT_STAT_IF_IN_UCAST_PKTS': 'sai_port_in_unicast_packets_total',
                'SAI_PORT_STAT_IF_IN_UNKNOWN_PROTOS': 'sai_port_in_unknown_protocols_total',
                'SAI_PORT_STAT_IF_OUT_BROADCAST_PKTS': 'sai_port_out_broadcast_packets_total',
                'SAI_PORT_STAT_IF_OUT_DISCARDS': 'sai_port_out_discarded_packets_total',
                'SAI_PORT_STAT_IF_OUT_ERRORS': 'sai_port_out_errored_packets_total',
                'SAI_PORT_STAT_IF_OUT_MULTICAST_PKTS': 'sai_port_out_multicast_packets_total',
                'SAI_PORT_STAT_IF_OUT_NON_UCAST_PKTS': 'sai_port_out_non_unicast_packets_total',
                'SAI_PORT_STAT_IF_OUT_QLEN': 'sai_port_out_queue_length',
                'SAI_PORT_STAT_IF_OUT_UCAST_PKTS': 'sai_port_out_unicast_packets_total',
                'SAI_PORT_STAT_IP_IN_RECEIVES': 'sai_port_ip_in_packets_total',
                'SAI_PORT_STAT_IP_IN_UCAST_PKTS': 'sai_port_ip_in_unicast_packets_total',
                'SAI_PORT_STAT_PAUSE_RX_PKTS': 'sai_port_rx_pause_frames_total',
                'SAI_PORT_STAT_PAUSE_TX_PKTS': 'sai_port_tx_pause_frames_total',
                'SAI_PORT_STAT_OUT_DROPPED_PKTS': 'sai_port_out_dropped_packets_total',
                'SAI_PORT_STAT_IN_DROPPED_PKTS': 'sai_port_in_dropped_packets_total',
                'SAI_PORT_STAT_IF_IN_FEC_SYMBOL_ERRORS': 'sai_port_in_fec_symbol_errors_total',
                'SAI_PORT_STAT_IF_IN_FEC_NOT_CORRECTABLE_FRAMES': 'sai_port_in_fec_not_correctable_frames_total',
                'SAI_PORT_STAT_IF_IN_FEC_CORRECTABLE_FRAMES': 'sai_port_in_fec_correctable_frames_total',
        }

        for metric, values in metrics.items():
            print(file=f)
            if metric in simple_metrics:
                print('# HELP', simple_metrics[metric], 'SAI metric', metric, file=f)
                emit(simple_metrics[metric], values[0])
            elif metric == 'SAI_PORT_STAT_PFC_RX_PKTS':
                print('# HELP sai_port_rx_pfc_packets_total SAI metric for SAI_PORT_STAT_PFC_*_RX_PKTS', file=f)
                for prio, innervals in values.items():
                    emit('sai_port_rx_pfc_packets_total', innervals, priority=prio)
            elif metric == 'SAI_PORT_STAT_PFC_TX_PKTS':
                print('# HELP sai_port_tx_pfc_packets_total SAI metric for SAI_PORT_STAT_PFC_*_TX_PKTS', file=f)
                for prio, innervals in values.items():
                    emit('sai_port_tx_pfc_packets_total', innervals, priority=prio)

            elif metric == 'SAI_PORT_STAT_ETHER_IN_PKTS_':
                print('# HELP sai_port_in_packet_size_bytes SAI metric for SAI_PORT_STAT_ETHER_IN_PKTS_*_OCTETS', file=f)
                print('# TYPE sai_port_in_packet_size_bytes histogram', file=f)
                cntr = collections.Counter()
                cntr = cntr + values['64']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=64)
                cntr = cntr + values['65:127']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=127)
                cntr = cntr + values['128:255']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=255)
                cntr = cntr + values['256:511']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=511)
                cntr = cntr + values['512:1023']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=1023)
                cntr = cntr + values['1024:1518']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=1518)
                cntr = cntr + values['1519:2047']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=2047)
                cntr = cntr + values['2048:4095']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=4095)
                cntr = cntr + values['4096:9216']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le=9216)
                cntr = cntr + values['9217:16383']
                emit('sai_port_in_packet_size_bytes_bucket', cntr, le="+Inf")
                emit('sai_port_in_packet_size_bytes_count', cntr)
            elif metric == 'SAI_PORT_STAT_ETHER_OUT_PKTS_':
                print('# HELP sai_port_out_packet_size_bytes SAI metric for SAI_PORT_STAT_ETHER_OUT_PKTS_*_OCTETS', file=f)
                print('# TYPE sai_port_out_packet_size_bytes histogram', file=f)
                cntr = collections.Counter()
                cntr = cntr + values['64']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=64)
                cntr = cntr + values['65:127']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=127)
                cntr = cntr + values['128:255']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=255)
                cntr = cntr + values['256:511']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=511)
                cntr = cntr + values['512:1023']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=1023)
                cntr = cntr + values['1024:1518']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=1518)
                cntr = cntr + values['1519:2047']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=2047)
                cntr = cntr + values['2048:4095']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=4095)
                cntr = cntr + values['4096:9216']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le=9216)
                cntr = cntr + values['9217:16383']
                emit('sai_port_out_packet_size_bytes_bucket', cntr, le="+Inf")
                emit('sai_port_out_packet_size_bytes_count', cntr)
            elif metric == 'SAI_PORT_STAT_IF_IN_OCTETS':
                emit('sai_port_in_packet_size_bytes_sum', values[0])
            elif metric == 'SAI_PORT_STAT_IF_OUT_OCTETS':
                emit('sai_port_out_packet_size_bytes_sum', values[0])
            else:
                print('Unknown metric:', metric)
    os.rename('.sai.prom.new', 'sai.prom')
    time.sleep(5)