Source code for rhesis.sdk.services.chunker

"""Context generator service for creating context from various sources."""

import re
from abc import ABC, abstractmethod
from typing import List

from pydantic import BaseModel

from rhesis.sdk.services.extractor import ExtractedSource, SourceSpecification
from rhesis.sdk.utils import count_tokens


class Chunk(BaseModel):
    """A chunk of text with a source metadata and a content"""

    source: SourceSpecification
    content: str


[docs] class ChunkingStrategy(ABC): """Abstract base class for chunkers."""
[docs] @abstractmethod def chunk(self, text: str) -> List[str]: """Chunk the text into a list of chunks.""" pass
[docs] class ChunkingService: """Chunk sources using a selected chunking strategy."""
[docs] def __init__(self, sources: list[ExtractedSource], strategy: ChunkingStrategy): self.sources = sources self.strategy = strategy
[docs] def chunk(self) -> List[Chunk]: chunks = [] for source in self.sources: text_chunks = self.strategy.chunk(source.content) source_metadata = SourceSpecification(**source.model_dump()) for chunk in text_chunks: chunks.append(Chunk(source=source_metadata, content=chunk)) return chunks
class IdentityChunker(ChunkingStrategy): """No chunking strategy.""" def chunk(self, text: str) -> List[str]: """No chunking.""" return [text]
[docs] class SemanticChunker(ChunkingStrategy): """Service for generating chunks of text from various sources using intelligent semantic chunking."""
[docs] def __init__( self, max_tokens_per_chunk: int = 1500, ): """ Initialize the context generator. Args: max_context_tokens: Maximum tokens per context (user preference) """ self.max_context_tokens = min(max_tokens_per_chunk, 3000) if max_tokens_per_chunk > 3000: print(f"⚠️ Context size capped at 3000 tokens (you requested {max_tokens_per_chunk})")
[docs] def chunk(self, text: str) -> List[str]: """ Generate contexts using intelligent semantic chunking with hard size limits. Strategy: 1. Identify semantic boundaries (headers, sections, paragraphs) 2. Create contexts that respect these boundaries 3. Enforce hard token limit; if a semantic span exceeds the context limit, split it abruptly 4. If there are no internal boundaries, slice the text into token-capped windows """ if not text: raise ValueError("Cannot generate contexts from empty text") text = text.strip() semantic_boundaries = self._identify_semantic_boundaries(text) # If no internal boundaries (just [0, len(text)]), slice linearly if len(semantic_boundaries) <= 2: contexts: List[str] = [] start_pos = 0 while start_pos < len(text): end_pos = self._find_token_capped_end(text, start_pos, len(text)) chunk = text[start_pos:end_pos].strip() if chunk: contexts.append(chunk) if end_pos <= start_pos: break start_pos = end_pos return contexts # Create contexts from semantic boundaries with abrupt splits when needed contexts = self._create_contexts_from_boundaries(text, semantic_boundaries) return contexts
def _identify_semantic_boundaries(self, text: str) -> List[int]: """Identify semantic boundaries in the text.""" boundaries = [0] # Start of text lines = text.split("\n") current_pos = 0 for line in lines: line_length = len(line) + 1 # +1 for newline # Check for markdown headers if re.match(r"^#{1,6}\s+", line): boundaries.append(current_pos) # Check for section separators elif re.match(r"^[-*_]{3,}$", line): boundaries.append(current_pos) # Check for major paragraph breaks (double newlines) elif line.strip() == "" and current_pos > 0: # Look ahead to see if this is a paragraph break next_non_empty = current_pos + line_length while ( next_non_empty < len(text) and text[next_non_empty : next_non_empty + 1].isspace() ): next_non_empty += 1 if next_non_empty < len(text) and text[next_non_empty : next_non_empty + 1] not in [ "#", "-", "*", "_", ]: boundaries.append(current_pos) current_pos += line_length # Add end of text boundaries.append(len(text)) return boundaries def _create_contexts_from_boundaries(self, text: str, boundaries: List[int]) -> List[str]: """Create contexts from semantic boundaries.""" contexts: List[str] = [] # Start from the first boundary and create contexts sequentially start_idx = 0 while start_idx < len(boundaries) - 1: # Find the best end boundary for this context end_idx = self._find_best_end_boundary(boundaries, start_idx, text) if end_idx <= start_idx: break # Extract context text start_pos = boundaries[start_idx] end_pos = boundaries[end_idx] context_text = text[start_pos:end_pos].strip() if context_text: # If the single span between adjacent boundaries exceeds token limit, # split it abruptly into token-capped windows span_tokens = count_tokens(text[start_pos:end_pos]) if span_tokens is None: raise ValueError("Failed to count tokens - text may be malformed or invalid") if end_idx == start_idx + 1 and span_tokens > self.max_context_tokens: local_start = start_pos while local_start < end_pos: local_end = self._find_token_capped_end(text, local_start, end_pos) piece = text[local_start:local_end].strip() if piece: contexts.append(piece) if local_end <= local_start: break local_start = local_end else: contexts.append(context_text) # Move to the next boundary start_idx = end_idx return contexts def _find_best_end_boundary(self, boundaries: List[int], start_idx: int, text: str) -> int: """Find the best end boundary for a context.""" start_pos = boundaries[start_idx] # Find the furthest boundary within size limit for i in range(start_idx + 1, len(boundaries)): end_pos = boundaries[i] token_len = count_tokens(text[start_pos:end_pos]) if token_len is None: raise ValueError("Failed to count tokens - text may be malformed or invalid") if token_len > self.max_context_tokens: # We've exceeded the limit, go back one if i > start_idx + 1: return i - 1 else: # Even the smallest context is too big, use it anyway return i # Use the last boundary return len(boundaries) - 1 def _find_token_capped_end(self, text: str, start_pos: int, hard_end: int) -> int: """Find the furthest end index within hard_end that stays under the token limit.""" low = start_pos + 1 high = hard_end best = None while low <= high: mid = (low + high) // 2 tokens = count_tokens(text[start_pos:mid]) if tokens is None: raise ValueError("Failed to count tokens - text may be malformed or invalid") if tokens <= self.max_context_tokens: best = mid low = mid + 1 else: high = mid - 1 if best is None: return min(start_pos + 1, hard_end) return best