#!/usr/bin/env python3

import datetime
import optparse
import sys
import numpy as np
import logging
import pprint
import json

def parse_args():

    parser = optparse.OptionParser()

    parser.add_option("--agg1")
    parser.add_option("--agg2")
    parser.add_option("--size", default=20, type="int")
    parser.add_option("--metric", default='bytes')

    (options, args) = parser.parse_args()

    if options.agg1 == None:
        parser.error("agg1 is a required argument.")
    if options.agg2 == None:
        parser.error("agg2 is a required argument.")
    if options.metric not in ['bytes', 'packets', 'flows']:
        parser.error("metric must be bytes, packets or flows")

    return options

def get_nf_indices(agg1, agg2, metric):
    agg_indices = {'srcip': 3, 'dstip': 4, 'srcport': 5, 'dstport': 6}
    metric_indices = {'packets': 7, 'bytes':8, 'flows': 11}
    return agg_indices[agg1], agg_indices[agg2], metric_indices[metric]

def get_names(traffic):
    xaxis = []
    for name, rel_list in traffic.items():
        if name != "other":
            xaxis.append(name)
    xaxis.append("Other")
    return xaxis

def get_diffs(traffic):
    diff = []
    for name, rel_list in traffic.items():
        diff.append(False)
    return diff

def parse_stdin(agg1, agg2, size, metric):
    '''Read in the data from stdin. Construct a hash (potentially very large)
    that has the values

    {agg1: {agg2: value}}

    But for each unique agg1 that shows up. And for each agg2 that is connected
    to agg1.
    '''
    #~ a1 and a2 are our aggregate categories, while m is the value of the bytes/flows
    agg1_index, agg2_index, metric_index = get_nf_indices(agg1, agg2, metric)
    count = 0
    all_traffic = {}

    for line in sys.stdin:

        if count > 4000:
            # Break now for memory reasons, 5000 distinct items will cause massive
            # memory access and will most likely never finish. So I changed it to 4000 -JO
            break

        raw_line = line.split()

        try:
            float(raw_line[-1])
        except:
            # This is not a number is represents the header, which we cannot turn
            # into a meaningful matrix row.
            continue

        agg1_value = raw_line[agg1_index]
        agg2_value = raw_line[agg2_index]
        traffic_value = float(raw_line[metric_index])

        # Let's start by building a list of all relations
        if not agg1_value in all_traffic:
            all_traffic[agg1_value] = { 'total_traffic': 0, 'other': 0 }
        if not agg2_value in all_traffic[agg1_value]:
            all_traffic[agg1_value][agg2_value] = traffic_value
        else:
            all_traffic[agg1_value][agg2_value] += traffic_value
        all_traffic[agg1_value]['total_traffic'] += traffic_value

        count += 1

    # Sort the values based on size
    all_traffic = sorted(all_traffic.items(), reverse=True, key=lambda x: x[1]['total_traffic'])

    #print traffic

    # Check to see if we are over the limit... SLICE IT UP!
    if len(all_traffic) > size:
        sliced_traffic = dict(all_traffic[:size])
    else:
        sliced_traffic = dict(all_traffic)
    
    all_traffic = dict(all_traffic)
    sliced_traffic['other'] = {}

    #print all_traffic
        
    # Remove the total traffic and set up the "other" category
    for addr1, rel_list in sliced_traffic.items():
        for addr2, amt in list(rel_list.items()):
            if addr2 != "total_traffic":
                if not addr2 in sliced_traffic:
                    sliced_traffic[addr1]['other'] = amt
                    del sliced_traffic[addr1][addr2]
                    if not addr1 in sliced_traffic['other']:
                        sliced_traffic['other'][addr1] = 0
                    if addr2 in all_traffic:
                        if addr2 in all_traffic:
                            if addr1 in all_traffic[addr2]:
                                sliced_traffic['other'][addr1] += all_traffic[addr2][addr1]
            else:
                del sliced_traffic[addr1]['total_traffic']

    traffic = sliced_traffic

    # Get the locations 
    i = 0
    locations = {}
    for addr1, rel_list in traffic.items():
        if addr1 != "other":
            locations[addr1] = i
            i += 1
    locations['other'] = i

    #print ""
    #print locations
    #print ""

    # for k, v in traffic.items():
    #     print k + " //"
    #     print v
    #     print ""

    # Create a matrix
    matrix_size = len(traffic)
    matrix = np.ma.zeros((matrix_size, matrix_size))

    # Add to the matrix!
    for addr1, rel_list in traffic.items():
        for addr2, bytes in rel_list.items():
            matrix[locations[addr1]][locations[addr2]] = bytes

    #print matrix

    dump = {'names': get_names(traffic),
            'matrix': matrix.tolist(),
            'diff': get_diffs(traffic),
            'warning': False}
    
    print(json.dumps(dump))
    sys.exit(0)

if __name__ == "__main__":

    logging.basicConfig(level=logging.DEBUG)
    options = parse_args()
    parse_stdin(options.agg1, options.agg2, options.size, options.metric)