import React, {useEffect, useRef, useState} from 'react';
import * as Plot from "@observablehq/plot";
import * as d3 from "d3";
import {makeStyles} from "@mui/styles";

const useStyles = makeStyles((theme) => ({
    heatmap: {
        width: '100%',
        '& figure': {
            margin: '10px 0'
        }
    }
}))

const corr = (x, y) => {
    const n = x.length;
    if (y.length !== n)
        throw new Error("The two columns must have the same length.");
    const x_ = d3.mean(x);
    const y_ = d3.mean(y);
    const XY = d3.sum(x, (_, i) => (x[i] - x_) * (y[i] - y_));
    const XX = d3.sum(x, (d) => (d - x_) ** 2);
    const YY = d3.sum(y, (d) => (d - y_) ** 2);
    return XY / Math.sqrt(XX * YY);
}

const distinctConversion = (data) => {
    if(data && data.length) {
        const params = []
        Object.keys(data[0]).map(k => {
            if(typeof data[0][k] === 'string') params.push(k)
        })
        const dist = {}
        params.map(p => {
            const values = [...new Set(Plot.valueof(data, p))]
            dist[p] = {}
            values.map((v, i) => {
                dist[p][v] = i
            })
        })
        return data.map(i => {
            Object.keys(dist).map(d => {
                i[d] = dist[d][i[d]]
            })
            return i
        })
    }
}

function CorrelationHeatmap({ data, cohorts }) {
    const ref = useRef();
    const classes = useStyles();
    const [corrData, setCorrData] = useState([])
    const [vars, setVars] = useState([])

    useEffect(() => {
        const distData = distinctConversion(data)
        const selectedVars = vars.filter(i => i.selected).map(i => i.var_key)
        const fields = distData && distData[0] ? Object.keys(distData[0]).filter(k => selectedVars.includes(k)) : []
        const correlations = d3.cross(fields, fields).map(([a, b]) => ({
            a,
            b,
            correlation: corr(Plot.valueof(distData, a), Plot.valueof(distData, b))
        }))
        setCorrData(correlations)
    }, [data, vars])

    useEffect(() => {
        const chart = Plot.plot({
            marginLeft: 100,
            label: null,
            width: 900,
            height: 300,
            color: { scheme: "rdylbu", pivot: 0, legend: true, label: "Correlation", marginLeft: 20, marginRight: -20 },
            marks: [
                Plot.cell(corrData, { x: "a", y: "b", fill: "correlation" }),
                Plot.text(corrData, {
                    x: "a",
                    y: "b",
                    text: (d) => d.correlation.toFixed(2),
                    fill: (d) => (Math.abs(d.correlation) > 0.6 ? "white" : "black")
                })
            ]
        })
        if(corrData.length) ref.current.append(chart)
        return () => chart.remove()
    }, [corrData])

    useEffect(() => {
        if(cohorts && cohorts[0]) {
            setVars(cohorts[0].variables.map(i => ({...i, selected: true})))
        }
    }, [cohorts])

    const onCheckbox = (e) => {
        const checked = vars.map(i => {
            if(i.var_key === e.target.dataset.key) i.selected = !i.selected
            return i
        })
        setVars(checked)
    }

    return (
        <div>
            <div className={classes.heatmap} ref={ref} />
            <div>
                {vars.map((v, k) => (
                  <div key={k}>
                      <input type="checkbox"
                             id={`${k}`}
                             checked={v.selected}
                             data-key={v.var_key}
                             onChange={onCheckbox}
                      />
                      <label htmlFor={`${k}`}>{v.name}</label>
                  </div>
                ))}
            </div>
        </div>
    )
}

export default CorrelationHeatmap;