// ImportalStackedBarChart.tsx
import { useChartDimensions } from '@/custom-hooks/reporting/useChartDimensions';
import { BarChartOrientation, ChartConfig, ChartUnitType, LegendPositionType } from 'common/interfaces/reports';
import * as d3 from 'd3';
import { useMemo, useState } from 'react';
import { CategoricalXAxis } from './CategoricalXAxis';
import { CategoricalYAxis } from './CategoricalYAxis';
import { ImportalChartLegend } from './ImportalChartLegend';
import { ImportalHorizontalBars } from './ImportalHorizontalBars';
import { ImportalVerticalBars } from './ImportalVerticalBars';
import { NumericXAxis } from './NumericXAxis';
import { NumericYAxis } from './NumericYAxis';

interface Props {
  chartConfig: ChartConfig;
  data: any[];
}

export default function ImportalStackedBarChart({ chartConfig, data }: Props) {
  const [ref, dms] = useChartDimensions({
    // width: 700,
    // height: 600,
    marginLeft: 100,
    marginRight: 40,
    marginTop: 20,
    marginBottom: 120,
  });

  const [tooltip, setTooltip] = useState<{ x: number; y: number; value: number } | null>(null);

  const { seriesNameKey, seriesSegmentKey, seriesValueKey, xAxisName, yAxisName } = chartConfig.barConfig ?? {};

  const segmentKeys = useMemo(() => {
    if (!seriesSegmentKey) return [];
    return d3.union(data.map((d) => d[seriesSegmentKey]));
  }, [data, seriesSegmentKey]);

  const groupedData = useMemo(() => {
    if (!seriesNameKey || !seriesSegmentKey || !seriesValueKey) return [];
    return d3
      .rollups(
        data,
        (v) => {
          // Combine all segmentKey:value pairs for this seriesNameKey
          const row: any = { [seriesNameKey]: v[0][seriesNameKey] };
          for (const item of v) {
            row[item[seriesSegmentKey]] = item[seriesValueKey];
          }
          return row;
        },
        (d) => d[seriesNameKey]
      )
      .map(([_, row]) => row);
  }, [data, seriesNameKey, seriesSegmentKey, seriesValueKey]);

  const series = useMemo(() => {
    if (!seriesNameKey || !seriesSegmentKey || !seriesValueKey) return [];
    return d3
      .stack()
      .keys(segmentKeys)
      .value((row, key) => row[key] ?? 0)(groupedData);
  }, [data, seriesNameKey, seriesSegmentKey, seriesValueKey]);

  const transformStackedData = (stackedData, categoryKey) => {
    return stackedData.map((series) => {
      const res = series.map((d) => {
        if (d.data) {
          const category = d.data[categoryKey];
          const { data, ...rest } = d;
          return { ...rest, [categoryKey]: category };
        }
        return d;
      });
      res.key = series.key;
      return res;
    });
  };

  const transformedSeries = useMemo(() => {
    if (!series) return [];
    return transformStackedData(series, seriesNameKey);
  }, [series, seriesNameKey]);

  // For the legend, we want the stack group keys in the same order as used by d3.stack:
  const legendKeys = useMemo(() => {
    if (transformedSeries && transformedSeries.length > 0) {
      return transformedSeries.map((s) => s.key);
    }
    return [];
  }, [transformedSeries]);

  const categories = useMemo(() => {
    if (!transformedSeries.length || !seriesNameKey) return [];
    return transformedSeries[0].map((d) => d[seriesNameKey]);
  }, [transformedSeries, seriesNameKey]);

  const maxValue = useMemo(() => {
    if (!transformedSeries) return 0;
    return d3.max(transformedSeries, (d) => d3.max(d, (d) => d[1])) ?? 0;
  }, [transformedSeries]);

  const xScale = useMemo(() => {
    if (!seriesNameKey || !categories) return null;
    if (chartConfig.barConfig?.orientation == BarChartOrientation.VERTICAL) {
      return d3.scaleBand<string>().domain(categories).range([0, dms.boundedWidth]).padding(0.2);
    } else {
      return d3.scaleLinear().domain([0, maxValue]).range([0, dms.boundedWidth]).nice();
    }
  }, [chartConfig.barConfig?.orientation, categories, dms.boundedWidth, seriesNameKey]);

  const yScale = useMemo(() => {
    if (!seriesNameKey || !categories) return null;
    if (chartConfig.barConfig?.orientation == BarChartOrientation.VERTICAL) {
      return d3.scaleLinear().domain([0, maxValue]).range([dms.boundedHeight, 0]).nice();
    } else {
      return d3.scaleBand<string>().domain(categories).range([0, dms.boundedHeight]).padding(0.2);
    }
  }, [chartConfig.barConfig?.orientation, categories, dms.boundedHeight, seriesNameKey]);

  const colorScale = useMemo(() => {
    if (!transformedSeries || !transformedSeries.length) return null;
    return d3
      .scaleOrdinal()
      .domain(transformedSeries.map((d) => d.key))
      .range(d3.schemeCategory10)
      .unknown('#ccc');
  }, [transformedSeries]);

  // Event handlers to update tooltip state on hover.
  const handleMouseEnter = (event: React.MouseEvent<SVGRectElement>, d: [number, number]) => {
    const container = ref.current;
    if (container) {
      const { left, top } = container.getBoundingClientRect();
      setTooltip({
        x: event.clientX - left,
        y: event.clientY - top,
        value: d[1] - d[0],
      });
    }
  };

  const handleMouseMove = (event: React.MouseEvent<SVGRectElement>, d: [number, number]) => {
    const container = ref.current;
    if (container) {
      const { left, top } = container.getBoundingClientRect();
      setTooltip({
        x: event.clientX - left,
        y: event.clientY - top,
        value: d[1] - d[0],
      });
    }
  };

  const handleMouseLeave = () => {
    setTooltip(null);
  };

  if (!seriesNameKey || !seriesValueKey) {
    return <div style={{ color: '#999' }}>Missing axis keys.</div>;
  }

  if (!xScale || !yScale) {
    return <div style={{ color: '#999' }}>Scales could not be computed.</div>;
  }

  return (
    <div ref={ref} style={{ height: '100%', width: '100%' }}>
      <figure style={{ margin: 0, width: dms.width, height: dms.height }}>
        {/* Render the legend above the chart if a color scale and keys exist */}
        {colorScale &&
          legendKeys.length > 0 && ( // only works for legendPosition === TOP right now
            <ImportalChartLegend
              keys={legendKeys}
              colorScale={colorScale}
              legendPosition={chartConfig.legendPosition ?? LegendPositionType.TOP}
            />
          )}
        <svg width={dms.width} height={dms.height}>
          <g transform={`translate(${dms.marginLeft}, ${dms.marginTop})`}>
            {/* Background */}
            <rect width={dms.boundedWidth} height={dms.boundedHeight} fill="#fff" rx={8} />

            {/* Draw each stack series */}
            {transformedSeries.map((stackSeries) => {
              return chartConfig.barConfig?.orientation === BarChartOrientation.VERTICAL ? (
                <ImportalVerticalBars
                  data={stackSeries}
                  xScale={xScale}
                  yScale={yScale}
                  seriesNameKey={seriesNameKey}
                  seriesValueKey={seriesValueKey}
                  barColor={colorScale(stackSeries.key)}
                  unit={chartConfig.barConfig?.xAxisUnits || ChartUnitType.Number}
                />
              ) : (
                <ImportalHorizontalBars
                  data={stackSeries}
                  xScale={xScale}
                  yScale={yScale}
                  seriesNameKey={seriesNameKey}
                  seriesValueKey={seriesValueKey}
                  barColor={colorScale(stackSeries.key)}
                  unit={chartConfig.barConfig?.xAxisUnits || ChartUnitType.Number}
                />
              );
            })}

            {/* Y Axis */}
            {chartConfig.barConfig?.orientation === BarChartOrientation.VERTICAL ? (
              <NumericYAxis axisLabel={yAxisName || ''} range={[0, maxValue!]} height={dms.boundedHeight} />
            ) : (
              <CategoricalYAxis
                axisLabel={yAxisName || ''}
                categories={categories}
                height={dms.boundedHeight}
                showAxisLine={chartConfig.showAxisLine}
              />
            )}
          </g>

          {/* X Axis - positioned below the chart */}
          <g transform={`translate(${dms.marginLeft}, ${dms.marginTop + dms.boundedHeight})`}>
            {chartConfig.barConfig?.orientation === BarChartOrientation.VERTICAL ? (
              <CategoricalXAxis axisLabel={xAxisName || ''} categories={categories} width={dms.boundedWidth} />
            ) : (
              <NumericXAxis
                axisLabel={xAxisName || ''}
                range={[0, maxValue!]}
                width={dms.boundedWidth}
                showAxisLine={chartConfig.showAxisLine}
                showAxisTicks={chartConfig.showAxisTicks}
              />
            )}
          </g>
        </svg>
      </figure>

      {/* Tooltip */}
      {tooltip && (
        <div
          style={{
            position: 'absolute',
            top: tooltip.y + 10, // Offset slightly so the tooltip doesn't obscure the bar
            left: tooltip.x + 10,
            background: 'rgba(255, 255, 255, 0.8)',
            border: '1px solid #ccc',
            padding: '4px 8px',
            pointerEvents: 'none',
            fontSize: '12px',
            borderRadius: '4px',
          }}
        >
          {tooltip.value}
        </div>
      )}
    </div>
  );
}
