Source code for rhesis.sdk.models.factory

"""
Model Factory for Rhesis SDK

This module provides a simple and intuitive way to create language model instances
and embedder instances with smart defaults and comprehensive error handling.

"""

from dataclasses import dataclass, field
from typing import Callable, Dict, Optional

from rhesis.sdk.models.base import BaseEmbedder, BaseLLM

# Default configuration
DEFAULT_PROVIDER = "rhesis"
DEFAULT_MODELS = {
    "rhesis": "rhesis-default",
    "anthropic": "claude-4",
    "cohere": "command-r-plus",
    "gemini": "gemini-2.0-flash",
    "groq": "llama3-8b-8192",
    "huggingface": "meta-llama/Llama-2-7b-chat-hf",
    "lmformatenforcer": "meta-llama/Llama-2-7b-chat-hf",
    "meta_llama": "Llama-3.3-70B-Instruct",
    "mistral": "mistral-medium-latest",
    "ollama": "llama3.1",
    "openai": "gpt-4o",
    "openrouter": "openai/gpt-4o-mini",
    "perplexity": "sonar-pro",
    "polyphemus": "",  # Polyphemus uses API's default model
    "replicate": "llama-2-70b-chat",
    "together_ai": "togethercomputer/llama-2-70b-chat",
    "vertex_ai": "gemini-2.0-flash",  # Best performance - avoid 2.5-flash
}

# Default embedding models per provider
DEFAULT_EMBEDDER_PROVIDER = "openai"
DEFAULT_EMBEDDING_MODELS = {
    "openai": "text-embedding-3-small",
    "gemini": "gemini-embedding-001",
    "vertex_ai": "text-embedding-005",
}


# Factory functions for each provider, the are used to create the model instance and
# avoid circular imports


def _create_rhesis_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for RhesisLLM."""
    from rhesis.sdk.models.providers.native import RhesisLLM

    return RhesisLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_gemini_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for GeminiLLM."""
    from rhesis.sdk.models.providers.gemini import GeminiLLM

    return GeminiLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_ollama_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for OllamaLLM."""
    from rhesis.sdk.models.providers.ollama import OllamaLLM

    return OllamaLLM(model_name=model_name, **kwargs)


def _create_openai_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for OpenAILLM."""
    from rhesis.sdk.models.providers.openai import OpenAILLM

    return OpenAILLM(model_name=model_name, api_key=api_key)


def _create_vertex_ai_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for VertexAILLM."""
    from rhesis.sdk.models.providers.vertex_ai import VertexAILLM

    # Note: api_key is ignored for Vertex AI as it uses service account credentials
    return VertexAILLM(model_name=model_name, **kwargs)


def _create_anthropic_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for AnthropicLLM."""
    from rhesis.sdk.models.providers.anthropic import AnthropicLLM

    return AnthropicLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_cohere_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for CohereLLM."""
    from rhesis.sdk.models.providers.cohere import CohereLLM

    return CohereLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_groq_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for GroqLLM."""
    from rhesis.sdk.models.providers.groq import GroqLLM

    return GroqLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_huggingface_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for HuggingFaceLLM."""
    from rhesis.sdk.models.providers.huggingface import HuggingFaceLLM

    # Note: api_key is ignored for HuggingFace as it uses local models
    return HuggingFaceLLM(model_name=model_name, **kwargs)


def _create_lmformatenforcer_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for LMFormatEnforcerLLM."""
    from rhesis.sdk.models.providers.lmformatenforcer import LMFormatEnforcerLLM

    # Note: api_key is ignored for LMFormatEnforcer as it uses local models
    return LMFormatEnforcerLLM(model_name=model_name, **kwargs)


