import React from "react";
import useComponentSize from "../../modules/hooks/use-component-size";
import { scaleLinear, scaleBand, scalePoint } from "d3-scale";
import styles from "./vertical-performance-trends.module.scss";
import SvgFlag from "../../components/svg-flag";
import { HexagonPolygon, HexagonPolygonSmall, HexagonSvg, SquareSvg, DiamondSvg, Arrow } from "../../components/shapes";
import Flag from "../../components/flag";

const padding = {
    top: 54,
    bottom: 0,
    left: 44,
    right: 0,
};

const DATUMS = ["school", "country", "oecd"];
const DATUM_TO_KEY = {
    school: "",
    country: "_cty",
    oecd: "_oecd",
};
const sectionPadding = 40;

const useYears = (schoolData, visOptions) => {
    const existingYearKeys = ["1y", "2y", "3y", "4y", "5y"].map(d => d.replace("lab_", ""));
    const existingYearValues = existingYearKeys.map(d => schoolData["lab_" + d]);

    return {
        keys: existingYearKeys,
        values: existingYearValues,
    };
};

const useDomains = (schoolData, visOptions, years) => {
    const bands = years.keys;
    let yRaw = visOptions.rangeColumns.map(d => schoolData[d]);

    if (yRaw.length > 0 && (yRaw[0] === undefined || yRaw[1] === undefined)) {
        console.log("----- ----- Couldn't find " + visOptions.rangeColumns + " in schoolData");
        console.log("----- ----- Making range [0, 100].");
        yRaw = [0, 700];
    }

    // We round to nearest 50s.
    const y = [Math.floor(yRaw[0] / 50) * 50, Math.ceil(yRaw[1] / 50) * 50];
    const innerPoints = DATUMS;

    return {
        bands,
        years: years.values,
        innerPoints,
        y,
    };
};

const useScales = (dimensions, domains, visOptions) => {
    if (!dimensions) return undefined;

    const frame = {
        top: padding.top,
        right: dimensions.width - padding.right,
        bottom: dimensions.height - padding.bottom,
        left: padding.left,
    };

    const rangeWidth = frame.right - frame.left;

    const bands = scaleBand()
        .domain(domains.bands)
        .rangeRound([frame.left, frame.right])
        .paddingOuter(0)
        .paddingInner(sectionPadding / rangeWidth);

    const bandwidth = bands.bandwidth();

    const y = scaleLinear().domain(domains.y).rangeRound([frame.bottom, frame.top]);

    const innerPoints = scalePoint().domain(domains.innerPoints).rangeRound([0, bandwidth]).padding(1);

    return {
        frame,
        bands,
        innerPoints,
        y,
        bandwidth,
        icon: {
            square: visOptions.showRanges ? 8 : 12,
            diamond: visOptions.showRanges ? 11 : 15,
        },
    };
};

const VerticalPerformanceTrends = props => {
    const { visOptions, schoolData } = props;
    const ref = React.useRef();
    const dimensions = useComponentSize(ref);
    const years = useYears(schoolData, visOptions);
    const domains = useDomains(schoolData, visOptions, years);
    const scales = useScales(dimensions, domains, visOptions);

    const chartProps = {
        ...props,
        dimensions,
        domains,
        scales,
    };

    let chartContent;
    let overlayContent;
    if (scales) {
        chartContent = (
            <>
                <AxisGrid {...chartProps} />
                <ConfidenceIntervals {...chartProps} />
                <ChartData {...chartProps} />
            </>
        );
        overlayContent = (
            <>
                <DomainLabels {...chartProps} />
            </>
        );
    }

    return (
        <div className={styles.verticalPerformance}>
            <div className={styles.container}>
                <svg className={styles.svg} ref={ref}>
                    {chartContent}
                </svg>
                <div className={styles.overlay}>{overlayContent}</div>
            </div>
            <Legend {...chartProps} />
        </div>
    );
};

