import { useMemo, useState } from 'react';
import { Sankey, ResponsiveContainer } from 'recharts';
import { Button } from '../button';
import { CardContent, Card } from '../card';
import { useTranslation } from 'react-i18next';
import {
  Dialog,
  DialogTrigger,
  DialogContent,
  DialogHeader,
  DialogTitle,
  DialogClose,
} from '@/components/ui/dialog';
import type { CustomNodeProps, SankeyData, SankeyChartProps, SankeyChartWithDialogProps } from "./types";

// Custom node component for better styling
const CustomNode = ({ x, y, width, height, payload }: CustomNodeProps) => {
  const isSource = payload.category === 'source';
  const isTarget = payload.category === 'target';
  const isOther = payload.category === 'other';
  const safeHeight = Math.max(0, height);
  
  // Color coding based on category
  const getNodeColor = (): string => {
    if (isSource) return "var(--chart-1)";
    if (isTarget) return 'var(--chart-2)'; 
    if (isOther) return 'var(--chart-3)';
    return 'var(--chart-4)';
  };

  const nodeColor = getNodeColor();
  const textColor = 'var(--foreground)';
  
  // Determine text position and anchor
  const textX = isSource ? x + width + 6 : x - 6;
  const textAnchor = isSource ? 'start' : 'end';
  
  return (
    <g>
      <rect
        x={x}
        y={y}
        width={width}
        height={safeHeight}
        fill={nodeColor}
        fillOpacity={0.8}
        stroke={nodeColor}
        strokeWidth={2}
        rx={4}
      />
      <text
        x={textX}
        y={y + safeHeight / 2}
        textAnchor={textAnchor}
        alignmentBaseline="middle"
        fontSize="12"
        fill={textColor}
        fontWeight="500"
      >
        {payload.name}
      </text>
    </g>
  );
};



