import { Node, Schema } from "prosemirror-model";
import { Plugin, PluginKey, Selection, TextSelection } from "prosemirror-state";
import { selectionFocusKey } from "../../../editor/plugins/selection-focus";

type PluginState = Selection | null;

export const inputRankSelectKey = new PluginKey<PluginState, Schema>(
  "inputRankSelectPlugin"
);

export function inputRankSelectPlugin() {
  return new Plugin<PluginState, Schema>({
    key: inputRankSelectKey,
    state: {
      init(_config, _state) {
        return null;
      },
      apply(_tr, _value, _oldState, newState) {
        const { selection, doc, schema } = newState;
        const { from, to, $to, $from, anchor, head } = selection;

        let newSelection = null;
        const focused = selectionFocusKey.getState(newState);

        if (!focused) {
          if (selection.anchor !== selection.head) {
            doc.nodesBetween(from, to, (node, pos) => {
              const isInputRank = node.type === schema.nodes.inputRank;

              if (isInputRank && isOutsideInput(from, to, node, pos)) {
                const newAnchor = from < pos ? $from : doc.resolve(pos - 1);
                const newHead =
                  to > pos + node.nodeSize
                    ? $to
                    : doc.resolve(pos + node.nodeSize + 1);
                newSelection =
                  anchor > head
                    ? new TextSelection(newHead, newAnchor)
                    : new TextSelection(newAnchor, newHead);
              }
            });
          }
          return newSelection;
        }
        return null;
      }
    },
    appendTransaction(_transactions, _oldState, newState) {
      const focused = inputRankSelectKey.getState(newState);

      if (focused) {
        return newState.tr.setSelection(focused);
      }
      return undefined;
    }
  });
}

function isOutsideInput(
  from: number,
  to: number,
  node: Node,
  pos: number
): Boolean {
  const start = pos;
  const end = pos + node.nodeSize;
  const isInside = start <= from && to <= end;
  return !isInside;
}
