""" Helper functions """
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
import psqlgml
from psqlgraph import Node, PsqlGraphDriver, mocks
from pytest_psqlgraph.typings import Literal
from . import models
logger = logging.getLogger(__name__)
[docs]def truncate_tables(pg_driver: PsqlGraphDriver) -> None:
"""Truncates all entries in the database
Args:
pg_driver: active driver
"""
with pg_driver.engine.begin() as conn:
for table in reversed(pg_driver.engine.table_names()):
try:
conn.execute("delete from {} cascade".format(table))
logger.debug(f"truncated {table}")
except Exception as e:
logger.warning(f"error while truncating {table} - {str(e)}", exc_info=True)
def create_tables(driver: models.DatabaseDriver) -> None:
# create default graph tables
driver.create_all()
# create extra base tables
for extra in driver.extra_bases:
extra.metadata.create_all(driver.g.engine)
[docs]def drop_tables(driver: models.DatabaseDriver) -> None:
"""Drops all tables in the listed orm_bases"""
driver.drop_all()
# drop all base tables
for base in driver.extra_bases:
base.metadata.drop_all(driver.g.engine)
@attr.s(auto_attribs=True)
class DatabaseFixture:
name: str
driver: models.DatabaseDriver
volatile: bool = False
def pre_test(self) -> PsqlGraphDriver:
logger.debug("Running pre test setup for {}".format(self.name))
truncate_tables(self.driver.g)
return self.driver.g
def post_test(self) -> None:
logger.debug("Running post test clean up for {}".format(self.name))
truncate_tables(self.driver.g)
def pre_config(self) -> None:
logger.debug("Setting up database for {}".format(self.name))
create_tables(self.driver)
def post_config(self) -> None:
logger.debug("Destroying database for {}".format(self.name))
drop_tables(self.driver)
@attr.s(auto_attribs=True)
class DataFactory:
model: models.DataModel
pg_driver: PsqlGraphDriver
dictionary: Optional[models.Dictionary]
extension: models.MarkExtension
globals: Optional[Dict[str, Any]]
factory: mocks.GraphFactory = None
mock_data: List[Node] = attr.ib(factory=list)
def __attrs_post_init__(self) -> None:
self.factory = mocks.GraphFactory(
models=self.model,
dictionary=self.dictionary,
graph_globals=self.globals or {},
)
def from_source(
self,
source_data: psqlgml.GmlData,
) -> List[Node]:
# do post processing
nodes_cache: Dict[str, psqlgml.GmlNode] = {}
unique_key: Literal["node_id", "submitter_id"] = source_data.get(
"unique_field", "submitter_id"
)
mock_all_props = source_data.get("mock_all_props", True)
for n in source_data["nodes"]:
nodes_cache[n[unique_key]] = n
self.mock_data = self.factory.create_from_nodes_and_edges(
unique_key=unique_key,
all_props=mock_all_props,
nodes=source_data["nodes"],
edges=source_data["edges"],
)
self.extension.pre(self.mock_data)
with self.pg_driver.session_scope(can_inherit=False) as s:
for node in self.mock_data:
self.extension.run(node)
s.add(node)
self.extension.post(self.mock_data)
return self.mock_data
def clean(self) -> None:
with self.pg_driver.session_scope() as sxn:
for node in self.mock_data:
node = self.pg_driver.nodes().get(node.node_id)
if node:
sxn.delete(node)
@attr.s(auto_attribs=True)
class MarkHandler:
mark: models.PsqlgraphDataMark
fixture: DatabaseFixture
factory: DataFactory = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
cls = self.mark.get("extension") or models.MarkExtension
self.factory = DataFactory(
pg_driver=self.driver.g,
model=self.driver.model,
globals=self.driver.globals,
dictionary=self.driver.dictionary,
extension=cls(g=self.driver.g),
)
@property
def driver(self) -> models.DatabaseDriver:
return self.fixture.driver
def pre(self) -> List[Node]:
resource = self.mark["resource"]
if isinstance(resource, dict):
if validate_resource(resource, self.driver.dictionary):
raise ValueError("Data Error")
return self.factory.from_source(resource)
data_dir = self.mark["data_dir"]
source_data = psqlgml.load_resource(data_dir, resource)
# do validation
if validate_file_resource(resource, data_dir, self.driver.dictionary):
raise ValueError("Invalid data specified")
return self.factory.from_source(source_data)
def post(self) -> None:
self.factory.clean()
def read_schema(
dictionary: models.Dictionary,
) -> Tuple[psqlgml.Dictionary, psqlgml.GmlSchema]:
dictionary_name = f"{dictionary.__module__}.{dictionary.__class__.__name__}"
di = psqlgml.from_object(dictionary.schema, name=dictionary_name, version="pytest_psqlgraph")
psqlgml.generate(di)
return di, psqlgml.read_schema(dictionary_name, version="pytest_psqlgraph")
def validate_file_resource(
data_file: str, data_dir: str, dictionary: models.Dictionary
) -> Set[psqlgml.DataViolation]:
di, schema = read_schema(dictionary)
req = psqlgml.ValidationRequest(
data_file=data_file, data_dir=data_dir, schema=schema, dictionary=di
)
grouped_violations = psqlgml.validate(req, print_error=True)
violations: Set[psqlgml.DataViolation] = set.union(*grouped_violations.values())
return violations
def validate_resource(
resource: psqlgml.GmlData, dictionary: models.Dictionary
) -> Set[psqlgml.DataViolation]:
di, schema = read_schema(dictionary)
req = psqlgml.ValidationRequest("data object", "", schema, di, payload={"": resource})
grouped_violations = psqlgml.validate(req, print_error=True)
violations: Set[psqlgml.DataViolation] = set.union(*grouped_violations.values())
return violations