from __future__ import annotations from datetime import datetime, timezone from enum import Enum from collections.abc import Iterable from pydantic import BaseModel, Field from sir import SemanticNode, SirSnapshot class KnowledgeScope(str, Enum): GLOBAL = "GLOBAL" WORKSPACE = "WORKSPACE" PROJECT = "PROJECT" SESSION = "SESSION" class KnowledgeRecord(BaseModel): record_id: str scope: KnowledgeScope title: str body: str tags: list[str] = Field(default_factory=list) related_lineages: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) attributes: dict = Field(default_factory=dict) class KnowledgePack(BaseModel): pack_id: str name: str vendor: str | None = None version: str | None = None description: str = "" records: list[KnowledgeRecord] = Field(default_factory=list) attributes: dict = Field(default_factory=dict) class KnowledgeSearchResult(BaseModel): record: KnowledgeRecord score: float matched_fields: list[str] class KnowledgeCoverageItem(BaseModel): node: SemanticNode record_count: int class InMemoryKnowledgeBase: def __init__(self) -> None: self._records: dict[str, KnowledgeRecord] = {} self._packs: dict[str, KnowledgePack] = {} def upsert(self, record: KnowledgeRecord) -> KnowledgeRecord: self._records[record.record_id] = record return record def import_pack(self, pack: KnowledgePack) -> KnowledgePack: self._packs[pack.pack_id] = pack for record in pack.records: tags = sorted({*record.tags, f"pack:{pack.pack_id}", *(["vendor:" + pack.vendor] if pack.vendor else [])}) attributes = { **record.attributes, "pack_id": pack.pack_id, "pack_name": pack.name, "pack_version": pack.version, "vendor": pack.vendor, } self.upsert(record.model_copy(update={"tags": tags, "attributes": attributes})) return pack def list_packs(self) -> list[KnowledgePack]: return sorted(self._packs.values(), key=lambda pack: (pack.vendor or "", pack.name, pack.version or "")) def get(self, record_id: str) -> KnowledgeRecord | None: return self._records.get(record_id) def list_records(self, scope: KnowledgeScope | None = None) -> list[KnowledgeRecord]: records = list(self._records.values()) if scope is not None: records = [record for record in records if record.scope == scope] return sorted(records, key=lambda record: (record.scope.value, record.title)) def search( self, query: str, *, scope: KnowledgeScope | None = None, limit: int = 20, ) -> list[KnowledgeSearchResult]: normalized = query.casefold().strip() if not normalized: return [] results: list[KnowledgeSearchResult] = [] for record in self.list_records(scope): score, fields = _score_record(record, normalized) if score > 0: results.append(KnowledgeSearchResult(record=record, score=score, matched_fields=fields)) results.sort(key=lambda item: (-item.score, -item.record.created_at.timestamp(), item.record.title)) return results[:limit] def coverage(self, snapshot: SirSnapshot) -> list[KnowledgeCoverageItem]: counts: dict[str, int] = {node.lineage_id: 0 for node in snapshot.nodes} for record in self._records.values(): for lineage_id in record.related_lineages: if lineage_id in counts: counts[lineage_id] += 1 return [ KnowledgeCoverageItem(node=node, record_count=counts[node.lineage_id]) for node in sorted(snapshot.nodes, key=lambda item: item.qualified_name) ] def _score_record(record: KnowledgeRecord, query: str) -> tuple[float, list[str]]: fields = { "title": record.title, "body": record.body, "tags": " ".join(record.tags), "related_lineages": " ".join(record.related_lineages), } score = 0.0 matched: list[str] = [] for field, value in fields.items(): field_score = _score_text(value, query) if field_score: score += field_score matched.append(field) for field, value in _attribute_search_fields(record.attributes): field_score = _score_text(value, query) if field_score: score += max(field_score - 1.0, 1.0) matched.append(field) return score, matched def _score_text(value: object, query: str) -> float: normalized = str(value).casefold() if normalized == query: return 10.0 if normalized.startswith(query): return 5.0 if query in normalized: return 2.0 return 0.0 def _attribute_search_fields(attributes: dict) -> Iterable[tuple[str, object]]: for key, value in sorted(attributes.items()): field = f"attributes.{key}" if isinstance(value, dict): for nested_key, nested_value in _attribute_search_fields(value): yield f"{field}.{nested_key.removeprefix('attributes.')}", nested_value elif isinstance(value, list): for index, item in enumerate(value): yield f"{field}[{index}]", item else: yield field, value __all__ = [ "InMemoryKnowledgeBase", "KnowledgeCoverageItem", "KnowledgePack", "KnowledgeRecord", "KnowledgeScope", "KnowledgeSearchResult", ]