export function ReportSankeyChart({ 
  sankeyData, 
  title = "Network Flow Diagram",
  height = 600,
  className = "",
  nodePadding = 10,
  isLoading = false,
  maxNodes = 20,
  wasPreprocessed,
}: SankeyChartProps) {
  const {t} = useTranslation();
  let processedData: SankeyData | null;
  if (wasPreprocessed && typeof sankeyData !== 'string'){   // skips processing sankey data if
    processedData = sankeyData;                             // processed elsewhere
  }
  else {
    processedData = useMemo((): SankeyData | null => {
      if (!sankeyData) return null;

      try {
        const data: SankeyData = typeof sankeyData === 'string'
          ? JSON.parse(sankeyData)
          : sankeyData;

        if (!data.nodes || !data.links) return null;

        // Filter out very small flows
        const minFlowThreshold = 1;
        let filteredLinks = data.links.filter(link => link.value >= minFlowThreshold);

        const maxTargetNodes = maxNodes;
        const targetFlows = new Map<number, number>();
        filteredLinks.forEach(link => {
          const current = targetFlows.get(link.target) || 0;
          targetFlows.set(link.target, current + link.value);
        });

        const sortedTargets = Array.from(targetFlows.entries())
          .sort((a, b) => b[1] - a[1]);

        const topTargets = sortedTargets
          .slice(0, maxTargetNodes)
          .map(([nodeIndex]) => nodeIndex);

        // NEW: Separate links into top targets and "other"
        const topTargetLinks = filteredLinks.filter(link =>
          topTargets.includes(link.target)
        );

        const otherLinks = filteredLinks.filter(link =>
          !topTargets.includes(link.target)
        );

        // NEW: Create "Other" node if there are filtered links
        let finalNodes = [...data.nodes];
        let finalLinks = [...topTargetLinks];

        const existingOtherIndex = finalNodes.findIndex(node =>
          node.name.toLowerCase().includes('other') || node.category === 'other'
        );

        let otherNodeIndex: number | undefined;

        if (existingOtherIndex >= 0) {
          // Use existing "Other" node
          otherNodeIndex = existingOtherIndex;
          // Make sure it has the right category
          finalNodes[existingOtherIndex].category = 'other';
        } else if (otherLinks.length > 0) {
          // Create new "Other" node
          otherNodeIndex = finalNodes.length;
          finalNodes.push({
            id: `other-combined`,
            name: "Other",
            category: "other"
          });
        }

        if (otherLinks.length > 0 && otherNodeIndex !== undefined) {
          // Group "other" links by source and sum their values
          const otherFlowsBySource = new Map<number, number>();
          
          // Include flows to existing "Other" node if it exists
          if (existingOtherIndex >= 0) {
            topTargetLinks.filter(link => link.target === existingOtherIndex)
              .forEach(link => {
                const current = otherFlowsBySource.get(link.source) || 0;
                otherFlowsBySource.set(link.source, current + link.value);
              });

            // Remove existing "Other" links from finalLinks since we're recreating them
            finalLinks = finalLinks.filter(link => link.target !== existingOtherIndex);
          }

          // Add new "other" flows
          otherLinks.forEach(link => {
            const current = otherFlowsBySource.get(link.source) || 0;
            otherFlowsBySource.set(link.source, current + link.value);
          });

          // Create combined aggregated links to "Other" node
          const combinedOtherLinks = Array.from(otherFlowsBySource.entries())
            .map(([sourceIndex, totalValue]) => ({
              source: sourceIndex,
              target: otherNodeIndex,
              value: totalValue
            }));

          finalLinks.push(...combinedOtherLinks);
        }

        // Remove self-loops and get used nodes
        const usedNodeIndices = new Set<number>();
        finalLinks = finalLinks.filter(link => {
          if (link.source === link.target) return false;
          usedNodeIndices.add(link.source);
          usedNodeIndices.add(link.target);
          return true;
        });

        const filteredNodes = finalNodes.filter((_, index) => usedNodeIndices.has(index));

        const nodeIndexMap: Record<number, number> = {};
        filteredNodes.forEach((node, newIndex) => {
          const oldIndex = finalNodes.findIndex(n => n.id === node.id);
          if (oldIndex >= 0) {
            nodeIndexMap[oldIndex] = newIndex;
          }
        });

        const updatedLinks = finalLinks
          .map(link => ({
            ...link,
            source: nodeIndexMap[link.source],
            target: nodeIndexMap[link.target],
          }))
          .filter(link =>
            link.source !== undefined &&
            link.target !== undefined &&
            link.source !== link.target
          );

        // Cycle removal
        const adjacency: Record<number, number[]> = {};
        updatedLinks.forEach(link => {
          adjacency[link.source] = adjacency[link.source] || [];
          adjacency[link.source].push(link.target);
        });

        const visited = new Set<number>();
        const stack = new Set<number>();

        const hasCycle = (node: number): boolean => {
          if (stack.has(node)) return true;
          if (visited.has(node)) return false;

          visited.add(node);
          stack.add(node);

          const neighbors = adjacency[node] || [];
          for (const neighbor of neighbors) {
            if (hasCycle(neighbor)) return true;
          }

          stack.delete(node);
          return false;
        };

        const cycleFreeLinks = updatedLinks.filter(link => {
          visited.clear();
          stack.clear();
          return !hasCycle(link.target);
        });

        return {
          nodes: filteredNodes,
          links: cycleFreeLinks,
        };
      } catch (error) {
        console.error('Error processing Sankey data:', error);
        return null;
      }
    }, [sankeyData]);
  }


  if (!processedData || processedData.nodes.length === 0) {
    return (
      <CardContent className="w-full p-8 text-center w-[330px] h-[302px] mb-auto mt-6">
        <div className="text-foreground">
          {isLoading ? (
            <>
              <div className="animate-pulse space-y-2 mb-auto h-full">
                <div className="h-4 bg-muted rounded w-3/4 mx-auto mb-2"></div>
                <div className="h-3 bg-muted rounded w-1/2 mx-auto"></div>
                <div className="h-4 bg-muted rounded w-3/4 mx-auto mb-2"></div>
                <div className="h-3 bg-muted rounded w-1/2 mx-auto"></div>
                <div className="h-4 bg-muted rounded w-3/4 mx-auto mb-2"></div>
                <div className="h-3 bg-muted rounded w-1/2 mx-auto"></div>
                <div className="h-4 bg-muted rounded w-3/4 mx-auto mb-2"></div>
              </div>
              <p className="text-sm mt-2">{t('Loading flow data...')}</p>
            </>
          ) : (
            <>
              <p className="font-medium">{t('No flow data available')}</p>
              <p className="text-sm">{t('No significant traffic flows found for the selected criteria.')}</p>
            </>
          )}
        </div>
      </CardContent>
    );
  }

  const [agg1, agg2] = title.split("→").map(s => s.trim());

  return (
    <div className={`flex pt-6 ${className} gap-0`}>
      <CardContent className="flex-1 pb-0">
        <ResponsiveContainer width="100%" height={height}>
          <Sankey
            data={processedData}
            node={CustomNode}
            nodePadding={nodePadding}
            margin={{ 
              left: 25, 
              right: 25, 
              top: 10, 
              bottom: 10 
            }}
            link={{ 
              stroke: 'var(--chart-4)',
              strokeOpacity: 0.65,
              fill: 'none'
            }}
            iterations={32}
          >
          </Sankey>
        </ResponsiveContainer>
      
      
      <div className="mb-2">
        <h3 className="text-md font-semibold text-foreground">{title}</h3>
        <div className="flex gap-4 mt-2 text-xs">
          <div className="flex items-center gap-1">
            <div className="w-3 h-3 bg-chart-1 rounded"></div>
            <span>{agg1}</span>
          </div>
          <div className="flex items-center gap-1">
            <div className="w-3 h-3 bg-chart-2 rounded"></div>
            <span>{agg2}</span>
          </div>
          <div className="flex items-center gap-1">
            <div className="w-3 h-3 bg-chart-3 rounded"></div>
            <span>{t('Other')}</span>
          </div>
        </div>
      </div>
      </CardContent>
    </div>
  );
}


