Files
sfera/packages/knowledge-base/src/knowledge_base/__init__.py
T
2026-05-16 19:03:49 +03:00

170 lines
5.5 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from enum import Enum
from collections.abc import Iterable
from pydantic import BaseModel, Field
from sir import SemanticNode, SirSnapshot
class KnowledgeScope(str, Enum):
GLOBAL = "GLOBAL"
WORKSPACE = "WORKSPACE"
PROJECT = "PROJECT"
SESSION = "SESSION"
class KnowledgeRecord(BaseModel):
record_id: str
scope: KnowledgeScope
title: str
body: str
tags: list[str] = Field(default_factory=list)
related_lineages: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
attributes: dict = Field(default_factory=dict)
class KnowledgePack(BaseModel):
pack_id: str
name: str
vendor: str | None = None
version: str | None = None
description: str = ""
records: list[KnowledgeRecord] = Field(default_factory=list)
attributes: dict = Field(default_factory=dict)
class KnowledgeSearchResult(BaseModel):
record: KnowledgeRecord
score: float
matched_fields: list[str]
class KnowledgeCoverageItem(BaseModel):
node: SemanticNode
record_count: int
class InMemoryKnowledgeBase:
def __init__(self) -> None:
self._records: dict[str, KnowledgeRecord] = {}
self._packs: dict[str, KnowledgePack] = {}
def upsert(self, record: KnowledgeRecord) -> KnowledgeRecord:
self._records[record.record_id] = record
return record
def import_pack(self, pack: KnowledgePack) -> KnowledgePack:
self._packs[pack.pack_id] = pack
for record in pack.records:
tags = sorted({*record.tags, f"pack:{pack.pack_id}", *(["vendor:" + pack.vendor] if pack.vendor else [])})
attributes = {
**record.attributes,
"pack_id": pack.pack_id,
"pack_name": pack.name,
"pack_version": pack.version,
"vendor": pack.vendor,
}
self.upsert(record.model_copy(update={"tags": tags, "attributes": attributes}))
return pack
def list_packs(self) -> list[KnowledgePack]:
return sorted(self._packs.values(), key=lambda pack: (pack.vendor or "", pack.name, pack.version or ""))
def get(self, record_id: str) -> KnowledgeRecord | None:
return self._records.get(record_id)
def list_records(self, scope: KnowledgeScope | None = None) -> list[KnowledgeRecord]:
records = list(self._records.values())
if scope is not None:
records = [record for record in records if record.scope == scope]
return sorted(records, key=lambda record: (record.scope.value, record.title))
def search(
self,
query: str,
*,
scope: KnowledgeScope | None = None,
limit: int = 20,
) -> list[KnowledgeSearchResult]:
normalized = query.casefold().strip()
if not normalized:
return []
results: list[KnowledgeSearchResult] = []
for record in self.list_records(scope):
score, fields = _score_record(record, normalized)
if score > 0:
results.append(KnowledgeSearchResult(record=record, score=score, matched_fields=fields))
results.sort(key=lambda item: (-item.score, -item.record.created_at.timestamp(), item.record.title))
return results[:limit]
def coverage(self, snapshot: SirSnapshot) -> list[KnowledgeCoverageItem]:
counts: dict[str, int] = {node.lineage_id: 0 for node in snapshot.nodes}
for record in self._records.values():
for lineage_id in record.related_lineages:
if lineage_id in counts:
counts[lineage_id] += 1
return [
KnowledgeCoverageItem(node=node, record_count=counts[node.lineage_id])
for node in sorted(snapshot.nodes, key=lambda item: item.qualified_name)
]
def _score_record(record: KnowledgeRecord, query: str) -> tuple[float, list[str]]:
fields = {
"title": record.title,
"body": record.body,
"tags": " ".join(record.tags),
"related_lineages": " ".join(record.related_lineages),
}
score = 0.0
matched: list[str] = []
for field, value in fields.items():
field_score = _score_text(value, query)
if field_score:
score += field_score
matched.append(field)
for field, value in _attribute_search_fields(record.attributes):
field_score = _score_text(value, query)
if field_score:
score += max(field_score - 1.0, 1.0)
matched.append(field)
return score, matched
def _score_text(value: object, query: str) -> float:
normalized = str(value).casefold()
if normalized == query:
return 10.0
if normalized.startswith(query):
return 5.0
if query in normalized:
return 2.0
return 0.0
def _attribute_search_fields(attributes: dict) -> Iterable[tuple[str, object]]:
for key, value in sorted(attributes.items()):
field = f"attributes.{key}"
if isinstance(value, dict):
for nested_key, nested_value in _attribute_search_fields(value):
yield f"{field}.{nested_key.removeprefix('attributes.')}", nested_value
elif isinstance(value, list):
for index, item in enumerate(value):
yield f"{field}[{index}]", item
else:
yield field, value
__all__ = [
"InMemoryKnowledgeBase",
"KnowledgeCoverageItem",
"KnowledgePack",
"KnowledgeRecord",
"KnowledgeScope",
"KnowledgeSearchResult",
]