Files
sfera/packages/projection-engine/src/projection_engine/__init__.py
T
2026-05-16 19:03:49 +03:00

238 lines
9.3 KiB
Python

from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
import json
from sir import EdgeKind, SemanticEdge, SemanticNode, SirDelta, SirSnapshot
@dataclass
class InMemoryProjection:
nodes: dict[str, SemanticNode] = field(default_factory=dict)
edges: dict[str, SemanticEdge] = field(default_factory=dict)
def project_snapshot(self, snapshot: SirSnapshot) -> None:
self.nodes = {node.lineage_id: node for node in snapshot.nodes}
self.edges = {edge.edge_id: edge for edge in snapshot.edges}
def apply_delta(self, delta: SirDelta) -> None:
for lineage_id in delta.removed_nodes:
self.nodes.pop(lineage_id, None)
self.edges = {
edge_id: edge
for edge_id, edge in self.edges.items()
if edge.source_lineage != lineage_id and edge.target_lineage != lineage_id
}
for edge_id in delta.removed_edges:
self.edges.pop(edge_id, None)
for node in [*delta.added_nodes, *delta.updated_nodes]:
self.nodes[node.lineage_id] = node
for edge in delta.added_edges:
self.edges[edge.edge_id] = edge
def find_procedures(self) -> list[SemanticNode]:
return [
node
for node in self.nodes.values()
if node.kind.value in {"PROCEDURE", "FUNCTION"}
]
def find_callers(self, routine_name: str) -> list[SemanticNode]:
target_ids = self._routine_lineages(routine_name)
caller_ids = [
edge.source_lineage
for edge in self.edges.values()
if edge.kind == EdgeKind.CALLS and edge.target_lineage in target_ids
]
return [self.nodes[lineage] for lineage in caller_ids if lineage in self.nodes]
def find_callees(self, routine_name: str) -> list[SemanticNode]:
source_ids = self._routine_lineages(routine_name)
callee_ids = [
edge.target_lineage
for edge in self.edges.values()
if edge.kind == EdgeKind.CALLS and edge.source_lineage in source_ids
]
return [self.nodes[lineage] for lineage in callee_ids if lineage in self.nodes]
def find_query_tables(self, routine_name: str) -> list[SemanticNode]:
routine_ids = self._routine_lineages(routine_name)
query_ids = [
edge.target_lineage
for edge in self.edges.values()
if edge.kind == EdgeKind.OWNS_QUERY and edge.source_lineage in routine_ids
]
table_ids = [
edge.target_lineage
for edge in self.edges.values()
if edge.kind == EdgeKind.READS_TABLE and edge.source_lineage in query_ids
]
return [self.nodes[lineage] for lineage in table_ids if lineage in self.nodes]
def find_writes(self, routine_name: str) -> list[SemanticNode]:
routine_ids = self._routine_lineages(routine_name)
write_ids = [
edge.target_lineage
for edge in self.edges.values()
if edge.kind == EdgeKind.WRITES and edge.source_lineage in routine_ids
]
return [self.nodes[lineage] for lineage in write_ids if lineage in self.nodes]
def _routine_lineages(self, routine_name: str) -> set[str]:
wanted = routine_name.lower()
return {
node.lineage_id
for node in self.nodes.values()
if node.name.lower() == wanted and node.kind.value in {"PROCEDURE", "FUNCTION"}
}
class Neo4jProjection:
def __init__(self, driver) -> None:
self._driver = driver
async def ensure_schema(self) -> None:
async with self._driver.session() as session:
await session.run(
"""
CREATE CONSTRAINT sfera_node_lineage IF NOT EXISTS
FOR (n:SferaNode) REQUIRE n.lineage_id IS UNIQUE
"""
)
await session.run(
"""
CREATE INDEX sfera_node_project_name IF NOT EXISTS
FOR (n:SferaNode) ON (n.project_id, n.name)
"""
)
await session.run(
"""
CREATE INDEX sfera_edge_kind IF NOT EXISTS
FOR ()-[r:SEMANTIC_EDGE]-() ON (r.kind)
"""
)
async def project_snapshot(self, snapshot: SirSnapshot) -> None:
await self.ensure_schema()
async with self._driver.session() as session:
for node in snapshot.nodes:
await self._merge_node(session, node, snapshot.project_id)
for edge in snapshot.edges:
await self._merge_edge(session, edge)
async def apply_delta(self, delta: SirDelta, *, project_id: str) -> None:
await self.ensure_schema()
async with self._driver.session() as session:
for lineage_id in delta.removed_nodes:
await session.run(
"""
MATCH (n:SferaNode {project_id: $project_id, lineage_id: $lineage_id})
DETACH DELETE n
""",
project_id=project_id,
lineage_id=lineage_id,
)
for edge_id in delta.removed_edges:
await session.run(
"""
MATCH ()-[r:SEMANTIC_EDGE {edge_id: $edge_id}]->()
DELETE r
""",
edge_id=edge_id,
)
for node in [*delta.added_nodes, *delta.updated_nodes]:
await self._merge_node(session, node, project_id)
for edge in delta.added_edges:
await self._merge_edge(session, edge)
async def clear_project(self, project_id: str) -> None:
async with self._driver.session() as session:
await session.run(
"""
MATCH (n:SferaNode {project_id: $project_id})
DETACH DELETE n
""",
project_id=project_id,
)
async def counts(self, project_id: str | None = None) -> dict[str, int]:
async with self._driver.session() as session:
if project_id is None:
node_result = await session.run("MATCH (n:SferaNode) RETURN count(n) AS count")
edge_result = await session.run(
"MATCH ()-[r:SEMANTIC_EDGE]->() RETURN count(r) AS count"
)
else:
node_result = await session.run(
"""
MATCH (n:SferaNode {project_id: $project_id})
RETURN count(n) AS count
""",
project_id=project_id,
)
edge_result = await session.run(
"""
MATCH (:SferaNode {project_id: $project_id})-[r:SEMANTIC_EDGE]->
(:SferaNode {project_id: $project_id})
RETURN count(r) AS count
""",
project_id=project_id,
)
node_record = await node_result.single()
edge_record = await edge_result.single()
return {
"nodes": int(node_record["count"]) if node_record else 0,
"edges": int(edge_record["count"]) if edge_record else 0,
}
async def _merge_node(self, session, node: SemanticNode, project_id: str) -> None:
await session.run(
"""
MERGE (n:SferaNode {lineage_id: $lineage_id})
SET n.semantic_id = $semantic_id,
n.project_id = $project_id,
n.kind = $kind,
n.name = $name,
n.qualified_name = $qualified_name,
n.attributes_json = $attributes_json,
n.source_ref_json = $source_ref_json
""",
lineage_id=node.lineage_id,
semantic_id=node.semantic_id,
project_id=project_id,
kind=node.kind.value,
name=node.name,
qualified_name=node.qualified_name,
attributes_json=json.dumps(node.attributes, ensure_ascii=False, sort_keys=True),
source_ref_json=node.source_ref.model_dump_json(exclude_none=True),
)
async def _merge_edge(self, session, edge: SemanticEdge) -> None:
await session.run(
"""
MATCH (source:SferaNode {lineage_id: $source_lineage})
MATCH (target:SferaNode {lineage_id: $target_lineage})
MERGE (source)-[r:SEMANTIC_EDGE {edge_id: $edge_id}]->(target)
SET r.kind = $kind,
r.attributes_json = $attributes_json,
r.source_ref_json = $source_ref_json
""",
source_lineage=edge.source_lineage,
target_lineage=edge.target_lineage,
edge_id=edge.edge_id,
kind=edge.kind.value,
attributes_json=json.dumps(edge.attributes, ensure_ascii=False, sort_keys=True),
source_ref_json=edge.source_ref.model_dump_json(exclude_none=True) if edge.source_ref else None,
)
def build_adjacency(snapshot: SirSnapshot) -> dict[str, list[SemanticEdge]]:
adjacency: dict[str, list[SemanticEdge]] = defaultdict(list)
for edge in snapshot.edges:
adjacency[edge.source_lineage].append(edge)
return dict(adjacency)
__all__ = ["InMemoryProjection", "Neo4jProjection", "build_adjacency"]