#!/usr/bin/env python

import optparse
import sys
import json
import logging
from collections import defaultdict

def parse_args():
    parser = optparse.OptionParser()
    
    parser.add_option("--agg1", help="First aggregation field (e.g., srcip)")
    parser.add_option("--agg2", help="Second aggregation field (e.g., dstip)")
    parser.add_option("--size", default=20, type="int", help="Maximum number of nodes to show")
    parser.add_option("--metric", default='bytes', help="Metric to use (bytes, packets, flows)")
    parser.add_option("--min-flow", default=0, type="float", help="Minimum flow value to include")
    
    (options, args) = parser.parse_args()
    
    if options.agg1 is None:
        parser.error("agg1 is a required argument.")
    if options.agg2 is None:
        parser.error("agg2 is a required argument.")
    if options.metric not in ['bytes', 'packets', 'flows']:
        parser.error("metric must be bytes, packets or flows")
    
    return options

def get_nf_indices(agg1, agg2, metric):
    """Get column indices for nfdump output"""
    agg_indices = {'srcip': 3, 'dstip': 4, 'srcport': 5, 'dstport': 6}
    metric_indices = {'packets': 7, 'bytes': 8, 'flows': 11}
    return agg_indices[agg1], agg_indices[agg2], metric_indices[metric]

def parse_stdin(agg1, agg2, size, metric, min_flow):
    """Parse nfdump data from stdin and create Sankey data structure"""
    
    agg1_index, agg2_index, metric_index = get_nf_indices(agg1, agg2, metric)
    
    # Store all flows: {(source, target): value}
    flows = defaultdict(float)
    # Track all unique nodes and their total traffic
    node_totals = defaultdict(float)
    
    count = 0
    
    for line in sys.stdin:
        if count > 5000:  # Memory protection
            logging.warning("Hit maximum line limit (5000), stopping processing")
            break
            
        raw_line = line.strip().split()
        
        # Skip header lines and invalid data
        try:
            traffic_value = float(raw_line[metric_index])
        except (ValueError, IndexError):
            continue
            
        # Skip flows below minimum threshold
        if traffic_value < min_flow:
            continue
            
        try:
            source = raw_line[agg1_index]
            target = raw_line[agg2_index]
        except IndexError:
            continue
            
        # Aggregate flows
        flows[(source, target)] += traffic_value
        node_totals[source] += traffic_value
        node_totals[target] += traffic_value
        
        count += 1
    
    # Sort nodes by total traffic and take top N
    top_nodes = sorted(node_totals.items(), key=lambda x: x[1], reverse=True)[:size]
    top_node_set = {node for node, _ in top_nodes}
    
    # Create nodes list for Sankey
    nodes = []
    node_index = {}
    
    # Add source nodes first
    source_nodes = {source for source, target in flows.keys() if source in top_node_set}
    for i, node in enumerate(sorted(source_nodes)):
        nodes.append({
            'id': node,
            'name': node,
            'category': 'source'
        })
        node_index[node] = i
    
    # Add target nodes that aren't already sources
    target_nodes = {target for source, target in flows.keys() if target in top_node_set and target not in source_nodes}
    for node in sorted(target_nodes):
        nodes.append({
            'id': node,
            'name': node, 
            'category': 'target'
        })
        node_index[node] = len(nodes) - 1
    
    # Add "Other" node if needed
    other_flows = {}
    has_other = False
    
    for (source, target), value in flows.items():
        if source not in top_node_set or target not in top_node_set:
            has_other = True
            break
    
    if has_other:
        nodes.append({
            'id': 'Other',
            'name': 'Other',
            'category': 'other'
        })
        node_index['Other'] = len(nodes) - 1
    
    # Create links for Sankey
    links = []
    other_inbound = 0
    other_outbound = 0
    
    for (source, target), value in flows.items():
        # Skip very small flows to reduce clutter
        if value < min_flow:
            continue
            
        source_in_top = source in top_node_set
        target_in_top = target in top_node_set
        
        if source_in_top and target_in_top:
            # Both nodes are in top list
            links.append({
                'source': node_index[source],
                'target': node_index[target],
                'value': value,
                'sourceId': source,
                'targetId': target
            })
        elif source_in_top and not target_in_top and has_other:
            # Source is in top list, target is not - flow to "Other"
            other_outbound += value
        elif not source_in_top and target_in_top and has_other:
            # Target is in top list, source is not - flow from "Other"
            other_inbound += value
    
    # Add aggregated "Other" flows if significant
    if has_other:
        if other_inbound > min_flow:
            # Find a representative target for "Other" -> top nodes flow
            # We'll create individual flows from "Other" to top targets
            target_from_other = defaultdict(float)
            for (source, target), value in flows.items():
                if source not in top_node_set and target in top_node_set:
                    target_from_other[target] += value
            
            for target, value in target_from_other.items():
                if value > min_flow:
                    links.append({
                        'source': node_index['Other'],
                        'target': node_index[target],
                        'value': value,
                        'sourceId': 'Other',
                        'targetId': target
                    })
        
        if other_outbound > min_flow:
            # Find a representative source for top nodes -> "Other" flow
            source_to_other = defaultdict(float)
            for (source, target), value in flows.items():
                if source in top_node_set and target not in top_node_set:
                    source_to_other[source] += value
            
            for source, value in source_to_other.items():
                if value > min_flow:
                    links.append({
                        'source': node_index[source],
                        'target': node_index['Other'],
                        'value': value,
                        'sourceId': source,
                        'targetId': 'Other'
                    })
    
    # Sort links by value for better visualization
    links.sort(key=lambda x: x['value'], reverse=True)
    
    # Create the final Sankey data structure
    sankey_data = {
        'nodes': nodes,
        'links': links,
        'meta': {
            'total_nodes': len(nodes),
            'total_links': len(links),
            'total_flows_processed': count,
            'metric': metric,
            'agg1': agg1,
            'agg2': agg2,
            'has_other_category': has_other,
            'top_nodes_count': len(top_node_set)
        }
    }
    
    return sankey_data

if __name__ == "__main__":
    logging.basicConfig(level=logging.WARNING)
    
    options = parse_args()
    
    try:
        sankey_data = parse_stdin(
            options.agg1, 
            options.agg2, 
            options.size, 
            options.metric,
            options.min_flow
        )
        
        print(json.dumps(sankey_data, indent=2))
        
    except Exception as e:
        logging.error(f"Error processing data: {e}")
        error_output = {
            'error': str(e),
            'nodes': [],
            'links': [],
            'meta': {'error': True}
        }
        print(json.dumps(error_output))
        sys.exit(1)