Source code for rhesis.sdk.entities.model

from typing import TYPE_CHECKING, Any, ClassVar, Dict, Literal, Optional, Union

from pydantic import model_validator

from rhesis.sdk.clients import APIClient, Endpoints, Methods

if TYPE_CHECKING:
    from rhesis.sdk.models.base import BaseEmbedder, BaseLLM
from rhesis.sdk.entities.base_collection import BaseCollection
from rhesis.sdk.entities.base_entity import BaseEntity

ENDPOINT = Endpoints.MODELS


[docs] class Model(BaseEntity): """ Model entity for interacting with the Rhesis API. Models represent AI model configurations (language models or embeddings) that can be used for generation, evaluation, embedding, and other AI-powered tasks. Each model configuration includes the provider, model name, and API key. Examples: Create a new language model: >>> model = Model( ... name="GPT-4 Production", ... provider="openai", ... model_name="gpt-4", ... key="sk-..." ... ) >>> model.push() Create an embedding model: >>> model = Model( ... name="OpenAI Embeddings", ... provider="openai", ... model_name="text-embedding-3-small", ... model_type="embedding", ... key="sk-..." ... ) >>> model.push() Load an existing model: >>> model = Models.pull(name="GPT-4 Production") >>> print(model.model_name) List all models: >>> models = Models.all() >>> for m in models: ... print(m.name, m.model_type, m.model_name) Supported providers: - openai, anthropic, gemini, mistral, cohere, groq - vertex_ai, together_ai, replicate, perplexity - ollama, vllm (for self-hosted models) """ endpoint: ClassVar[Endpoints] = ENDPOINT # Core identification id: Optional[str] = None name: Optional[str] = None description: Optional[str] = None # Model configuration provider: Optional[str] = None # Provider name (e.g., "openai", "anthropic") model_name: Optional[str] = None model_type: Optional[Literal["llm", "embedding"]] = "llm" key: Optional[str] = None # Provider API key # Relationships (resolved automatically from provider) provider_type_id: Optional[str] = None status_id: Optional[str] = None @model_validator(mode="after") def _set_default_description(self) -> "Model": """Set default description based on provider if not provided.""" if self.description is None and self.provider: self.description = f"{self.provider.title()} model connection" return self def _resolve_provider_type_id(self) -> Optional[str]: """Resolve provider name to provider_type_id via API lookup.""" if not self.provider: return self.provider_type_id client = APIClient() # Query type_lookups for provider with matching type_value and type_name filter_query = ( f"type_name eq 'ProviderType' and tolower(type_value) eq '{self.provider.lower()}'" ) response = client.send_request( endpoint=Endpoints.TYPE_LOOKUPS, method=Methods.GET, params={"$filter": filter_query}, ) if response and len(response) > 0: return response[0].get("id") available_providers = Models.list_providers() raise ValueError( f"Unsupported provider '{self.provider}'. " f"Supported providers: {', '.join(available_providers)}" )
[docs] def push(self) -> Optional[Dict[str, Any]]: """Save the model to the platform. If a provider name is set, it will be automatically resolved to the provider_type_id before saving. The icon is automatically set based on the provider. """ # Validate provider is set if not self.provider and not self.provider_type_id: available_providers = Models.list_providers() raise ValueError( f"Provider is required. Supported providers: {', '.join(available_providers)}" ) # Resolve provider name to provider_type_id if needed if self.provider and not self.provider_type_id: self.provider_type_id = self._resolve_provider_type_id() # Build data dict and add icon (not exposed as a user field) data = self.model_dump(mode="json") # Set icon to provider value (same as frontend does) if self.provider: data["icon"] = self.provider.lower() if "id" in data and data["id"] is not None: response = self._update(data["id"], data) else: response = self._create(data) self.id = response["id"] return response
[docs] def set_default_generation(self) -> None: """Set this model as the default for test generation. This updates the current user's settings to use this model when generating new test cases. Raises: ValueError: If model ID is not set (model must be saved first) Example: >>> model = Models.pull(name="GPT-4 Production") >>> model.set_default_generation() """ if not self.id: raise ValueError("Model must be saved before setting as default. Call push() first.") client = APIClient() client.send_request( endpoint=Endpoints.USERS, method=Methods.PATCH, url_params="settings", data={"models": {"generation": {"model_id": self.id}}}, )
[docs] def set_default_evaluation(self) -> None: """Set this model as the default for evaluation (LLM as Judge). This updates the current user's settings to use this model when running metrics and evaluations. Raises: ValueError: If model ID is not set (model must be saved first) Example: >>> model = Models.pull(name="GPT-4 Production") >>> model.set_default_evaluation() """ if not self.id: raise ValueError("Model must be saved before setting as default. Call push() first.") client = APIClient() client.send_request( endpoint=Endpoints.USERS, method=Methods.PATCH, url_params="settings", data={"models": {"evaluation": {"model_id": self.id}}}, )
[docs] def set_default_embedding(self) -> None: """Set this model as the default for embedding generation. This updates the current user's settings to use this model when generating embeddings for semantic search and similarity. Raises: ValueError: If model ID is not set (model must be saved first) Example: >>> model = Models.pull(name="OpenAI Embeddings") >>> model.set_default_embedding() """ if not self.id: raise ValueError("Model must be saved before setting as default. Call push() first.") client = APIClient() client.send_request( endpoint=Endpoints.USERS, method=Methods.PATCH, url_params="settings", data={"models": {"embedding": {"model_id": self.id}}}, )
[docs] def get_model_instance( self, ) -> "Union[BaseLLM, BaseEmbedder]": """Create a model instance configured with this model's settings. Returns a ready-to-use LLM or embedder client based on the model_type. Uses the provider, model name, and API key from this entity. Returns: BaseLLM or BaseEmbedder: Ready-to-use model instance Raises: ValueError: If provider or model_name is not set Example: >>> model = Models.pull(name="GPT-4 Production") >>> llm = model.get_model_instance() >>> response = llm.generate("Hello, how are you?") >>> model = Models.pull(name="OpenAI Embeddings") >>> embedder = model.get_model_instance() >>> vector = embedder.generate("Hello, world!") """ if not self.provider: raise ValueError("Provider is required to create a model instance") if not self.model_name: raise ValueError("Model name is required to create a model instance") if self.model_type == "embedding": from rhesis.sdk.models.factory import get_embedder return get_embedder( provider=self.provider, model_name=self.model_name, api_key=self.key, ) from rhesis.sdk.models.factory import get_model return get_model( provider=self.provider, model_name=self.model_name, api_key=self.key, )
[docs] class Models(BaseCollection): """Collection class for Model entities.""" endpoint = ENDPOINT entity_class = Model
[docs] @classmethod def list_providers(cls) -> list[str]: """List available provider names. Returns: List of provider names that can be used when creating models. Example: >>> providers = Models.list_providers() >>> print(providers) ['openai', 'anthropic', 'gemini', 'mistral', ...] """ client = APIClient() response = client.send_request( endpoint=Endpoints.TYPE_LOOKUPS, method=Methods.GET, params={"$filter": "type_name eq 'ProviderType'", "limit": 100}, ) if response: return [item.get("type_value") for item in response if item.get("type_value")] return []