Source code for rhesis.sdk.services.mcp.agent

"""Generic MCP Agent using ReAct (Reason-Act-Observe) loop."""

import asyncio
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import jinja2

from rhesis.sdk.models.base import BaseLLM
from rhesis.sdk.models.factory import get_model
from rhesis.sdk.services.mcp.client import MCPClient
from rhesis.sdk.services.mcp.exceptions import (
    MCPApplicationError,
    MCPConfigurationError,
    MCPConnectionError,
    MCPValidationError,
)
from rhesis.sdk.services.mcp.executor import ToolExecutor
from rhesis.sdk.services.mcp.schemas import (
    AgentAction,
    AgentResult,
    ExecutionStep,
    ToolCall,
    ToolResult,
)

logger = logging.getLogger(__name__)


[docs] class MCPAgent: """ Generic MCP Agent for autonomous tool usage with customizable prompts. Uses a ReAct (Reason-Act-Observe) loop to autonomously call MCP tools and accomplish tasks. Clients can customize behavior via system prompts. """
[docs] def __init__( self, model: Optional[Union[str, BaseLLM]] = None, mcp_client: MCPClient = None, system_prompt: Optional[str] = None, max_iterations: int = 10, verbose: bool = False, ): """ Initialize the MCP agent. Args: model: Language model for reasoning and decision-making. Can be a string (provider name), BaseLLM instance, or None (uses default). mcp_client: Client connected to an MCP server system_prompt: Custom system prompt to define agent behavior (optional) max_iterations: Maximum reasoning loops before stopping (default: 10) verbose: Print detailed execution logs to stdout (default: False) """ if not mcp_client: raise ValueError("mcp_client is required") # Initialize template environment templates_dir = Path(__file__).parent / "prompt_templates" self._jinja_env = jinja2.Environment( loader=jinja2.FileSystemLoader(str(templates_dir)), autoescape=False, trim_blocks=True, lstrip_blocks=True, ) # Convert model to BaseLLM instance if needed self.model = self._set_model(model) self.mcp_client = mcp_client self.system_prompt = system_prompt or self._load_default_system_prompt() self.max_iterations = max_iterations self.verbose = verbose self.executor = ToolExecutor(mcp_client)
def _set_model(self, model: Optional[Union[str, BaseLLM]]) -> BaseLLM: """Convert model string or instance to BaseLLM instance.""" if isinstance(model, BaseLLM): return model return get_model(model) def _load_default_system_prompt(self) -> str: """Load the default system prompt from template.""" template = self._jinja_env.get_template("system_prompt.j2") return template.render()
[docs] async def run_async(self, user_query: str) -> AgentResult: """ Execute the agent's ReAct loop asynchronously. Connects to MCP server, discovers tools, and iteratively reasons about what actions to take until the task is complete or max iterations reached. Args: user_query: User's query or task description Returns: AgentResult with final answer and execution history """ history: List[ExecutionStep] = [] iteration = 0 try: try: await self.mcp_client.connect() except ConnectionError as e: # Convert ConnectionError to MCPConnectionError raise MCPConnectionError( f"Failed to connect to MCP server: {str(e)}", original_error=e ) logger.info("[MCPAgent] Connected to MCP server") if self.verbose: print("\n" + "=" * 70) print("🤖 MCP Agent Starting") print("=" * 70) print(f"Max iterations: {self.max_iterations}") available_tools = await self.executor.get_available_tools() logger.info(f"[MCPAgent] Discovered {len(available_tools)} tools") # ReAct loop while iteration < self.max_iterations: iteration += 1 if self.verbose: print(f"\n{'=' * 70}") print(f"Iteration {iteration}/{self.max_iterations}") print("=" * 70) step, should_finish = await self._execute_iteration( user_query, available_tools, history, iteration ) history.append(step) if should_finish: if self.verbose: print(f"\n✓ MCP Agent finished after {iteration} iteration(s)") if step.action == "finish" and step.tool_results: tool_result = step.tool_results[0] # Check if this is an error result if tool_result.success: final_answer = tool_result.content return AgentResult( final_answer=final_answer, execution_history=history, iterations_used=len(history), max_iterations_reached=False, success=True, ) else: # Error occurred during execution error_msg = tool_result.error or "Unknown error" return AgentResult( final_answer="", execution_history=history, iterations_used=len(history), max_iterations_reached=False, success=False, error=error_msg, ) # Max iterations reached if self.verbose: print(f"\n⚠️ Max iterations ({self.max_iterations}) reached") raise MCPValidationError( f"Agent did not complete task within {self.max_iterations} iterations. " "Consider increasing max_iterations or simplifying the task." ) except ( MCPConnectionError, MCPConfigurationError, MCPValidationError, MCPApplicationError, ): # Propagate MCP exceptions directly raise except Exception as e: error_msg = f"Agent execution failed: {str(e)}" logger.error(error_msg, exc_info=True) if self.verbose: print(f"\n❌ Error: {error_msg}") # Wrap unexpected errors raise MCPValidationError(f"Agent execution failed: {str(e)}", original_error=e) finally: await self.mcp_client.disconnect() logger.info("[MCPAgent] Disconnected from MCP server")
[docs] def run(self, user_query: str) -> AgentResult: """ Execute the agent synchronously. Convenience wrapper around run_async for non-async code. Args: user_query: User's query or task description Returns: AgentResult with final answer and execution history """ return asyncio.run(self.run_async(user_query))
async def _execute_iteration( self, user_query: str, available_tools: List[Dict[str, Any]], history: List[ExecutionStep], iteration: int, ) -> Tuple[ExecutionStep, bool]: """ Execute one ReAct iteration: build prompt, get LLM decision, execute tools. Returns: Tuple of (execution_step, should_finish) """ prompt = self._build_prompt(user_query, available_tools, history) # Get LLM decision action = await self._get_llm_action(prompt, iteration) if action is None: # LLM parsing error occurred return self._create_error_step(iteration, "Failed to parse LLM response"), True if self.verbose: print(f"\n Reasoning: {action.reasoning}") print(f" Action: {action.action}") # Handle different action types if action.action == "finish": return self._handle_finish_action(action, iteration) elif action.action == "call_tool": return await self._handle_tool_calls(action, iteration) else: return self._handle_unknown_action(action, iteration) async def _get_llm_action(self, prompt: str, iteration: int) -> Optional[AgentAction]: """Get and parse the LLM's action decision.""" logger.info(f"[MCPAgent] Iteration {iteration}: Sending prompt to LLM") if self.verbose: print("\n💭 Reasoning...") try: response = self.model.generate( prompt=prompt, system_prompt=self.system_prompt, schema=AgentAction ) if isinstance(response, dict): action = AgentAction(**response) else: action = AgentAction(**json.loads(response)) logger.info( f"[MCPAgent] Iteration {iteration}: Action={action.action}, " f"Reasoning='{action.reasoning[:100]}...'" ) return action except Exception as e: logger.error(f"[MCPAgent] Failed to parse LLM response: {e}", exc_info=True) return None def _handle_finish_action( self, action: AgentAction, iteration: int ) -> Tuple[ExecutionStep, bool]: """Handle the finish action.""" logger.info("[MCPAgent] Agent finishing") if self.verbose: ans = action.final_answer[:200] if action.final_answer else "" print(f"\n✓ Final Answer: {ans}...") return ( ExecutionStep( iteration=iteration, reasoning=action.reasoning, action="finish", tool_calls=[], tool_results=[ ToolResult( tool_name="finish", success=True, content=action.final_answer or "", ) ], ), True, ) async def _handle_tool_calls( self, action: AgentAction, iteration: int ) -> Tuple[ExecutionStep, bool]: """Handle tool call actions.""" if not action.tool_calls: logger.warning("Action is 'call_tool' but no tool_calls provided") return ( ExecutionStep( iteration=iteration, reasoning=action.reasoning, action="call_tool", tool_calls=[], tool_results=[ ToolResult( tool_name="error", success=False, error="No tool calls specified", ) ], ), False, ) tool_names = [tc.tool_name for tc in action.tool_calls] logger.info(f"[MCPAgent] Calling {len(action.tool_calls)} tool(s): {tool_names}") if self.verbose: print(f"\n🔧 Calling {len(action.tool_calls)} tool(s):") for tc in action.tool_calls: print(f" • {tc.tool_name}") # Execute all tool calls tool_results = await self._execute_tools(action.tool_calls) return ( ExecutionStep( iteration=iteration, reasoning=action.reasoning, action="call_tool", tool_calls=action.tool_calls, tool_results=tool_results, ), False, ) async def _execute_tools(self, tool_calls: List[ToolCall]) -> List[ToolResult]: """ Execute multiple tool calls and return results. All ToolResults (success or failure) are passed to the LLM for reasoning. The LLM decides how to handle failures (retry, different approach, give up). """ tool_results: List[ToolResult] = [] for tool_call in tool_calls: try: result = await self.executor.execute_tool(tool_call) tool_results.append(result) # Logging if result.success: logger.info( f"[MCPAgent] Tool {result.tool_name} succeeded, " f"returned {len(result.content)} chars" ) else: logger.warning(f"[MCPAgent] Tool {result.tool_name} failed: {result.error}") if self.verbose: if result.success: print(f" ✓ {result.tool_name}: {len(result.content)} chars") else: print(f" ✗ {result.tool_name}: {result.error}") except (MCPConnectionError, MCPConfigurationError, MCPApplicationError): # Infrastructure/config/application failures - propagate immediately raise return tool_results def _handle_unknown_action( self, action: AgentAction, iteration: int ) -> Tuple[ExecutionStep, bool]: """Handle unknown action types.""" logger.warning(f"Unknown action: {action.action}") return ( ExecutionStep( iteration=iteration, reasoning=action.reasoning, action=action.action, tool_calls=[], tool_results=[ ToolResult( tool_name="error", success=False, error=f"Unknown action: {action.action}", ) ], ), True, ) def _create_error_step(self, iteration: int, error_msg: str) -> ExecutionStep: """Create an error execution step.""" return ExecutionStep( iteration=iteration, reasoning=f"Error: {error_msg}", action="finish", tool_calls=[], tool_results=[ToolResult(tool_name="error", success=False, error=error_msg)], ) def _build_prompt( self, user_query: str, available_tools: List[Dict[str, Any]], history: List[ExecutionStep] ) -> str: """Build the user prompt for the current iteration.""" tools_text = self._format_tools(available_tools) history_text = self._format_history(history) template = self._jinja_env.get_template("iteration_prompt.j2") return template.render( user_query=user_query, tools_text=tools_text, history_text=history_text, ) def _format_tools(self, tools: List[Dict[str, Any]]) -> str: """Format tool list into human-readable text with names, descriptions, \ and parameters.""" descriptions = [] for tool in tools: desc = f"- {tool['name']}: {tool.get('description', 'No description')}" if "inputSchema" in tool and tool["inputSchema"]: schema = tool["inputSchema"] if "properties" in schema: params = ", ".join(schema["properties"].keys()) desc += f"\n Parameters: {params}" descriptions.append(desc) return "\n".join(descriptions) def _format_history(self, history: List[ExecutionStep]) -> str: """Format execution history into readable text for LLM context.""" if not history: return "" parts = [] for step in history: parts.append(f"Iteration {step.iteration}:") parts.append(f" Reasoning: {step.reasoning}") parts.append(f" Action: {step.action}") if step.tool_calls: parts.append(" Tools called:") for tc in step.tool_calls: parts.append(f" • {tc.tool_name}") if step.tool_results: parts.append(" Results:") for tr in step.tool_results: if tr.success: parts.append(f" • {tr.tool_name}: {tr.content}") else: parts.append(f" • {tr.tool_name}: ERROR - {tr.error}") parts.append("") return "\n".join(parts)