from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
if TYPE_CHECKING:
from rhesis.sdk.entities.model import Model
# Type alias for embeddings
Embedding = List[float]
[docs]
class BaseLLM(ABC):
PROVIDER: str = "" # Subclasses should override this
[docs]
def __init__(self, model_name, *args, **kwargs):
self.model_name = model_name
self.model = self.load_model(*args, **kwargs)
[docs]
@abstractmethod
def load_model(self, *args, **kwargs):
"""Loads a model
Returns:
A model object
"""
pass
[docs]
@abstractmethod
def generate(self, *args, **kwargs) -> Union[str, Dict[str, Any]]:
"""Runs the model to output LLM response.
Returns:
A string or dict (if schema provided).
"""
pass
[docs]
@abstractmethod
def generate_batch(self, *args, **kwargs) -> List[Union[str, Dict[str, Any]]]:
"""Run model on multiple prompts to output LLM responses.
Returns:
A list of strings or dicts (if schema provided).
"""
pass
[docs]
def get_model_name(self, *args, **kwargs) -> str:
return f"Class name: {self.__class__.__name__}, model name: {self.model_name}"
[docs]
def get_available_models(self) -> List[str]:
raise NotImplementedError("Subclasses must implement this method")
[docs]
def push(self, name: str, description: Optional[str] = None) -> "Model":
"""Save this LLM configuration to the Rhesis platform as a Model entity.
Creates a Model entity with this LLM's provider, model name, and API key,
then saves it to the platform.
Args:
name: Name for the saved model configuration (required)
description: Optional description for the model
Returns:
Model: The created Model entity (can be used for set_default_generation, etc.)
Raises:
ValueError: If provider is not set on this LLM class
Example:
>>> from rhesis.sdk.models.factory import get_model
>>> llm = get_model("openai", "gpt-4", api_key="sk-...")
>>> model = llm.push(name="My GPT-4 Production")
>>> model.set_default_generation()
"""
from rhesis.sdk.entities.model import Model
provider = getattr(self, "PROVIDER", None)
if not provider:
raise ValueError(
"Cannot push LLM: PROVIDER class variable is not set. "
"This LLM implementation does not support push()."
)
# Extract model name (remove provider prefix if present, e.g., "openai/gpt-4" -> "gpt-4")
model_name = (
self.model_name.split("/", 1)[-1]
if self.model_name and "/" in self.model_name
else self.model_name
)
# Get API key if available
api_key = getattr(self, "api_key", None)
model = Model(
name=name,
description=description,
provider=provider,
model_name=model_name,
model_type="llm",
key=api_key,
)
model.push()
return model
class BaseEmbedder(ABC):
"""Base class for embedding models."""
def __init__(self, model_name: str, *args, **kwargs):
self.model_name = model_name
@abstractmethod
def generate(self, text: str, **kwargs) -> Embedding:
"""Generate embedding for a single text.
Args:
text: The input text to embed.
**kwargs: Additional parameters (e.g., dimensions).
Returns:
A list of floats representing the embedding vector.
"""
pass
@abstractmethod
def generate_batch(self, texts: List[str], **kwargs) -> List[Embedding]:
"""Generate embeddings for multiple texts.
Args:
texts: List of input texts to embed.
**kwargs: Additional parameters (e.g., dimensions).
Returns:
A list of embedding vectors, one for each input text.
"""
pass
def get_model_name(self) -> str:
return f"Class name: {self.__class__.__name__}, model name: {self.model_name}"
def get_available_models(self) -> List[str]:
"""Get the list of available embedding models for this provider.
Subclasses should override this method to return provider-specific embedding models.
Returns:
List of embedding model names
Raises:
NotImplementedError: If the subclass doesn't implement this method
"""
raise NotImplementedError("Subclasses must implement this method")
def push(self, name: str, description: Optional[str] = None):
"""Save this embedder configuration to the Rhesis platform as a Model entity.
Creates a Model entity with this embedder's provider, model name, and API key,
then saves it to the platform with model_type="embedding".
Args:
name: Name for the saved model configuration (required)
description: Optional description for the model
Returns:
Model: The created Model entity
Raises:
ValueError: If provider is not set on this embedder class
Example:
>>> embedder = get_embedder("openai", "text-embedding-3-small", api_key="sk-...")
>>> model = embedder.push(name="My OpenAI Embeddings")
>>> model.set_default_embedding()
"""
from rhesis.sdk.entities.model import Model
provider = getattr(self, "PROVIDER", None)
if not provider:
raise ValueError(
"Cannot push embedder: PROVIDER class variable is not set. "
"This embedder implementation does not support push()."
)
# Extract model name (remove provider prefix if present)
model_name = (
self.model_name.split("/", 1)[-1]
if self.model_name and "/" in self.model_name
else self.model_name
)
# Get API key if available
api_key = getattr(self, "api_key", None)
model = Model(
name=name,
description=description,
provider=provider,
model_name=model_name,
model_type="embedding",
key=api_key,
)
model.push()
return model