import { Connection, HandleType, Node, useReactFlow } from "reactflow";
import { useCallback } from "react";
import { NodeHandle, NodeType, SourceHandle, TargetHandle } from "../models/nodeType";

const useNodeHandlesValidation = () => {
  const { getNode, getNodes, getEdges } = useReactFlow<NodeType>();

  const isValidConnectionsAtLimit = useCallback(({ handleId, handleType, handleCategory }: NodeHandle) => {
    const edges = getEdges();

    const [edge] = edges.filter(
      (edge) => edge?.data?.isDragging === true && (edge.targetHandle === handleId || edge.sourceHandle === handleId)
    );

    if (edge != null) {
      return false;
    }

    if (handleType === "source") {
      const sourceHandleConnectionCount = edges.filter(
        ({ sourceHandle: sourceHandleId }) => sourceHandleId === handleId
      ).length;

      if (handleCategory === "data" || handleCategory === "start" || handleCategory === "flow") {
        return false;
      }

      return sourceHandleConnectionCount === 1;
    }

    if (handleType === "target") {
      const targetHandleConnectionCount = edges.filter(
        ({ targetHandle: targetHandleId }) => targetHandleId === handleId
      ).length;

      return targetHandleConnectionCount === 1;
    }

    return true;
  }, []);

  const isValidConnectionType = useCallback(
    (
      { handleCategory: sourceHandleCategory, handleName: sourceHandleName }: SourceHandle,
      { handleCategory: targetHandleCategory, handleName: targetHandleName }: TargetHandle
    ) => {
      if (sourceHandleCategory === "start") {
        switch (targetHandleCategory) {
          case "flow": {
            return true;
          }
          default: {
            return false;
          }
        }
      }

      if (sourceHandleCategory === "flow") {
        switch (targetHandleCategory) {
          case "end":
          case "flow": {
            return true;
          }
          default: {
            return false;
          }
        }
      }

      if (sourceHandleCategory === "data") {
        switch (targetHandleCategory) {
          case "data": {
            return sourceHandleName === targetHandleName;
          }
          default: {
            return false;
          }
        }
      }

      if (sourceHandleCategory === "element") {
        switch (targetHandleCategory) {
          case "element": {
            return sourceHandleName === targetHandleName;
          }
          default: {
            return false;
          }
        }
      }

      return false;
    },
    []
  );

  const isValidConnection = useCallback(
    ({
      source: sourceNodeId,
      sourceHandle: sourceHandleId,
      target: targetNodeId,
      targetHandle: targetHandleId,
    }: Connection) => {
      if (sourceNodeId == null || sourceHandleId == null || targetNodeId == null || targetHandleId == null) {
        return false;
      }

      const sourceNode = getNode(sourceNodeId);
      const targetNode = getNode(targetNodeId);

      if (sourceNode == null || targetNode == null) {
        return false;
      }

      const sourceHandle = sourceNode.data.sourceHandles?.find(({ handleId }) => handleId === sourceHandleId);
      const targetHandle = targetNode.data.targetHandles?.find(({ handleId }) => handleId === targetHandleId);

      if (sourceHandle == null || targetHandle == null) {
        return false;
      }

      if (isValidConnectionsAtLimit(sourceHandle)) {
        return false;
      }

      if (isValidConnectionsAtLimit(targetHandle)) {
        return false;
      }

      return isValidConnectionType(sourceHandle, targetHandle);
    },
    []
  );

  const getValidConnectionHandles = useCallback(
    (
      nodeId: string | null,
      initiatingConnectionHandleId: string | null,
      initiatingConnectionHandleType: HandleType | null
    ) => {
      const validConnectionHandles: NodeHandle[] = [];

      if (nodeId == null || initiatingConnectionHandleId == null || initiatingConnectionHandleType == null) {
        return validConnectionHandles;
      }

      const nodes = getNodes();
      const node = getNode(nodeId);

      if (node == null) {
        return validConnectionHandles;
      }

      let initiatingConnectionNode: Node<NodeType> | undefined = undefined;
      let initiatingConnectionNodeHandle: SourceHandle | TargetHandle | undefined = undefined;

      if (initiatingConnectionHandleType === "source") {
        initiatingConnectionNode = nodes.find(
          ({ data: { sourceHandles } }) =>
            sourceHandles?.find(({ handleId }) => handleId === initiatingConnectionHandleId) != null
        );

        initiatingConnectionNodeHandle = initiatingConnectionNode?.data.sourceHandles?.find(
          ({ handleId }) => handleId === initiatingConnectionHandleId
        );
      }

      if (initiatingConnectionHandleType === "target") {
        initiatingConnectionNode = nodes.find(
          ({ data: { targetHandles } }) =>
            targetHandles?.find(({ handleId }) => handleId === initiatingConnectionHandleId) != null
        );

        initiatingConnectionNodeHandle = initiatingConnectionNode?.data.targetHandles?.find(
          ({ handleId }) => handleId === initiatingConnectionHandleId
        );
      }

      if (initiatingConnectionNodeHandle == null) {
        return validConnectionHandles;
      }

      if (isValidConnectionsAtLimit(initiatingConnectionNodeHandle)) {
        return validConnectionHandles;
      }

      if (initiatingConnectionHandleType === "source") {
        if (node.data.sourceHandles?.find(({ handleId }) => handleId === initiatingConnectionHandleId) != null) {
          validConnectionHandles.push(initiatingConnectionNodeHandle);
        }
      }

      if (initiatingConnectionHandleType === "target") {
        if (node.data.targetHandles?.find(({ handleId }) => handleId === initiatingConnectionHandleId) != null) {
          validConnectionHandles.push(initiatingConnectionNodeHandle);
        }
      }

      if (initiatingConnectionHandleType === "source") {
        node.data.targetHandles?.forEach((targetHandle) => {
          if (
            isValidConnection({
              source: initiatingConnectionNode?.id ?? null,
              sourceHandle: initiatingConnectionHandleId,
              target: nodeId,
              targetHandle: targetHandle.handleId ?? null,
            })
          ) {
            validConnectionHandles.push(targetHandle);
          }
        });
      }

      if (initiatingConnectionHandleType === "target") {
        node.data.sourceHandles?.forEach((sourceHandle) => {
          if (
            isValidConnection({
              source: nodeId,
              sourceHandle: sourceHandle.handleId ?? null,
              target: initiatingConnectionHandleId,
              targetHandle: initiatingConnectionNode?.id ?? null,
            })
          ) {
            validConnectionHandles.push(sourceHandle);
          }
        });
      }

      return validConnectionHandles;
    },
    []
  );

  return {
    isValidConnection,
    isValidConnectionsAtLimit,
    getValidConnectionHandles,
  };
};

export default useNodeHandlesValidation;
