import { Node, Schema } from "prosemirror-model";
import {
  EditorState,
  NodeSelection,
  Plugin,
  PluginKey,
  TextSelection
} from "prosemirror-state";
import { CellSelection } from "prosemirror-tables";
import { Decoration, DecorationSet } from "prosemirror-view";
import { findParent } from "../../../util";
import { editableKey } from "../editable";
import { GapCursor } from "../gap-cursor";

export interface Focus<S extends Schema = Schema> {
  node: Node<S>;
  pos: number;
}

export const selectionFocusKey = new PluginKey<Focus | null>("selectionFocus");

function findFocusedNode(state: EditorState): Focus | null {
  const modeState = editableKey.getState(state);
  if (modeState != null && !modeState.focusable) {
    return null;
  }

  const { selection } = state;

  if (
    !(selection instanceof TextSelection) &&
    !(selection instanceof NodeSelection) &&
    !(selection instanceof CellSelection) &&
    !(selection instanceof GapCursor)
  ) {
    return null;
  }

  const predicate = (node: Node<Schema>) => {
    return node.type.spec.focusable === true;
  };

  if (selection instanceof NodeSelection) {
    const { node, from } = selection;
    if (predicate(node)) {
      return { node: node, pos: from };
    } else {
      return null;
    }
  } else {
    const { $from, $to, from, to } = selection;

    const fromParent = findParent($from, predicate);
    const toParent = findParent($to, predicate);

    if (fromParent != null && toParent != null) {
      const { node: fromNode, pos } = fromParent;
      const toNode = toParent.node;
      if (fromNode === toNode) {
        const nodeFrom = pos;
        const nodeTo = pos + fromNode.nodeSize;
        const range = { from: nodeFrom, to: nodeTo };

        if (from > range.from && to < range.to) {
          return { node: fromNode, pos: pos };
        }
      }
    }

    return null;
  }
}

export function selectionFocus<S extends Schema>() {
  return new Plugin<Focus | null, S>({
    key: selectionFocusKey,
    props: {
      decorations(state) {
        const focus = this.getState(state);
        if (focus != null) {
          const { doc } = state;

          const { node, pos } = focus;
          const from = pos;
          const to = pos + node.nodeSize;

          return DecorationSet.create(doc, [
            Decoration.node(from, to, { class: "ProseMirror-focusednode" })
          ]);
        }

        return;
      }
    },
    state: {
      init(_, state) {
        return findFocusedNode(state);
      },
      apply(_tr, _value, _oldState, newState) {
        return findFocusedNode(newState);
      }
    }
  });
}