export function ReportSankeyChartWithDialog({
  sankeyData,
  title,
  isLoading = false,
  className="",
  wasPreprocessed= false,
}: SankeyChartWithDialogProps) {
  const { t } = useTranslation();
  // Small preview size
  const previewWidth = 330;
  const previewHeight = 250;
  // Large dialog size
  const dialogWidth = 700;
  const dialogHeight = 700;

  const [open, setOpen] = useState(false);

  return (
    <Dialog open={open} onOpenChange={setOpen}>
      <Card className={`cursor-pointer py-0 flex gap-0`} aria-label={`Open large view for ${title}`}>
        <ReportSankeyChart
          sankeyData={sankeyData}
          title={title}
          width={previewWidth}
          height={previewHeight}
          className={className}
          isLoading={isLoading}
          wasPreprocessed={wasPreprocessed}
        />
        <DialogTrigger asChild>
          <div className="flex justify-center items-center pb-2">
          <Button
            aria-label={`Open large view for ${title}`}
          >
            {t("Expand View")}
          </Button>
          </div>
        </DialogTrigger>
      </Card>

      {/* The large chart inside the dialog */}
      <DialogContent className="sm:max-w-[50vw] max-h-[750px] p-6 overflow-y-scroll overflow-x-hidden">
        <DialogHeader>
          <DialogTitle>{title}</DialogTitle>
          <DialogClose className="absolute right-4 top-4" />
        </DialogHeader>
        <Card>
          <ReportSankeyChart
            sankeyData={sankeyData}
            title={title}
            width={dialogWidth}
            height={dialogHeight}
            nodePadding={10}
            isLoading={isLoading}
            maxNodes={50}
            wasPreprocessed={wasPreprocessed}
          />
          </Card>
      </DialogContent>
    </Dialog>
  );
}