86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from sir import SemanticNode, SirSnapshot
|
|
|
|
|
|
class SearchResult(BaseModel):
|
|
node: SemanticNode
|
|
score: float
|
|
matched_fields: list[str]
|
|
|
|
|
|
def search_snapshot(
|
|
snapshot: SirSnapshot,
|
|
query: str,
|
|
*,
|
|
kinds: set[str] | None = None,
|
|
limit: int = 20,
|
|
) -> list[SearchResult]:
|
|
normalized_query = query.casefold().strip()
|
|
if not normalized_query:
|
|
return []
|
|
|
|
results: list[SearchResult] = []
|
|
for node in snapshot.nodes:
|
|
if kinds is not None and node.kind.value not in kinds:
|
|
continue
|
|
score, fields = _score_node(node, normalized_query)
|
|
if score > 0:
|
|
results.append(SearchResult(node=node, score=score, matched_fields=fields))
|
|
|
|
results.sort(key=lambda result: (-result.score, result.node.qualified_name))
|
|
return results[:limit]
|
|
|
|
|
|
def _score_node(node: SemanticNode, query: str) -> tuple[float, list[str]]:
|
|
fields = {
|
|
"name": node.name,
|
|
"qualified_name": node.qualified_name,
|
|
"kind": node.kind.value,
|
|
"source_path": node.source_ref.source_path,
|
|
}
|
|
score = 0.0
|
|
matched: list[str] = []
|
|
for field, value in fields.items():
|
|
field_score = _score_text(value, query)
|
|
if field_score:
|
|
score += field_score
|
|
matched.append(field)
|
|
for field, value in _attribute_search_fields(node.attributes):
|
|
field_score = _score_text(value, query)
|
|
if field_score:
|
|
score += max(field_score - 1.0, 1.0)
|
|
matched.append(field)
|
|
return score, matched
|
|
|
|
|
|
def _score_text(value: object, query: str) -> float:
|
|
normalized = str(value).casefold()
|
|
if normalized == query:
|
|
return 10.0
|
|
if normalized.startswith(query):
|
|
return 5.0
|
|
if query in normalized:
|
|
return 2.0
|
|
return 0.0
|
|
|
|
|
|
def _attribute_search_fields(attributes: dict) -> Iterable[tuple[str, object]]:
|
|
for key, value in sorted(attributes.items()):
|
|
field = f"attributes.{key}"
|
|
if isinstance(value, dict):
|
|
for nested_key, nested_value in _attribute_search_fields(value):
|
|
yield f"{field}.{nested_key.removeprefix('attributes.')}", nested_value
|
|
elif isinstance(value, list):
|
|
for index, item in enumerate(value):
|
|
yield f"{field}[{index}]", item
|
|
else:
|
|
yield field, value
|
|
|
|
|
|
__all__ = ["SearchResult", "search_snapshot"]
|