from __future__ import annotations from collections import defaultdict from dataclasses import dataclass, field import json from sir import EdgeKind, SemanticEdge, SemanticNode, SirDelta, SirSnapshot @dataclass class InMemoryProjection: nodes: dict[str, SemanticNode] = field(default_factory=dict) edges: dict[str, SemanticEdge] = field(default_factory=dict) def project_snapshot(self, snapshot: SirSnapshot) -> None: self.nodes = {node.lineage_id: node for node in snapshot.nodes} self.edges = {edge.edge_id: edge for edge in snapshot.edges} def apply_delta(self, delta: SirDelta) -> None: for lineage_id in delta.removed_nodes: self.nodes.pop(lineage_id, None) self.edges = { edge_id: edge for edge_id, edge in self.edges.items() if edge.source_lineage != lineage_id and edge.target_lineage != lineage_id } for edge_id in delta.removed_edges: self.edges.pop(edge_id, None) for node in [*delta.added_nodes, *delta.updated_nodes]: self.nodes[node.lineage_id] = node for edge in delta.added_edges: self.edges[edge.edge_id] = edge def find_procedures(self) -> list[SemanticNode]: return [ node for node in self.nodes.values() if node.kind.value in {"PROCEDURE", "FUNCTION"} ] def find_callers(self, routine_name: str) -> list[SemanticNode]: target_ids = self._routine_lineages(routine_name) caller_ids = [ edge.source_lineage for edge in self.edges.values() if edge.kind == EdgeKind.CALLS and edge.target_lineage in target_ids ] return [self.nodes[lineage] for lineage in caller_ids if lineage in self.nodes] def find_callees(self, routine_name: str) -> list[SemanticNode]: source_ids = self._routine_lineages(routine_name) callee_ids = [ edge.target_lineage for edge in self.edges.values() if edge.kind == EdgeKind.CALLS and edge.source_lineage in source_ids ] return [self.nodes[lineage] for lineage in callee_ids if lineage in self.nodes] def find_query_tables(self, routine_name: str) -> list[SemanticNode]: routine_ids = self._routine_lineages(routine_name) query_ids = [ edge.target_lineage for edge in self.edges.values() if edge.kind == EdgeKind.OWNS_QUERY and edge.source_lineage in routine_ids ] table_ids = [ edge.target_lineage for edge in self.edges.values() if edge.kind == EdgeKind.READS_TABLE and edge.source_lineage in query_ids ] return [self.nodes[lineage] for lineage in table_ids if lineage in self.nodes] def find_writes(self, routine_name: str) -> list[SemanticNode]: routine_ids = self._routine_lineages(routine_name) write_ids = [ edge.target_lineage for edge in self.edges.values() if edge.kind == EdgeKind.WRITES and edge.source_lineage in routine_ids ] return [self.nodes[lineage] for lineage in write_ids if lineage in self.nodes] def _routine_lineages(self, routine_name: str) -> set[str]: wanted = routine_name.lower() return { node.lineage_id for node in self.nodes.values() if node.name.lower() == wanted and node.kind.value in {"PROCEDURE", "FUNCTION"} } class Neo4jProjection: def __init__(self, driver) -> None: self._driver = driver async def ensure_schema(self) -> None: async with self._driver.session() as session: await session.run( """ CREATE CONSTRAINT sfera_node_lineage IF NOT EXISTS FOR (n:SferaNode) REQUIRE n.lineage_id IS UNIQUE """ ) await session.run( """ CREATE INDEX sfera_node_project_name IF NOT EXISTS FOR (n:SferaNode) ON (n.project_id, n.name) """ ) await session.run( """ CREATE INDEX sfera_edge_kind IF NOT EXISTS FOR ()-[r:SEMANTIC_EDGE]-() ON (r.kind) """ ) async def project_snapshot(self, snapshot: SirSnapshot) -> None: await self.ensure_schema() async with self._driver.session() as session: for node in snapshot.nodes: await self._merge_node(session, node, snapshot.project_id) for edge in snapshot.edges: await self._merge_edge(session, edge) async def apply_delta(self, delta: SirDelta, *, project_id: str) -> None: await self.ensure_schema() async with self._driver.session() as session: for lineage_id in delta.removed_nodes: await session.run( """ MATCH (n:SferaNode {project_id: $project_id, lineage_id: $lineage_id}) DETACH DELETE n """, project_id=project_id, lineage_id=lineage_id, ) for edge_id in delta.removed_edges: await session.run( """ MATCH ()-[r:SEMANTIC_EDGE {edge_id: $edge_id}]->() DELETE r """, edge_id=edge_id, ) for node in [*delta.added_nodes, *delta.updated_nodes]: await self._merge_node(session, node, project_id) for edge in delta.added_edges: await self._merge_edge(session, edge) async def clear_project(self, project_id: str) -> None: async with self._driver.session() as session: await session.run( """ MATCH (n:SferaNode {project_id: $project_id}) DETACH DELETE n """, project_id=project_id, ) async def counts(self, project_id: str | None = None) -> dict[str, int]: async with self._driver.session() as session: if project_id is None: node_result = await session.run("MATCH (n:SferaNode) RETURN count(n) AS count") edge_result = await session.run( "MATCH ()-[r:SEMANTIC_EDGE]->() RETURN count(r) AS count" ) else: node_result = await session.run( """ MATCH (n:SferaNode {project_id: $project_id}) RETURN count(n) AS count """, project_id=project_id, ) edge_result = await session.run( """ MATCH (:SferaNode {project_id: $project_id})-[r:SEMANTIC_EDGE]-> (:SferaNode {project_id: $project_id}) RETURN count(r) AS count """, project_id=project_id, ) node_record = await node_result.single() edge_record = await edge_result.single() return { "nodes": int(node_record["count"]) if node_record else 0, "edges": int(edge_record["count"]) if edge_record else 0, } async def _merge_node(self, session, node: SemanticNode, project_id: str) -> None: await session.run( """ MERGE (n:SferaNode {lineage_id: $lineage_id}) SET n.semantic_id = $semantic_id, n.project_id = $project_id, n.kind = $kind, n.name = $name, n.qualified_name = $qualified_name, n.attributes_json = $attributes_json, n.source_ref_json = $source_ref_json """, lineage_id=node.lineage_id, semantic_id=node.semantic_id, project_id=project_id, kind=node.kind.value, name=node.name, qualified_name=node.qualified_name, attributes_json=json.dumps(node.attributes, ensure_ascii=False, sort_keys=True), source_ref_json=node.source_ref.model_dump_json(exclude_none=True), ) async def _merge_edge(self, session, edge: SemanticEdge) -> None: await session.run( """ MATCH (source:SferaNode {lineage_id: $source_lineage}) MATCH (target:SferaNode {lineage_id: $target_lineage}) MERGE (source)-[r:SEMANTIC_EDGE {edge_id: $edge_id}]->(target) SET r.kind = $kind, r.attributes_json = $attributes_json, r.source_ref_json = $source_ref_json """, source_lineage=edge.source_lineage, target_lineage=edge.target_lineage, edge_id=edge.edge_id, kind=edge.kind.value, attributes_json=json.dumps(edge.attributes, ensure_ascii=False, sort_keys=True), source_ref_json=edge.source_ref.model_dump_json(exclude_none=True) if edge.source_ref else None, ) def build_adjacency(snapshot: SirSnapshot) -> dict[str, list[SemanticEdge]]: adjacency: dict[str, list[SemanticEdge]] = defaultdict(list) for edge in snapshot.edges: adjacency[edge.source_lineage].append(edge) return dict(adjacency) __all__ = ["InMemoryProjection", "Neo4jProjection", "build_adjacency"]