const AxisGrid = props => {
    const { visOptions, localiser, scales } = props;

    const yTicks = scales.y.ticks();

    // Lines protrude past the axis a bit.
    const linesStartX = scales.frame.left - 8;

    return (
        <g>
            <g>
                {yTicks.map((yVal, i) => (
                    <line
                        key={yVal}
                        className={"gridLine"}
                        x1={linesStartX}
                        x2={scales.frame.right}
                        y1={scales.y(yVal)}
                        y2={scales.y(yVal)}
                        data-bottom={i === 0}
                    />
                ))}
            </g>
            <g>
                {yTicks.map(yVal => (
                    <text key={yVal} className={"gridLabel"} x={linesStartX - 8} y={scales.y(yVal)}>
                        {yVal}
                    </text>
                ))}
            </g>
            <g>
                <line
                    className={"axisLine"}
                    x1={scales.frame.left}
                    x2={scales.frame.left}
                    y1={scales.frame.top - 32}
                    y2={scales.frame.bottom}
                />
                <text className={"axisText"} x={0} y={scales.frame.top - 41}>
                    {localiser(visOptions.yAxisLabel || "PISA score")}
                </text>
            </g>
        </g>
    );
};

const ConfidenceIntervals = props => {
    const { scales, visOptions, domains, schoolData } = props;
    const { confidencePrefix } = visOptions;

    const intervals = domains.bands
        .filter(d => !d.startsWith("+"))
        .map(year => {
            const top = scales.y(schoolData[`CI_upper_${confidencePrefix}${year === "0" ? "" : "_" + year}`]);
            const bottom = scales.y(schoolData[`CI_lower_${confidencePrefix}${year === "0" ? "" : "_" + year}`]);
            const height = bottom - top;
            const x = scales.bands(year);
            const endX = x + scales.bandwidth;

            return {
                top,
                bottom,
                height,
                x,
                endX,
                yearIndex: year,
            };
        })
        .filter(d => (d.top || d.top === 0) && (d.bottom || d.bottom === 0));

    const rects = intervals.map((domain, i) => {
        return (
            <g key={i}>
                <rect
                    className={"confidenceIntervalRect"}
                    x={domain.x}
                    width={scales.bandwidth}
                    y={domain.top}
                    height={domain.height}
                />
                <line
                    className={"confidenceIntervalLine"}
                    x1={domain.x}
                    x2={domain.endX}
                    y1={domain.top}
                    y2={domain.top}
                />
                <line
                    className={"confidenceIntervalLine"}
                    x1={domain.x}
                    x2={domain.endX}
                    y1={domain.bottom}
                    y2={domain.bottom}
                />
                {i === 0 && <ConfidenceIntervalExplainer {...props} domain={domain} />}
            </g>
        );
    });
    return <g>{rects}</g>;
};

const ConfidenceIntervalExplainer = props => {
    const { domain, localiser, scales, domains, schoolData, visOptions } = props;
    const left = domain.x + 8;

    const collisions = getBandValues(domain.yearIndex, domains, schoolData, visOptions)
        .map(d => scales.y(d))
        .concat([domain.top, domain.bottom])
        .sort();
    const lowerTop = collisions[collisions.length - 1] + 12 + 14 + scales.icon.square;
    const upperBottom = collisions[0] - 12 - scales.icon.square;

    return (
        <g>
            <g transform={`translate(${left} ${upperBottom})`}>
                <text x={0} y={0} className={styles.confidenceExplainerLabel}>
                    {localiser("Statistically significant")}
                </text>
                <Arrow transform={`translate(10 -24) rotate(180)`} />
            </g>
            <g transform={`translate(${left} ${lowerTop})`}>
                <text x={0} y={0} className={styles.confidenceExplainerLabel}>
                    {localiser("Statistically significant")}
                </text>
                <Arrow transform={`translate(0 12)`} />
            </g>
        </g>
    );
};

const getBandValues = (yearIndex, domains, schoolData, visOptions) => {
    const domainDatums = domains.innerPoints.map(datumName => {
        const datumSuffix = DATUM_TO_KEY[datumName];
        return schoolData[`${visOptions.yearScorePrefix}_${yearIndex}${datumSuffix}`];
    });
    return domainDatums;
};