def _create_meta_llama_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for MetaLlamaLLM."""
    from rhesis.sdk.models.providers.meta_llama import MetaLlamaLLM

    return MetaLlamaLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_mistral_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for MistralLLM."""
    from rhesis.sdk.models.providers.mistral import MistralLLM

    return MistralLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_perplexity_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for PerplexityLLM."""
    from rhesis.sdk.models.providers.perplexity import PerplexityLLM

    return PerplexityLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_replicate_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for ReplicateLLM."""
    from rhesis.sdk.models.providers.replicate import ReplicateLLM

    return ReplicateLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_together_ai_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for TogetherAILLM."""
    from rhesis.sdk.models.providers.together_ai import TogetherAILLM

    return TogetherAILLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_openrouter_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for OpenRouterLLM."""
    from rhesis.sdk.models.providers.openrouter import OpenRouterLLM

    return OpenRouterLLM(model_name=model_name, api_key=api_key, **kwargs)


def _create_polyphemus_llm(model_name: str, api_key: Optional[str], **kwargs) -> BaseLLM:
    """Factory function for PolyphemusLLM."""
    from rhesis.sdk.models.providers.polyphemus import PolyphemusLLM

    return PolyphemusLLM(model_name=model_name, api_key=api_key, **kwargs)


# Provider registry mapping provider names to their factory functions
PROVIDER_REGISTRY: Dict[str, Callable[[str, Optional[str]], BaseLLM]] = {
    "rhesis": _create_rhesis_llm,
    "anthropic": _create_anthropic_llm,
    "cohere": _create_cohere_llm,
    "gemini": _create_gemini_llm,
    "groq": _create_groq_llm,
    "huggingface": _create_huggingface_llm,
    "lmformatenforcer": _create_lmformatenforcer_llm,
    "meta_llama": _create_meta_llama_llm,
    "mistral": _create_mistral_llm,
    "ollama": _create_ollama_llm,
    "openai": _create_openai_llm,
    "openrouter": _create_openrouter_llm,
    "perplexity": _create_perplexity_llm,
    "polyphemus": _create_polyphemus_llm,
    "replicate": _create_replicate_llm,
    "together_ai": _create_together_ai_llm,
    "vertex_ai": _create_vertex_ai_llm,
}


@dataclass
class ModelConfig:
    """Configuration for a model instance.

    Args:
        provider: The provider name (e.g., "rhesis", "anthropic", "gemini", "openai", "ollama")
        model_name: Specific model name (E.g gpt-4o, gemini-2.0-flash, claude-4, etc)
        api_key: The API key to use for the model.
        extra_params: Extra parameters to pass to the model.
    """

    provider: str | None = None
    model_name: str | None = None
    api_key: str | None = None
    extra_params: dict = field(default_factory=dict)


