import React from 'react';
import { scaleLinear } from '@visx/scale';
import { Group } from '@visx/group';
import discreteColorScale from './Colors';
import { useSprings, animated } from 'react-spring';
import { Sample, SampleBreakdown } from './Common';
import { Container } from 'semantic-ui-react';

interface AnimatedSampleProps {
  sample: Sample;
  breakdown: SampleBreakdown;
  keys: Array<string>;
  dataKey: string;
  width: string | number;
  height: string | number;
  currentKey: string; // Provided from controller state. This should be stateful in the controller
  columnWidth?: number;
  gutterWidth?: number;
  circlePadding?: number;
  labelPadding?: number;
  marginLeft?: number;
  marginRight?: number;
  marginTop?: number;
  marginBottom?: number;
  viewboxHeight?: number;
  viewboxWidth?: number;
  maxCircleRadius?: number;
  nCircles?: number;
  labels?: boolean;
  labelTransform?: (seriesKey: string, percentage: number) => string;
}

interface Center {
  x: number;
  y: number;
  color: string;
}

const AnimatedSample: React.FC<AnimatedSampleProps> = (props: AnimatedSampleProps) => {
  const {
    sample,
    keys,
    dataKey,
    breakdown,
    width = '100%',
    height = 500,
    labels,
    labelTransform,
    columnWidth = 5, // Number of circles per column
    gutterWidth = 30, // Pixel viewboxWidth between each column
    circlePadding = 0.2, // Percentage of circle (as a decimal) to use as padding
    labelPadding = 0, // Percentage of plot height to reserve for labels
    marginBottom = 0,
    marginLeft = 0,
    marginRight = 0,
    marginTop = 0, // Margin in pixels to reserve within the svg
    viewboxHeight = 500,
    viewboxWidth = 800, // viewboxHeight and viewboxWidth of the svg in pixels
    maxCircleRadius = Infinity,
    nCircles = 100, // Total number of circles
    currentKey, // String to access current series, stateful in controller
  } = props;

  // Transform each array to normalized (by nCircles) counts. Round greedily such that
  // each array in roundedCounts sums to nCircles.
  const sums = keys.map((key: string) => breakdown[key].map((innerKey) => sample[innerKey]).reduce((a, b) => a + b, 0));
  const rawCounts = keys.map((key: string, idx: number) =>
    breakdown[key].map((a) => (sample[a] / sums[idx]) * nCircles)
  );

  // Greedy rounding procedure:
  //  Counts is initialized to rawCounts with truncated decimals
  //  While the sum of counts is less than nCircles:
  //    Calculate the difference between each count and the corresponding rawCount, save the index
  //    Compute the max difference and get its index
  //    Add one (round up) the count at that index (NOTE: rawCount - count for that index is now negative)
  const roundedCounts = rawCounts.map((rawCount: Array<number>) => {
    var counts = rawCount.map(Math.floor);
    while (counts.reduce((a, b) => a + b, 0) < nCircles) {
      var differences = rawCount.map((val: number, idx: number) => [val - counts[idx], idx]);
      const maxValue = Math.max(...differences.map((a) => a[0]));
      const maxIndex = differences.filter((val) => val[0] == maxValue)[0][1];
      counts[maxIndex] += 1;
    }
    return counts.filter((count) => count > 0);
  });

  // The number of columns is the length of each rounded count array
  const nColumns = roundedCounts.map((arr: Array<number>) => arr.length);
  const maxColumns = Math.max(...nColumns);

  // The number of rows is the max of each rounded array divided by columnWidth
  const nRows = roundedCounts.map((roundedCount) => Math.ceil(Math.max(...roundedCount) / columnWidth));
  const maxRows = Math.max(...nRows);

  // Calculate min and max true pixel values
  const [minX, maxX] = [marginLeft, viewboxWidth - marginRight];
  const [minY, maxY] = [marginTop + labelPadding * viewboxHeight, viewboxHeight - marginBottom];

  const [plotWidth, plotHeight] = [maxX - minX, maxY - minY];

  // Create scale transformations
  const xScale = scaleLinear({ domain: [0, plotWidth], range: [minX, maxX] });
  const yScale = scaleLinear({ domain: [0, plotHeight], range: [minY, maxY] });

  // Since gutterWidth is provided in pixels, we need to convert it to logical space
  const xGutterWidth = xScale.invert(gutterWidth);

  const xRadiusBound = (plotWidth - xGutterWidth * (maxColumns - 1)) / (columnWidth * maxColumns * 2);
  const yRadiusBound = plotHeight / (maxRows * 2);

  // Display circle radius is how large the circle should be drawn and is inversely
  // proportional to circlePadding
  // NOTE: this is in canvas units not logical units since it only gets used at render
  const fullCircleRadius = Math.min(xRadiusBound, yRadiusBound, maxCircleRadius);
  const displayCircleRadius = fullCircleRadius * (1 - circlePadding);

  // Center shifts are potentially different for each series so must be calculated
  // based on the number of columns in that series
  const centerShifts = nColumns.map(
    (nCols) => (plotWidth - (xGutterWidth * (nCols - 1) + fullCircleRadius * 2 * columnWidth * nCols)) / 2
  );

  // Now that we know bounds we can pack circle centers
  const circleCenters = roundedCounts.map((roundedCount, idx) =>
    roundedCount.flatMap((count, binOffset) => {
      var columnCenters = Array<Center>();
      for (var offset = 0; offset < count; offset++) {
        columnCenters.push({
          x:
            binOffset * (xGutterWidth + fullCircleRadius * columnWidth * 2) +
            (offset % columnWidth) * fullCircleRadius * 2 +
            fullCircleRadius +
            centerShifts[idx],
          y: Math.floor(offset / columnWidth) * fullCircleRadius * 2 + fullCircleRadius,
          color: discreteColorScale(binOffset),
        });
      }
      return columnCenters;
    })
  );

  const labelCenters = nColumns.map((seriesLength, seriesIndx) => {
    const seriesLengths = [...Array<number>(seriesLength)];
    return seriesLengths.map((_, binIndex) => ({
      // cx: centerShifts[seriesIndx] + (xGutterWidth + columnWidth * fullCircleRadius * 2) * binIndex + columnWidth * fullCircleRadius,
      x: centerShifts[seriesIndx] + (xGutterWidth + columnWidth * fullCircleRadius * 2) * binIndex,
      // cy: viewboxHeight * labelPadding / 2,
      y: 0,
      color: discreteColorScale(binIndex),
    }));
  });

  const keyIndex = keys.findIndex((key) => key == currentKey) == -1 ? 0 : keys.findIndex((key) => key == currentKey);

  const springs = useSprings(
    nCircles,
    circleCenters[keyIndex].map((center: Center, idx: number) => ({
      cx: xScale(center.x),
      cy: yScale(center.y),
      fill: center.color,
    }))
  );

  return (
    <svg
      viewBox={`0 0 ${Math.min(viewboxWidth, plotWidth)} ${Math.min(
        viewboxHeight,
        plotHeight + viewboxHeight * labelPadding
      )}`}
      width={width}
      height={height}
    >
      {labels && (
        <Group>
          {labelCenters[keyIndex].map(({ x, y, color }, labelIndex) => (
            <foreignObject
              key={`label-${dataKey}-${labelIndex}`}
              x={x}
              y={y}
              width={columnWidth * fullCircleRadius * 2}
              height={viewboxHeight * labelPadding}
            >
              <Container fluid textAlign="center">
                <p style={{ color: color }}>
                  {labelTransform
                    ? labelTransform(
                        breakdown[currentKey][labelIndex],
                        (rawCounts[keyIndex][labelIndex] / nCircles) * 100
                      )
                    : breakdown[currentKey][labelIndex]}
                </p>
              </Container>
            </foreignObject>
          ))}
        </Group>
      )}
      <Group>
        {springs.map((styles, idx) => (
          <animated.circle key={`${dataKey}-circle-${idx}`} r={displayCircleRadius} className="dot" style={styles} />
        ))}
      </Group>
    </svg>
  );
};

export default AnimatedSample;
