import {
  getClaimAnalyticsAgentNode,
  getClaimAnalyticsAgentToolNode,
  getDocumentRetrievalNode,
  getRagAgentNode,
  getRejectClarifyNode,
  getRespondToUserToolNode,
  getSignalEventAnalyticsAgentNode,
  getSignalEventAnalyticsAgentToolNode,
  getSupervisorAgentNode,
} from "duck/graph/nodes";
import { graphState } from "duck/graph/state";
import { DuckGraphParams } from "duck/graph/types";
import { Runnable } from "@langchain/core/runnables";
import { END, MemorySaver, START, StateGraph } from "@langchain/langgraph/web";

import { GenericToolNodeName, getNextNode, NodeNames } from "./utils";

/**
 * @summary Get DUCK's compiled state graph.
 * @param params The parameters for the agent from the UI
 * @param withMemory True to use the memory saver, false to not use memory at all
 * @returns The compiled state graph
 */
const getGraph = async (
  params: DuckGraphParams,
  withMemory: boolean = false
): Promise<Runnable> => {
  // Create agent executor, optionally with a checkpointer
  // The MemorySaver checkpointer is meant for experimentation and is not intended for production usage
  const checkpointer = withMemory ? new MemorySaver() : undefined;

  /**
   * State Graph Visualization:
   *
   * START
   *   |
   *   v
   * DOCUMENT_RETRIEVAL
   *   |
   *   v
   * SUPERVISOR
   *   |
   *   v
   * +-------------------+
   * | Conditional Edges |
   * |-------------------|
   * | rag               |
   * | claimAnalytics    |
   * | rejectClarify     |
   * | respondToUserTool |
   * +-------------------+
   *   |       |                     |
   *   v       v                     v
   * RAG      REJECT_CLARIFY      CLAIM_ANALYTICS (Does not respond to users directly)
   *   |       |                     |
   *   v       v                     v
   * respondToUserTool              +--------------------+
   *   |                            | Conditional Edges  |
   *   v                            |--------------------|
   * END                            | claimAnalyticsTools|
   *                                | END                |
   *                                +--------------------+
   *                                |
   *                                v
   *                             CLAIM_ANALYTICS_TOOLS
   *                                |
   *                                v
   *                             CLAIM_ANALYTICS
   *                                |
   *                                v
   *                                SUPERVISOR
   */
  const stateGraph = new StateGraph(graphState)
    .addNode(NodeNames.DOCUMENT_RETRIEVAL, getDocumentRetrievalNode())
    .addNode(NodeNames.SUPERVISOR, await getSupervisorAgentNode(params))
    .addNode(NodeNames.REJECT_CLARIFY, await getRejectClarifyNode(params))
    .addNode(NodeNames.RAG, await getRagAgentNode(params))
    .addNode(
      NodeNames.CLAIM_ANALYTICS,
      await getClaimAnalyticsAgentNode(params)
    )
    .addNode(
      NodeNames.CLAIM_ANALYTICS_TOOLS,
      getClaimAnalyticsAgentToolNode(params)
    )
    .addNode(
      NodeNames.SIGNAL_EVENT_ANALYTICS,
      await getSignalEventAnalyticsAgentNode(params)
    )
    .addNode(
      NodeNames.SIGNAL_EVENT_ANALYTICS_TOOLS,
      getSignalEventAnalyticsAgentToolNode(params)
    )
    .addNode(NodeNames.RESPOND_TO_USER_TOOL, getRespondToUserToolNode(params))
    .addEdge(START, NodeNames.DOCUMENT_RETRIEVAL)
    .addEdge(NodeNames.DOCUMENT_RETRIEVAL, NodeNames.SUPERVISOR)
    .addConditionalEdges(NodeNames.SUPERVISOR, getNextNode, {
      [NodeNames.RAG]: NodeNames.RAG,
      [NodeNames.REJECT_CLARIFY]: NodeNames.REJECT_CLARIFY,
      [NodeNames.CLAIM_ANALYTICS]: NodeNames.CLAIM_ANALYTICS,
      [NodeNames.SIGNAL_EVENT_ANALYTICS]: NodeNames.SIGNAL_EVENT_ANALYTICS,
      [GenericToolNodeName]: NodeNames.RESPOND_TO_USER_TOOL,
      // safegaurd: end the conversation if no tool call is found
      [END]: END,
    })
    .addConditionalEdges(NodeNames.RAG, getNextNode, {
      [GenericToolNodeName]: NodeNames.RESPOND_TO_USER_TOOL,
      // safegaurd: end the conversation if no tool call is found
      [END]: END,
    })
    .addConditionalEdges(NodeNames.REJECT_CLARIFY, getNextNode, {
      [GenericToolNodeName]: NodeNames.RESPOND_TO_USER_TOOL,
      // safegaurd: end the conversation if no tool call is found
      [END]: END,
    })
    .addEdge(NodeNames.RESPOND_TO_USER_TOOL, END)
    .addConditionalEdges(NodeNames.CLAIM_ANALYTICS, getNextNode, {
      [GenericToolNodeName]: NodeNames.CLAIM_ANALYTICS_TOOLS,
      [NodeNames.SUPERVISOR]: NodeNames.SUPERVISOR,
      // safegaurd: end the conversation if no tool call is found
      [END]: END,
    })
    .addEdge(NodeNames.CLAIM_ANALYTICS_TOOLS, NodeNames.CLAIM_ANALYTICS)
    .addConditionalEdges(NodeNames.SIGNAL_EVENT_ANALYTICS, getNextNode, {
      [GenericToolNodeName]: NodeNames.SIGNAL_EVENT_ANALYTICS_TOOLS,
      [NodeNames.SUPERVISOR]: NodeNames.SUPERVISOR,
      // safegaurd: end the conversation if no tool call is found
      [END]: END,
    })
    .addEdge(
      NodeNames.SIGNAL_EVENT_ANALYTICS_TOOLS,
      NodeNames.SIGNAL_EVENT_ANALYTICS
    );

  // compile the state graph with checkpointer
  const app = stateGraph.compile({
    checkpointer,
  });

  return app;
};

export default getGraph;
