import React from 'react'
import PropTypes from 'prop-types'

import * as d3 from 'd3'


// =============================================================================
const Legend = ({ children }) => (
  <div className="legend">
    {children}
  </div>
)


// =============================================================================
const LegendItem = ({ name, colour }) => (
  <div className="series">
    <div className="colour" style={{ backgroundColor: colour }} />
    <span className="label">{name}</span>
  </div>
)


// =============================================================================
class HeatmapChart extends React.Component {
  constructor(props) {
    super(props)

    this.updateChart = this.updateChart.bind(this)
  }

  updateChart() {
    const data = this.props.data

    // Set the dimensions of the graph using the user-specified margins.
    const margin = this.props.margin,
          width  = this.props.width - margin.left - margin.right,
          height = this.props.height - margin.top - margin.bottom

    // Create an SVG group element to use as the host for the chart.  We use a
    // here to support the transformation required by the margins.
    const chart = d3
      .select(this.svg)
      .append("g")
      .attr("transform", "translate(" + margin.left + "," + margin.top + ")")

    // Create scales for the X and Y series.  These behave like a map from the
    // row/column value (the domains) onto the pixel-value in the X and Y
    // directions (the ranges).  The range of the Y scale is inverted because
    // the origin of the chart should be the bottom-left, but SVGs (like all
    // computer graphics) uses the top-left as the origin.
    const xScale = d3
      .scaleBand()
      .range([ 0, width ])
      .domain(data.columns)
      .padding(0.01)

    const yScale = d3
      .scaleBand()
      .range([ height, 0 ])
      .domain(data.rows)
      .padding(0.01)

    // The data for this chart is an integer in [0,3] which corresponds to these
    // values.  This is purely a convenience to make the colour-coding of the
    // data points simple.  In a more general data-set, we'd likely need some
    // form of mapping function from data-values onto a gradient or something.
    const colours = [
      '#a6d96a',  // True negative
      '#fdae61',  // False negative
      '#d7191c',  // False positive
      '#1a9641',  // True positive
    ]

    // SVG enforces a drawing order where the first drawn has the lowest z-index
    // and, unfortunately, this cannot be overridden using CSS.  Therefore, we
    // draw the series data first.
    //
    // There's quite a lot happening in this function chain.  First, a group is
    // added to contain all the series data points purely for organisation.  D3
    // uses an interesting data-binding technique using entrance and exit
    // functions.  Below, the function chain following enter() provides D3 with
    // the appropriate functions to invoke when a new data point is required.
    // By binding data to this query (i.e., the selectAll() class), D3 calls
    // this function chain which ultimately creates an SVG rect for each sample
    // in the series.  Like React, they each need a key (provided by the
    // function passed as second argument to data()).
    //
    // The attributes of position, size, and fill-colour are either functions
    // (if the value must be computed) or constants.  Here we use the two scales
    // created above to position and size the data-point rects and use the
    // simple array of colours above instead of some computation or other
    // mapping.
    chart
      .append("g")
      .classed("series", true)
      .selectAll()
      .data(data.samples, d => d.col + ':' + d.row)
      .enter()
        .append("rect")
        .attr("x", d => xScale(d.col))
        .attr("y", d => yScale(d.row))
        .attr("width", xScale.bandwidth())
        .attr("height", yScale.bandwidth())
        .style("fill", d => colours[d.value])

    // With the data points visualised, we can add the axes so they appear on
    // top of the left-most and bottom-most series rects.  For each axis, we use
    // D3 to create a function that, when called, will return the SVG elements
    // required to visualise an axis (the axis*() class).  Each of them require
    // the scales created earlier, and are given the exact values to display.
    // Typically this isn't required, but without this filtering the axes are
    // very cluttered with this data-set.  The filter here simply selects the
    // ticks with odd-valued indices to cut the number displayed in half.
    //
    // Then, for each axis, we create a group for organisational purposes and
    // instruct D3 to call these axis functions on this newly-created group.
    // This is identical to creating an element and using selectAll() to chain
    // functions together.  The function passed to call() is invoked with the D3
    // selection passed as the sole parameter.
    const tickFilter = (d, i) => !(i % 2)

    const xAxis = d3
      .axisBottom(xScale)
      .tickValues(xScale.domain().filter(tickFilter))

    chart
      .append("g")
      .classed("axis", true)
      .attr("transform", "translate(0," + height + ")")
      .call(xAxis)

    const yAxis = d3
      .axisLeft(yScale)
      .tickValues(yScale.domain().filter(tickFilter))

    chart
      .append("g")
      .classed("axis", true)
      .call(yAxis)
  }

  componentDidMount() {
    this.updateChart()
  }

  componentDidUpdate() {
    this.updateChart()
  }

  render() {
    return (
      <div className="content">
        <svg
          viewBox={`0 0 ${this.props.width} ${this.props.height}`}
          ref={e => this.svg = e}
        />
        <Legend>
          <LegendItem colour="#1a9641" name="True positive" />
          <LegendItem colour="#a6d96a" name="True negative" />
          <LegendItem colour="#d7191c" name="False positive" />
          <LegendItem colour="#fdae61" name="False negative" />
        </Legend>
      </div>
    )
  }
}

// -----------------------------------------------------------------------------
HeatmapChart.propTypes = {
  data: PropTypes.object.isRequired,
  width: PropTypes.number.isRequired,
  height: PropTypes.number.isRequired,
  margin: PropTypes.shape({
    top:    PropTypes.number,
    right:  PropTypes.number,
    bottom: PropTypes.number,
    left:   PropTypes.number,
  }),
}

// -----------------------------------------------------------------------------
HeatmapChart.defaultProps = {
  margin: { top: 0, right: 0, bottom: 30, left: 35 },
}

// -----------------------------------------------------------------------------
export default HeatmapChart
