import * as d3 from 'd3'

const d3Sankey = () => {
  const sankey = {}
  let nodeWidth = 24
  let nodePadding = 8
  let size = [1, 1]
  let nodes = []
  let links = []
  let sinksRight = true

  sankey.nodeWidth = _ => {
    if (!_) return nodeWidth
    nodeWidth = +_
    return sankey
  }

  sankey.nodePadding = _ => {
    if (!_) return nodePadding
    nodePadding = +_
    return sankey
  }

  sankey.nodes = _ => {
    if (!_) return nodes
    nodes = _
    return sankey
  }

  sankey.links = _ => {
    if (!_) return links
    links = _
    return sankey
  }

  sankey.size = _ => {
    if (!_) return size
    size = _
    return sankey
  }

  sankey.sinksRight = _ => {
    if (!_) return sinksRight
    sinksRight = _
    return sankey
  }

  sankey.layout = iterations => {
    computeNodeLinks()
    computeNodeValues()
    computeNodeBreadthsWithLoops()
    computeNodeDepths(iterations)
    return sankey
  }

  sankey.relayout = () => {
    computeLinkDepths()
    return sankey
  }

  // SVG path data generator, to be used as "d" attribute on "path" element selection.
  sankey.link = () => {
    let curvature = 0.5

    function link(d) {
      const xs = d.source.x + d.source.dx
      const xt = d.target.x
      const xi = d3.interpolateNumber(xs, xt)
      let xsc = xi(curvature)
      let xtc = xi(1 - curvature)
      const ys = d.source.y + d.sy + d.dy / 2
      const yt = d.target.y + d.ty + d.dy / 2

      if (!d.cycleBreaker) {
        return `M${xs},${ys}C${xsc},${ys} ${xtc},${yt} ${xt},${yt}`
      } else {
        const xdelta = 1.5 * d.dy + 0.05 * Math.abs(xs - xt)
        xsc = xs + xdelta
        xtc = xt - xdelta
        const xm = xi(0.5)
        const ym = d3.interpolateNumber(ys, yt)(0.5)
        const xsxt = 0.1 * Math.abs(xs - xt)
        const ysyt = 0.1 * Math.abs(ys - yt)
        const ydelta = 2 * d.dy + xsxt + ysyt * (ym < size[1] / 2 ? -1 : 1)
        return `M${xs},${ys}C${xsc},${ys} ${xsc},${ys + ydelta} ${xm},${ym + ydelta}S${xtc},${yt} ${xt},${yt}`
      }
    }

    link.curvature = _ => {
      if (!_) return curvature
      curvature = +_
      return link
    }

    return link
  }

  // Populate the sourceLinks and targetLinks for each node.
  // Also, if the source and target are not objects, assume they are indices.
  function computeNodeLinks() {
    nodes.forEach(node => {
      // Links that have this node as source.
      node.sourceLinks = []
      // Links that have this node as target.
      node.targetLinks = []
    })
    links.forEach(link => {
      if (typeof link.source === 'number') link.source = nodes[link.source]
      if (typeof link.target === 'number') link.target = nodes[link.target]
      link.source.sourceLinks.push(link)
      link.target.targetLinks.push(link)
    })
  }

  // Compute the value (size) of each node by summing the associated links.
  function computeNodeValues() {
    nodes.forEach(node => {
      node.value = Math.max(d3.sum(node.sourceLinks, value), d3.sum(node.targetLinks, value))
    })
  }

  //compute breadth with loops being considered
  // Iteratively assign the breadth (x-position) for each node.
  // Nodes are assigned the maximum breadth of incoming neighbors plus one
  // nodes with no incoming links are assigned breadth zero, while
  // nodes with no outgoing links are assigned the maximum breadth.
  function computeNodeBreadthsWithLoops() {
    let remainingNodes = nodes
    let nextNodes = null
    let x = 0

    // Work from left to right.
    // Keep updating the breath (x-position) of nodes that are target of recently updated nodes.
    while (remainingNodes.length && x < nodes.length) {
      nextNodes = []
      remainingNodes.forEach(setNodeSource, { x: x, nextNodes: nextNodes })
      if (nextNodes.length === remainingNodes.length) {
        // There must be a cycle here. Let's search for a link that breaks it.
        findAndMarkCycleBreaker(nextNodes)
        return computeNodeBreadthsWithLoops()
      } else {
        remainingNodes = nextNodes
        x += 1
      }
    }

    // Optionally move pure sinks always to the right.
    if (sinksRight) {
      moveSinksRight(x)
    }
    scaleNodeBreadths((size[0] - nodeWidth) / (x - 1))
  }

  function setNodeSource(node) {
    node.x = this.x
    node.dx = nodeWidth
    node.sourceLinks.forEach(link => {
      if (this.nextNodes.indexOf(link.target) < 0 && !link.cycleBreaker) {
        this.nextNodes.push(link.target)
      }
    })
  }

  // Find a link that breaks a cycle in the graph (if any).
  function findAndMarkCycleBreaker(nodeList) {
    // Go through all nodes from the given subset and traverse links searching for cycles.
    let link
    for (let n = nodeList.length - 1; n >= 0; n -= 1) {
      link = depthFirstCycleSearch(nodes[n], [])
      if (link) {
        return link
      }
    }

    // Depth-first search to find a link that is part of a cycle.
    function depthFirstCycleSearch(cursorNode, path) {
      let currentTarget
      let currentLink
      for (let n = cursorNode.sourceLinks.length - 1; n >= 0; n -= 1) {
        currentLink = cursorNode.sourceLinks[n]

        // Skip already known cycle breakers
        if (!currentLink.cycleBreaker) {
          // Check if target of link makes a cycle in current path
          currentTarget = currentLink.target
          for (let l = 0; l < path.length; l += 1) {
            if (path[l].source === currentTarget) {
              // We found a cycle. Search for weakest link in cycle
              let weakest = currentLink
              for (let k = l; k < path.length; k += 1) {
                if (path[k].value < weakest.value) {
                  weakest = path[k]
                }
              }
              // Mark weakest link as (known) cycle breaker and abort search
              weakest.cycleBreaker = true
              return weakest
            }
          }

          // Recurse deeper
          path.push(link)
          currentLink = depthFirstCycleSearch(currentTarget, path)
          path.pop()
          // Stop further search if we found a cycle breaker.
          if (currentLink) {
            return currentLink
          }
        }
      }
    }
  }

  function moveSinksRight(x) {
    nodes.forEach(node => {
      if (!node.sourceLinks.length) {
        node.x = x - 1
      }
    })
  }

  function scaleNodeBreadths(kx) {
    nodes.forEach(node => {
      node.x *= kx
    })
  }

  // Compute the depth (y-position) for each node.
  function computeNodeDepths(iterations) {
    // Group nodes by breath.
    const nodesByBreadth = Array.from(d3.group(nodes, d => d.x)) // Convert Map to array
      .sort((a, b) => d3.ascending(a[0], b[0])) // Sort by key
      .map(d => d[1]) // Return values for each group

    initializeNodeDepth()
    resolveCollisions()
    computeLinkDepths()
    for (let alpha = 1; iterations > 0; iterations -= 1) {
      relaxRightToLeft((alpha *= 0.99))
      resolveCollisions()
      computeLinkDepths()
      relaxLeftToRight(alpha)
      resolveCollisions()
      computeLinkDepths()
    }

    function initializeNodeDepth() {
      // Calculate vertical scaling factor.
      const ky = d3.min(nodesByBreadth, nodeList => {
        return (size[1] - (nodeList.length - 1) * nodePadding) / d3.sum(nodeList, value)
      })

      nodesByBreadth.forEach(nodeList => {
        nodeList.forEach((node, i) => {
          node.y = i
          node.dy = node.value * ky
        })
      })

      links.forEach(link => {
        link.dy = link.value * ky
      })
    }

    function relaxLeftToRight(alpha) {
      nodesByBreadth.forEach(nodeList => {
        nodeList.forEach(node => {
          if (node.targetLinks.length) {
            // Value-weighted average of the y-position of source node centers linked to this node.
            const y = d3.sum(node.targetLinks, weightedSource) / d3.sum(node.targetLinks, value)
            node.y += (y - center(node)) * alpha
          }
        })
      })

      function weightedSource(link) {
        return (link.source.y + link.sy + link.dy / 2) * link.value
      }
    }

    function relaxRightToLeft(alpha) {
      nodesByBreadth
        .slice()
        .reverse()
        .forEach(nodeList => {
          nodeList.forEach(node => {
            if (node.sourceLinks.length) {
              // Value-weighted average of the y-positions of target nodes linked to this node.
              const y = d3.sum(node.sourceLinks, weightedTarget) / d3.sum(node.sourceLinks, value)
              node.y += (y - center(node)) * alpha
            }
          })
        })

      function weightedTarget(link) {
        return (link.target.y + link.ty + link.dy / 2) * link.value
      }
    }

    function resolveCollisions() {
      nodesByBreadth.forEach(nodeList => {
        let node
        let dy
        let y0 = 0
        const n = nodeList.length
        let i

        // Push any overlapping nodes down.
        nodes.sort(ascendingDepth)
        for (i = 0; i < n; i += 1) {
          node = nodeList[i]
          dy = y0 - node.y
          if (dy > 0) node.y += dy
          y0 = node.y + node.dy + nodePadding
        }

        // If the bottommost node goes outside the bounds, push it back up.
        dy = y0 - nodePadding - size[1]
        if (dy > 0) {
          node.y -= dy
          y0 = node.y

          // Push any overlapping nodes back up.
          for (i = n - 2; i >= 0; i -= 1) {
            node = nodeList[i]
            dy = node.y + node.dy + nodePadding - y0
            if (dy > 0) node.y -= dy
            y0 = node.y
          }
        }
      })
    }

    function ascendingDepth(a, b) {
      return a.y - b.y
    }
  }

  // Compute y-offset of the source endpoint (sy) and target endpoints (ty) of links,
  // relative to the source/target node's y-position.
  function computeLinkDepths() {
    nodes.forEach(node => {
      node.sourceLinks.sort(ascendingTargetDepth)
      node.targetLinks.sort(ascendingSourceDepth)
    })
    nodes.forEach(node => {
      let sy = 0
      let ty = 0
      node.sourceLinks.forEach(link => {
        link.sy = sy
        sy += link.dy
      })
      node.targetLinks.forEach(link => {
        link.ty = ty
        ty += link.dy
      })
    })

    function ascendingSourceDepth(a, b) {
      return a.source.y - b.source.y
    }

    function ascendingTargetDepth(a, b) {
      return a.target.y - b.target.y
    }
  }

  // Y-position of the middle of a node.
  function center(node) {
    return node.y + node.dy / 2
  }

  // Value property accessor.
  function value(x) {
    return x.value
  }

  return sankey
}

export default d3Sankey
