from __future__ import annotations from pydantic import BaseModel, Field from sir import EdgeKind, SemanticNode, SirSnapshot class TableUsage(BaseModel): table: SemanticNode queries: list[SemanticNode] = Field(default_factory=list) readers: list[SemanticNode] = Field(default_factory=list) writers: list[SemanticNode] = Field(default_factory=list) @property def has_read_write_conflict(self) -> bool: return bool(self.readers and self.writers) def table_usage(snapshot: SirSnapshot, table_name: str | None = None) -> list[TableUsage]: nodes = {node.lineage_id: node for node in snapshot.nodes} query_owner: dict[str, SemanticNode] = {} for edge in snapshot.edges: if edge.kind == EdgeKind.OWNS_QUERY and edge.target_lineage in nodes and edge.source_lineage in nodes: query_owner[edge.target_lineage] = nodes[edge.source_lineage] usage_by_table: dict[str, TableUsage] = {} for edge in snapshot.edges: if edge.kind != EdgeKind.READS_TABLE: continue query = nodes.get(edge.source_lineage) table = nodes.get(edge.target_lineage) if query is None or table is None: continue if table_name is not None and table.name.casefold() != table_name.casefold() and table.qualified_name.casefold() != table_name.casefold(): continue usage = usage_by_table.setdefault(table.lineage_id, TableUsage(table=table)) usage.queries.append(query) if owner := query_owner.get(query.lineage_id): usage.readers.append(owner) for edge in snapshot.edges: if edge.kind != EdgeKind.WRITES: continue writer = nodes.get(edge.source_lineage) table = nodes.get(edge.target_lineage) if writer is None or table is None: continue if table_name is not None and table.name.casefold() != table_name.casefold() and table.qualified_name.casefold() != table_name.casefold(): continue usage = usage_by_table.setdefault(table.lineage_id, TableUsage(table=table)) usage.writers.append(writer) result = list(usage_by_table.values()) result.sort(key=lambda item: item.table.qualified_name) return result def tables_with_read_write_conflicts(snapshot: SirSnapshot) -> list[TableUsage]: return [usage for usage in table_usage(snapshot) if usage.has_read_write_conflict] __all__ = ["TableUsage", "table_usage", "tables_with_read_write_conflicts"]