import { Node, ResolvedPos, Schema } from "prosemirror-model";
import {
  AllSelection,
  EditorState,
  Plugin,
  PluginKey,
  TextSelection
} from "prosemirror-state";
import { Decoration, DecorationSet } from "prosemirror-view";
import { findParent } from "../../../util";

interface SelectedNode {
  node: Node;
  pos: number;
}

function findNodesInRange(state: EditorState): SelectedNode[] | null {
  const { doc, selection } = state;

  if (
    !(selection instanceof TextSelection) &&
    !(selection instanceof AllSelection)
  ) {
    return null;
  }

  const { from, to } = selection;

  const nodes = new Array<SelectedNode>();

  doc.nodesBetween(from, to, (node, pos) => {
    const isText = node.isText || node.isTextblock;
    if (!isText) {
      const nodeFrom = pos;
      const nodeTo = pos + node.nodeSize;
      const range = { from: nodeFrom, to: nodeTo };

      if (from <= range.from && to >= range.to) {
        nodes.push({ node: node, pos: pos });
        return false;
      }
    }

    return;
  });

  return nodes;
}

function getDecorationsForState<S extends Schema>(
  state: EditorState<S>
): DecorationSet {
  const nodesInRange = findNodesInRange(state);
  if (nodesInRange != null) {
    const { doc } = state;

    const decorations = nodesInRange.map((selection) => {
      const { node, pos } = selection;
      const from = pos;
      const to = pos + node.nodeSize;

      return Decoration.node(from, to, {
        class: "ProseMirror-selection"
      });
    });

    return DecorationSet.create(doc, decorations);
  }

  return DecorationSet.empty;
}

export const nodeRangeSelectionKey = new PluginKey<DecorationSet>(
  "nodeRangeSelection"
);

export function nodeRangeSelection<S extends Schema>() {
  return new Plugin<DecorationSet, S>({
    key: nodeRangeSelectionKey,
    state: {
      init(_, state) {
        return getDecorationsForState(state);
      },
      apply(_tr, _value, _oldState, newState) {
        return getDecorationsForState(newState);
      }
    },
    props: {
      decorations(state) {
        return this.getState(state);
      }
    },
    appendTransaction: (transactions, oldState, newState) => {
      const selectionSet = transactions.find((tr) => tr.selectionSet);
      if (!selectionSet) {
        return undefined;
      }

      if (newState.selection instanceof TextSelection) {
        const isPointer =
          selectionSet.getMeta("pointer") === true ? true : false;
        const normalizedSelection = limitSelectionWithinOrSelectBlock(
          newState.doc,
          oldState.selection.$head,
          newState.selection.$anchor,
          newState.selection.$head,
          isPointer
        );

        if (
          normalizedSelection != null &&
          !newState.selection.eq(normalizedSelection)
        ) {
          return newState.tr.setSelection(normalizedSelection);
        }
      }

      return undefined;
    }
  });
}

function limitSelectionWithinOrSelectBlock<S extends Schema>(
  doc: Node<S>,
  $oldHead: ResolvedPos<S>,
  $newAnchor: ResolvedPos<S>,
  $newHead: ResolvedPos<S>,
  isPointer: boolean
) {
  const isContent = (node: Node<S>) =>
    node.type.spec.constrainSelection === true;
  const isParent = (node: Node<S>) => node.type.spec.blockSelection === true;

  const anchorContent = findParent($newAnchor, isContent);
  const headContent = findParent($newHead, isContent);

  if (anchorContent != null) {
    if (headContent != null && anchorContent.node.eq(headContent.node)) {
      return undefined;
    } else {
      const newHead =
        $newAnchor.pos <= $newHead.pos
          ? $newAnchor.doc.resolve(
              anchorContent.pos + anchorContent.node.nodeSize
            )
          : $newAnchor.doc.resolve(anchorContent.pos);
      return TextSelection.between(
        $newAnchor,
        newHead,
        $newAnchor.pos <= $newHead.pos ? 1 : -1
      );
    }
  } else if (headContent != null) {
    const headParent = findParent($newHead, isParent);
    if (headParent != null) {
      const forward = $newAnchor.pos <= $newHead.pos;

      let newHead: ResolvedPos<S>;
      if (forward) {
        if (isPointer) {
          newHead = $newAnchor.doc.resolve(
            headParent.pos + headParent.node.nodeSize
          );
        } else {
          if ($newHead.pos < $oldHead.pos) {
            newHead = doc.resolve(headParent.pos);
          } else {
            newHead = doc.resolve(headParent.pos + headParent.node.nodeSize);
          }
        }
      } else {
        if (isPointer) {
          newHead = $newAnchor.doc.resolve(headParent.pos);
        } else {
          if ($newHead.pos > $oldHead.pos) {
            newHead = doc.resolve(headParent.pos + headParent.node.nodeSize);
          } else {
            newHead = doc.resolve(headParent.pos);
          }
        }
      }

      return new TextSelection($newAnchor, newHead);
    }
  }

  return undefined;
}
