#!/usr/bin/env python3

import optparse
import sys
import numpy as np
import pprint
import json

def parse_args():
    
    parser = optparse.OptionParser()
    
    parser.add_option("--agg1")
    parser.add_option("--agg2")
    parser.add_option("--size", default=20, type="int")
    parser.add_option("--metric", default='bytes')
    
    (options, args) = parser.parse_args()
    
    if options.agg1 == None:
        parser.error("agg1 is a required argument.")
    if options.agg2 == None:
        parser.error("agg2 is a required argument.")
    if options.metric not in ['bytes', 'packets', 'flows']:
        parser.error("metric must be bytes, packets or flows")
    
    return options

def get_nf_indices(agg1, agg2, metric):
    agg_indices = {'srcip': 3, 'dstip': 4, 'srcport': 5, 'dstport': 6}
    metric_indices = {'packets': 7, 'bytes':8, 'flows': 11}
    return agg_indices[agg1], agg_indices[agg2], metric_indices[metric]

def parse_stdin(agg1, agg2, size, metric):
    
    #~ a1 and a2 are our aggregate categories, while m is the value of the bytes/flows
    a1, a2, m = get_nf_indices(agg1, agg2, metric)
    #~ location is what index the name will be in the matrix
    location = {}
    #~ a nested group of hashes
    data = {}
    xaxis = []
    diff = []
    #~ how many we have seen
    count = 0
    
    for line in sys.stdin:
        
        sline = line.split()
        
        try:
            float(sline[-1])
        except ValueError:
            continue
        
        i = sline[a1]
        j = sline[a2]
        k = int(sline[m])
        
        if not i in location.keys():
            xaxis.append(i)
            diff.append(True)
            location[i] = count
            count += 1
        
        if not i in data.keys():
            data[i] = {}
        
        if not j in location.keys():
            xaxis.append(j)
            diff.append(False)
            location[j] = count
            count += 1
        
        data[i][j] = k
    
    xaxis.append('Other')
    diff.append(False)
    
    matrix_size = len(location.keys()) + 2 # Add one for the mask, one for other
    matrix = np.ma.zeros((matrix_size, matrix_size))
    mask_index = len(location.keys()) + 1
    other_index = len(location.keys())
    
    for i in data.keys():
        loc_i = location[i]
        for j in data[i].keys():
            loc_j = location[j]
            matrix[loc_i][loc_j] += data[i][j]
    
    sums = sorted([(i, x.sum()) for i, x in enumerate(matrix)], 
                    reverse=True, 
                    key=lambda t: t[1])
    
    to_mask = sums[size:]
    
    for index, value in to_mask:
        if index != mask_index and index != other_index:
            if diff[index] != True:
                xaxis[index] = '!'
                diff[index] = '!'
                matrix[:,other_index] += matrix[:,index]
                matrix[other_index,:] += matrix[index,:]
                matrix[index][mask_index] = float('NaN')
                matrix[mask_index][index] = float('NaN')
    
    xaxis = [x for x in xaxis if x != '!']
    diff = [x for x in diff if x != '!']
    
    matrix = np.ma.masked_invalid(matrix)
    matrix = np.ma.extras.compress_rowcols(matrix)
    
    dump = {'names': xaxis,
            'matrix': matrix.tolist(),
            'diff': diff,
            'warning': False }
            
    if count > 4998:
        dump['warning'] = True
    
    print(json.dumps(dump))


if __name__ == "__main__":
    
    options = parse_args()
    parse_stdin(options.agg1, options.agg2, options.size, options.metric)