const ChartData = props => {
    const { scales, visOptions, params, domains, schoolData } = props;

    const domainIcons = domains.bands.map(yearIndex => {
        const domainLeft = scales.bands(yearIndex);
        const bandValues = getBandValues(yearIndex, domains, schoolData, visOptions);
        const domainDatums = domains.innerPoints.map((datumName, datumIndex) => {
            const value = bandValues[datumIndex];
            if (value === undefined) return null;
            const datumOffsetX = scales.innerPoints(datumName);
            const x = domainLeft + datumOffsetX;
            const y = scales.y(value);

            if (datumName === "country") {
                return (
                    <g key={datumName}>
                        <SvgFlag
                            key={datumName}
                            className={styles.chartDatum}
                            params={params}
                            x={x - scales.icon.square}
                            y={y - scales.icon.square}
                            width={scales.icon.square * 2}
                            height={scales.icon.square * 2}
                            data-datum={datumName}
                        />
                        <circle className={"flagStroke"} cx={x} cy={y} r={scales.icon.square} />
                    </g>
                );
            }

            if (datumName === "school") {
                const HexagonPolyComponent = visOptions.showRanges ? HexagonPolygonSmall : HexagonPolygon;
                return (
                    <HexagonPolyComponent
                        key={datumName}
                        className={styles.chartDatum}
                        transform={`translate(${x - scales.icon.square} ${y - scales.icon.square})`}
                        data-datum={datumName}
                    />
                );
            }

            // OECD
            return (
                <circle
                    key={datumName}
                    className={styles.chartDatum}
                    cx={x}
                    cy={y}
                    r={scales.icon.square}
                    data-datum={datumName}
                />
            );
        });

        return <g key={yearIndex}>{domainDatums}</g>;
    });

    return (
        <g>
            <g>{domainIcons}</g>
        </g>
    );
};

const DomainLabels = props => {
    const { scales, domains } = props;
    const domainLabels = domains.bands.map((yearId, domainIndex) => {
        const label = domains.years[domainIndex];
        const domainLeft = scales.bands(yearId);
        const domainMid = domainLeft + scales.bandwidth / 2;

        return (
            <div
                key={label}
                className={styles.domainLabel}
                style={{
                    left: domainMid,
                    top: scales.frame.top,
                    maxWidth: scales.bands.bandwidth() * 0.85,
                }}
            >
                {label}
            </div>
        );
    });

    return <g>{domainLabels}</g>;
};

const Legend = props => {
    const { batchConfig, localiser, visOptions, params } = props;

    return (
        <div className={"legend"} style={{ paddingLeft: padding.left }}>
            <div className={"legendItem"}>
                <div className={"legendIconGroup"}>
                    <HexagonSvg className={"legendIcon"} data-school />
                </div>
                <span className={"legendLabel"}>{props.yourSchool}</span>
            </div>
            <div className={"legendItem"}>
                <div className={"legendIconGroup"}>
                    <Flag className={"legendIcon"} {...params} />
                </div>
                <span className={"legendLabel"}>{batchConfig.countryName}</span>
            </div>

            <div className={"legendItem"}>
                <div className={"legendIconGroup"}>
                    <div className={"legendOecdIcon"} />
                </div>
                <span className={"legendLabel"}>{localiser("OECD")}</span>
            </div>
            {visOptions.showRanges && (
                <>
                    <div className={"legendItem"}>
                        <div className={"legendIconGroup"} style={{ height: 21 }}>
                            <DiamondSvg className={"legendRangeIcon"} data-significant="false" />
                            <DiamondSvg className={"legendRangeIcon"} />
                        </div>
                        <span className={"legendLabel"}>{localiser(visOptions.rangeLabels[0])}</span>
                    </div>
                    <div className={"legendItem"}>
                        <div className={"legendIconGroup"}>
                            <SquareSvg className={"legendRangeIcon"} data-significant="false" />
                            <SquareSvg className={"legendRangeIcon"} />
                        </div>
                        <span className={"legendLabel"}>{localiser(visOptions.rangeLabels[1])}</span>
                    </div>
                </>
            )}
            <div className={"legendItem"}>
                <div className={"legendIconGroup"}>
                    <div className={styles.legendConfidenceInterval} />
                </div>
                <span className={"legendLabel"}>{localiser("95% confidence interval")}</span>
            </div>
        </div>
    );
};

export default VerticalPerformanceTrends;