[docs] def get_model( provider: Optional[str] = None, model_name: Optional[str] = None, api_key: Optional[str] = None, config: Optional[ModelConfig] = None, **kwargs, ) -> BaseLLM: """Create a model instance with smart defaults and comprehensive error handling. This function provides multiple ways to create a model instance: 1. **Minimal**: `get_model()` - uses all defaults 2. **Provider only**: `get_model("rhesis")` - uses default model for provider 3. **Provider + Model**: `get_model("rhesis", "rhesis-llm-v1")` 4. **Shorthand**: `get_model("rhesis/rhesis-llm-v1")` 5. **Full config**: `get_model(config=ModelConfig(...))` Args: provider: Provider name (e.g., "rhesis", "anthropic", "gemini", "openai", "mistral", "ollama") model_name: Specific model name api_key: API key for authentication config: Complete configuration object **kwargs: Additional parameters passed to ModelConfig Returns: BaseLLM: Configured model instance Raises: ValueError: If configuration is invalid or provider not supported ImportError: If required dependencies are missing Examples: >>> # Basic usage with defaults >>> model = get_model() >>> # Specify provider and model >>> model = get_model("rhesis", "rhesis-llm-v1") >>> # Use provider/model shorthand >>> model = get_model("rhesis/rhesis-llm-v1") >>> # Use different providers >>> model = get_model("anthropic", "claude-4") >>> model = get_model("openai", "gpt-4o") >>> model = get_model("mistral/mistral-medium-latest") >>> # With custom configuration >>> config = ModelConfig( ... provider="gemini", ... model_name="gemini-pro", ... api_key="your-api-key" ... ) >>> model = get_model(config=config) >>> # With extra parameters >>> model = get_model( ... "rhesis", ... "rhesis-llm-v1", ... extra_params={"temperature": 0.5} ... ) """ # Create configuration if config: # Update config with any additional parameters for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) cfg = config else: cfg = ModelConfig() # Case: shorthand string like "provider/model" if provider and "/" in provider and model_name is None: # split only first "/" so that names like "rhesis/rhesis-default" still work prov, model = provider.split("/", 1) provider, model_name = prov, model provider = provider or cfg.provider or DEFAULT_PROVIDER if provider not in DEFAULT_MODELS.keys(): raise ValueError(f"Provider {provider} not supported") model_name = model_name or cfg.model_name or DEFAULT_MODELS[provider] api_key = api_key or cfg.api_key config = ModelConfig(provider=provider, model_name=model_name, api_key=api_key) # Get the factory function for the provider factory_func = PROVIDER_REGISTRY.get(config.provider) if factory_func is None: raise ValueError(f"Provider {config.provider} not supported") # Use the factory function to create the model instance return factory_func(config.model_name, config.api_key, **kwargs)
def get_available_language_models(provider: str) -> list[str]: """Get the list of available language models for a specific provider. This function retrieves the available language models by calling the provider class's get_available_models() method. It supports all LiteLLM-based providers. Args: provider: Provider name (e.g., "anthropic", "openai", "gemini", "groq") Returns: List of available language model names for the provider Raises: ValueError: If the provider is not supported or doesn't support listing models ImportError: If required dependencies for the provider are missing Examples: >>> # Get Anthropic models >>> models = get_available_language_models("anthropic") >>> print(models) ['claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', ...] >>> # Get OpenAI models >>> models = get_available_language_models("openai") >>> # Get Gemini models >>> models = get_available_language_models("gemini") """ if provider not in PROVIDER_REGISTRY: available_providers = ", ".join(sorted(PROVIDER_REGISTRY.keys())) raise ValueError( f"Provider '{provider}' not supported. Available providers: {available_providers}" ) # Map of providers that support get_available_models (LiteLLM-based providers) litellm_providers = { "anthropic": _get_anthropic_models, "cohere": _get_cohere_models, "gemini": _get_gemini_models, "groq": _get_groq_models, "meta_llama": _get_meta_llama_models, "mistral": _get_mistral_models, "ollama": _get_ollama_models, "openai": _get_openai_models, "openrouter": _get_openrouter_models, "perplexity": _get_perplexity_models, "replicate": _get_replicate_models, "together_ai": _get_together_ai_models, "vertex_ai": _get_vertex_ai_models, } if provider not in litellm_providers: raise ValueError( f"Provider '{provider}' does not support listing available models. " f"Only the following providers support this feature: " f"{', '.join(sorted(litellm_providers.keys()))}" ) # Call the provider-specific function to get models return litellm_providers[provider]() def get_available_embedding_models(provider: str) -> list[str]: """Get the list of available embedding models for a specific provider. This function retrieves available embedding models by calling the provider's embedder class get_available_models() method. It supports OpenAI, Gemini, and Vertex AI providers. Args: provider: Provider name (e.g., "openai", "gemini", "vertex_ai") Returns: List of available embedding model names for the provider Raises: ValueError: If the provider is not supported or doesn't support embeddings ImportError: If required dependencies for the provider are missing Examples: >>> # Get OpenAI embedding models >>> models = get_available_embedding_models("openai") >>> print(models) ['text-embedding-3-small', 'text-embedding-3-large', 'text-embedding-ada-002'] >>> # Get Gemini embedding models >>> models = get_available_embedding_models("gemini") >>> # Get Vertex AI embedding models >>> models = get_available_embedding_models("vertex_ai") """ if provider not in EMBEDDER_REGISTRY: available_providers = ", ".join(sorted(EMBEDDER_REGISTRY.keys())) raise ValueError( f"Embedding provider '{provider}' not supported. " f"Available embedding providers: {available_providers}" ) # Map of providers that support embedding model listing embedder_providers = { "openai": _get_openai_embedding_models, "gemini": _get_gemini_embedding_models, "vertex_ai": _get_vertex_ai_embedding_models, } if provider not in embedder_providers: raise ValueError( f"Provider '{provider}' does not support listing available embedding models. " f"Only the following providers support this feature: " f"{', '.join(sorted(embedder_providers.keys()))}" ) # Call the provider-specific function to get embedding models return embedder_providers[provider]() # Provider-specific functions to get available embedding models def _get_openai_embedding_models() -> list[str]: """Get available OpenAI embedding models.""" from rhesis.sdk.models.providers.openai import OpenAIEmbedder return OpenAIEmbedder.get_available_models() def _get_gemini_embedding_models() -> list[str]: """Get available Gemini embedding models.""" from rhesis.sdk.models.providers.gemini import GeminiEmbedder return GeminiEmbedder.get_available_models() def _get_vertex_ai_embedding_models() -> list[str]: """Get available Vertex AI embedding models.""" from rhesis.sdk.models.providers.vertex_ai import VertexAIEmbedder return VertexAIEmbedder.get_available_models() # Provider-specific functions to get available models def _get_anthropic_models() -> list[str]: """Get available Anthropic models.""" from rhesis.sdk.models.providers.anthropic import AnthropicLLM return AnthropicLLM.get_available_models() def _get_cohere_models() -> list[str]: """Get available Cohere models.""" from rhesis.sdk.models.providers.cohere import CohereLLM return CohereLLM.get_available_models() def _get_gemini_models() -> list[str]: """Get available Gemini models.""" from rhesis.sdk.models.providers.gemini import GeminiLLM return GeminiLLM.get_available_models() def _get_groq_models() -> list[str]: """Get available Groq models.""" from rhesis.sdk.models.providers.groq import GroqLLM return GroqLLM.get_available_models() def _get_meta_llama_models() -> list[str]: """Get available Meta Llama models.""" from rhesis.sdk.models.providers.meta_llama import MetaLlamaLLM return MetaLlamaLLM.get_available_models() def _get_mistral_models() -> list[str]: """Get available Mistral models.""" from rhesis.sdk.models.providers.mistral import MistralLLM return MistralLLM.get_available_models() def _get_ollama_models() -> list[str]: """Get available Ollama models.""" from rhesis.sdk.models.providers.ollama import OllamaLLM return OllamaLLM.get_available_models() def _get_openai_models() -> list[str]: """Get available OpenAI models.""" from rhesis.sdk.models.providers.openai import OpenAILLM return OpenAILLM.get_available_models() def _get_perplexity_models() -> list[str]: """Get available Perplexity models.""" from rhesis.sdk.models.providers.perplexity import PerplexityLLM return PerplexityLLM.get_available_models() def _get_replicate_models() -> list[str]: """Get available Replicate models.""" from rhesis.sdk.models.providers.replicate import ReplicateLLM return ReplicateLLM.get_available_models() def _get_together_ai_models() -> list[str]: """Get available Together AI models.""" from rhesis.sdk.models.providers.together_ai import TogetherAILLM return TogetherAILLM.get_available_models() def _get_vertex_ai_models() -> list[str]: """Get available Vertex AI models.""" from rhesis.sdk.models.providers.vertex_ai import VertexAILLM return VertexAILLM.get_available_models() def _get_openrouter_models() -> list[str]: """Get available OpenRouter models.""" from rhesis.sdk.models.providers.openrouter import OpenRouterLLM return OpenRouterLLM.get_available_models() # ============================================================================= # Embedder Factory # ============================================================================= def _create_openai_embedder( model_name: str, api_key: Optional[str], dimensions: Optional[int], **kwargs ) -> BaseEmbedder: """Factory function for OpenAIEmbedder.""" from rhesis.sdk.models.providers.openai import OpenAIEmbedder return OpenAIEmbedder(model_name=model_name, api_key=api_key, dimensions=dimensions, **kwargs) def _create_gemini_embedder( model_name: str, api_key: Optional[str], dimensions: Optional[int], **kwargs ) -> BaseEmbedder: """Factory function for GeminiEmbedder.""" from rhesis.sdk.models.providers.gemini import GeminiEmbedder return GeminiEmbedder(model_name=model_name, api_key=api_key, dimensions=dimensions, **kwargs) def _create_vertex_ai_embedder( model_name: str, api_key: Optional[str], dimensions: Optional[int], **kwargs ) -> BaseEmbedder: """Factory function for VertexAIEmbedder. Note: api_key is ignored for Vertex AI, which uses service account credentials. """ from rhesis.sdk.models.providers.vertex_ai import VertexAIEmbedder # Extract Vertex AI-specific parameters from kwargs credentials = kwargs.pop("credentials", None) location = kwargs.pop("location", None) project = kwargs.pop("project", None) return VertexAIEmbedder( model_name=model_name, credentials=credentials, location=location, project=project, dimensions=dimensions, **kwargs, ) # Embedder provider registry EMBEDDER_REGISTRY: Dict[str, Callable[..., BaseEmbedder]] = { "openai": _create_openai_embedder, "gemini": _create_gemini_embedder, "vertex_ai": _create_vertex_ai_embedder, } @dataclass class EmbedderConfig: """Configuration for an embedder instance. Args: provider: The provider name (e.g., "openai"). model_name: Specific model name (e.g., "text-embedding-3-small"). api_key: The API key to use for the embedder. dimensions: Optional embedding dimensions. extra_params: Extra parameters to pass to the embedder. """ provider: str | None = None model_name: str | None = None api_key: str | None = None dimensions: int | None = None extra_params: dict = field(default_factory=dict) def get_embedder( provider: Optional[str] = None, model_name: Optional[str] = None, api_key: Optional[str] = None, dimensions: Optional[int] = None, config: Optional[EmbedderConfig] = None, **kwargs, ) -> BaseEmbedder: """Create an embedder instance with smart defaults and comprehensive error handling. This function provides multiple ways to create an embedder instance: 1. **Minimal**: `get_embedder()` - uses all defaults (OpenAI text-embedding-3-small) 2. **Provider only**: `get_embedder("openai")` - uses default model for provider 3. **Provider + Model**: `get_embedder("openai", "text-embedding-3-large")` 4. **Shorthand**: `get_embedder("openai/text-embedding-3-large")` 5. **Full config**: `get_embedder(config=EmbedderConfig(...))` Args: provider: Provider name (e.g., "openai"). model_name: Specific embedding model name. api_key: API key for authentication. dimensions: Optional embedding dimensions (model-dependent). config: Complete configuration object. **kwargs: Additional parameters passed to the embedder. Returns: BaseEmbedder: Configured embedder instance. Raises: ValueError: If configuration is invalid or provider not supported. Examples: >>> # Basic usage with defaults >>> embedder = get_embedder() >>> # Specify provider and model >>> embedder = get_embedder("openai", "text-embedding-3-large") >>> # Use provider/model shorthand >>> embedder = get_embedder("openai/text-embedding-3-small") >>> # With dimensions >>> embedder = get_embedder("openai", dimensions=256) >>> # With custom configuration >>> config = EmbedderConfig( ... provider="openai", ... model_name="text-embedding-3-small", ... dimensions=512 ... ) >>> embedder = get_embedder(config=config) """ # Create configuration if config: for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) cfg = config else: cfg = EmbedderConfig() # Case: shorthand string like "provider/model" if provider and "/" in provider and model_name is None: prov, model = provider.split("/", 1) provider, model_name = prov, model provider = provider or cfg.provider or DEFAULT_EMBEDDER_PROVIDER if provider not in DEFAULT_EMBEDDING_MODELS: available = ", ".join(sorted(DEFAULT_EMBEDDING_MODELS.keys())) raise ValueError(f"Embedder provider '{provider}' not supported. Available: {available}") model_name = model_name or cfg.model_name or DEFAULT_EMBEDDING_MODELS[provider] api_key = api_key or cfg.api_key dimensions = dimensions or cfg.dimensions # Get the factory function for the provider factory_func = EMBEDDER_REGISTRY.get(provider) if factory_func is None: raise ValueError(f"Embedder provider '{provider}' not supported") return factory_func(model_name, api_key, dimensions, **kwargs)