import { useCallback, useMemo } from "react";
import { useEdges, useNodes } from "reactflow";
import { FlowNodeData } from "../components/reactFlow/nodeTypes/FlowNode";
import { NodeType } from "../models/nodeType";

export type NodeLookupFilters = { nodeClass?: string; nodeName?: string; nodeCategory?: string; handleName?: string };

const useNodeLookup = () => {
  const edges = useEdges();
  const nodes = useNodes<NodeType>();

  const targetHandleToNodeDataMap = useMemo(() => {
    const nodeDataMap = nodes.reduce((acc, node) => {
      acc.set(node.id, node.data);

      return acc;
    }, new Map<string, NodeType>());

    const handleToNodeDataMap = edges.reduce((acc, edge) => {
      if (edge.targetHandle && nodeDataMap.has(edge.source)) {
        acc.set(edge.targetHandle, nodeDataMap.get(edge.source));
      }

      return acc;
    }, new Map<string, NodeType | undefined>());

    return handleToNodeDataMap;
  }, [edges, nodes]);

  const getNodeTypeByTargetHandles = useCallback(
    <T = unknown>(targetHandles: (string | undefined)[], filters: NodeLookupFilters = {}) => {
      return targetHandles
        .map((handle) => {
          return handle ? targetHandleToNodeDataMap.get(handle) : undefined;
        })
        .filter((node) => {
          if (!node) {
            return false;
          }

          if (filters.nodeClass && node.nodeClass !== filters.nodeClass) {
            return false;
          }

          if (filters.nodeName && node.nodeName !== filters.nodeName) {
            return false;
          }

          if (filters.nodeCategory && node.nodeCategory !== filters.nodeCategory) {
            return false;
          }

          if (
            filters.handleName &&
            node.sourceHandles?.find((handle) => handle.handleName === filters.handleName) === undefined
          ) {
            return false;
          }

          return true;
        }) as NodeType<T>[];
    },
    [targetHandleToNodeDataMap]
  );

  const getTargetNodeOutputSchema = useCallback(
    (
      targetHandles: (string | undefined)[],
      filters: NodeLookupFilters = {
        nodeClass: "flow",
      }
    ) => {
      const nodes = getNodeTypeByTargetHandles<FlowNodeData>(targetHandles, filters);

      if (nodes.length > 1) {
        console.warn(`More than one nodes found for target handles`, nodes);
      }

      return nodes[0] ? nodes[0].nodeData?.outputSchema : undefined;
    },
    [getNodeTypeByTargetHandles]
  );

  return {
    getNodeTypeByTargetHandles,
    getTargetNodeOutputSchema,
  };
};

export default useNodeLookup;
