diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py new file mode 100644 index 00000000..20383e04 --- /dev/null +++ b/dbldatagen/spec/column_spec.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, Literal + +from .compat import BaseModel, root_validator + + +DbldatagenBasicType = Literal[ + "string", + "int", + "long", + "float", + "double", + "decimal", + "boolean", + "date", + "timestamp", + "short", + "byte", + "binary", + "integer", + "bigint", + "tinyint", +] +class ColumnDefinition(BaseModel): + name: str + type: DbldatagenBasicType | None = None + primary: bool = False + options: dict[str, Any] | None = {} + nullable: bool | None = False + omit: bool | None = False + baseColumn: str | None = "id" + baseColumnType: str | None = "auto" + + @root_validator() + def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: + """ + Validates constraints across the entire model after individual fields are processed. + """ + is_primary = values.get("primary") + options = values.get("options", {}) + name = values.get("name") + is_nullable = values.get("nullable") + column_type = values.get("type") + + if is_primary: + if "min" in options or "max" in options: + raise ValueError(f"Primary column '{name}' cannot have min/max options.") + + if is_nullable: + raise ValueError(f"Primary column '{name}' cannot be nullable.") + + if column_type is None: + raise ValueError(f"Primary column '{name}' must have a type defined.") + return values diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py new file mode 100644 index 00000000..3b604afb --- /dev/null +++ b/dbldatagen/spec/compat.py @@ -0,0 +1,31 @@ +# This module acts as a compatibility layer for Pydantic V1 and V2. + +try: + # This will succeed on environments with Pydantic V2.x + # It imports the V1 API that is bundled within V2. + from pydantic.v1 import BaseModel, Field, validator, constr, root_validator + +except ImportError: + # This will be executed on environments with only Pydantic V1.x + from pydantic import BaseModel, Field, validator, constr, root_validator + +# In your application code, do this: +# from .compat import BaseModel +# NOT this: +# from pydantic import BaseModel + +# FastAPI Notes +# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py + + +""" +## Why This Approach +No Installation Required: It directly addresses your core requirement. +You don't need to %pip install anything, which avoids conflicts with the pre-installed libraries on Databricks. +Single Codebase: You maintain one set of code that is guaranteed to work with the Pydantic V1 API, which is available in both runtimes. + +Environment Agnostic: Your application code in models.py has no idea which version of Pydantic is actually installed. The compat.py module handles that complexity completely. + +Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features), +you only need to change your application code and your compat.py import statements, making the transition much clearer. +""" \ No newline at end of file diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py new file mode 100644 index 00000000..c12d9a59 --- /dev/null +++ b/dbldatagen/spec/generator_spec.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from typing import Any, Literal, Union + +import pandas as pd +from IPython.display import HTML, display + +from dbldatagen.spec.column_spec import ColumnDefinition + +from .compat import BaseModel, validator + + +class UCSchemaTarget(BaseModel): + catalog: str + schema_: str + output_format: str = "delta" # Default to delta for UC Schema + + @validator("catalog", "schema_") + def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argument + if not v.strip(): + raise ValueError("Identifier must be non-empty.") + if not v.isidentifier(): + logger.warning( + f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") + return v.strip() + + def __str__(self): + return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" + + +class FilePathTarget(BaseModel): + base_path: str + output_format: Literal["csv", "parquet"] # No default, must be specified + + @validator("base_path") + def validate_base_path(cls, v): # noqa: N805, pylint: disable=no-self-argument + if not v.strip(): + raise ValueError("base_path must be non-empty.") + return v.strip() + + def __str__(self): + return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" + + +class TableDefinition(BaseModel): + number_of_rows: int + partitions: int | None = None + columns: list[ColumnDefinition] + + +class ValidationResult: + """Container for validation results with errors and warnings.""" + + def __init__(self) -> None: + self.errors: list[str] = [] + self.warnings: list[str] = [] + + def add_error(self, message: str) -> None: + """Add an error message.""" + self.errors.append(message) + + def add_warning(self, message: str) -> None: + """Add a warning message.""" + self.warnings.append(message) + + def is_valid(self) -> bool: + """Returns True if there are no errors.""" + return len(self.errors) == 0 + + def __str__(self) -> str: + """String representation of validation results.""" + lines = [] + if self.is_valid(): + lines.append("✓ Validation passed successfully") + else: + lines.append("✗ Validation failed") + + if self.errors: + lines.append(f"\nErrors ({len(self.errors)}):") + for i, error in enumerate(self.errors, 1): + lines.append(f" {i}. {error}") + + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + for i, warning in enumerate(self.warnings, 1): + lines.append(f" {i}. {warning}") + + return "\n".join(lines) + +class DatagenSpec(BaseModel): + tables: dict[str, TableDefinition] + output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg + generator_options: dict[str, Any] | None = {} + intended_for_databricks: bool | None = None # May be infered. + + def _check_circular_dependencies( + self, + table_name: str, + columns: list[ColumnDefinition] + ) -> list[str]: + """ + Check for circular dependencies in baseColumn references. + Returns a list of error messages if circular dependencies are found. + """ + errors = [] + column_map = {col.name: col for col in columns} + + for col in columns: + if col.baseColumn and col.baseColumn != "id": + # Track the dependency chain + visited = set() + current = col.name + + while current: + if current in visited: + # Found a cycle + cycle_path = " -> ".join(list(visited) + [current]) + errors.append( + f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}" + ) + break + + visited.add(current) + current_col = column_map.get(current) + + if not current_col: + break + + # Move to the next column in the chain + if current_col.baseColumn and current_col.baseColumn != "id": + if current_col.baseColumn not in column_map: + # baseColumn doesn't exist - we'll catch this in another validation + break + current = current_col.baseColumn + else: + # Reached a column that doesn't have a baseColumn or uses "id" + break + + return errors + + def validate(self, strict: bool = True) -> ValidationResult: + """ + Validates the entire DatagenSpec configuration. + Always runs all validation checks and collects all errors and warnings. + + Args: + strict: If True, raises ValueError if any errors or warnings are found. + If False, only raises ValueError if errors (not warnings) are found. + + Returns: + ValidationResult object containing all errors and warnings found. + + Raises: + ValueError: If validation fails based on strict mode setting. + The exception message contains all errors and warnings. + """ + result = ValidationResult() + + # 1. Check that there's at least one table + if not self.tables: + result.add_error("Spec must contain at least one table definition") + + # 2. Validate each table (continue checking all tables even if errors found) + for table_name, table_def in self.tables.items(): + # Check table has at least one column + if not table_def.columns: + result.add_error(f"Table '{table_name}' must have at least one column") + continue # Skip further checks for this table since it has no columns + + # Check row count is positive + if table_def.number_of_rows <= 0: + result.add_error( + f"Table '{table_name}' has invalid number_of_rows: {table_def.number_of_rows}. " + "Must be a positive integer." + ) + + # Check partitions if specified + #TODO: though this can be a model field check, we are checking here so that one can correct + # Can we find a way to use the default way? + if table_def.partitions is not None and table_def.partitions <= 0: + result.add_error( + f"Table '{table_name}' has invalid partitions: {table_def.partitions}. " + "Must be a positive integer or None." + ) + + # Check for duplicate column names + # TODO: Not something possible if we right model, recheck + column_names = [col.name for col in table_def.columns] + duplicates = [name for name in set(column_names) if column_names.count(name) > 1] + if duplicates: + result.add_error( + f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}" + ) + + # Build column map for reference checking + column_map = {col.name: col for col in table_def.columns} + + # TODO: Check baseColumn references, this is tricky? check the dbldefaults + for col in table_def.columns: + if col.baseColumn and col.baseColumn != "id": + if col.baseColumn not in column_map: + result.add_error( + f"Table '{table_name}', column '{col.name}': " + f"baseColumn '{col.baseColumn}' does not exist in the table" + ) + + # Check for circular dependencies in baseColumn references + circular_errors = self._check_circular_dependencies(table_name, table_def.columns) + for error in circular_errors: + result.add_error(error) + + # Check primary key constraints + primary_columns = [col for col in table_def.columns if col.primary] + if len(primary_columns) > 1: + primary_names = [col.name for col in primary_columns] + result.add_warning( + f"Table '{table_name}' has multiple primary columns: {', '.join(primary_names)}. " + "This may not be the intended behavior." + ) + + # Check for columns with no type and not using baseColumn properly + for col in table_def.columns: + if not col.primary and not col.type and not col.options: + result.add_warning( + f"Table '{table_name}', column '{col.name}': " + "No type specified and no options provided. " + "Column may not generate data as expected." + ) + + # 3. Check output destination + if not self.output_destination: + result.add_warning( + "No output_destination specified. Data will be generated but not persisted. " + "Set output_destination to save generated data." + ) + + # 4. Validate generator options (if any known options) + if self.generator_options: + known_options = [ + "random", "randomSeed", "randomSeedMethod", "verbose", + "debug", "seedColumnName" + ] + for key in self.generator_options: + if key not in known_options: + result.add_warning( + f"Unknown generator option: '{key}'. " + "This may be ignored during generation." + ) + + # Now that all validations are complete, decide whether to raise + if (strict and (result.errors or result.warnings)) or (not strict and result.errors): + raise ValueError(str(result)) + + return result + + + def display_all_tables(self) -> None: + for table_name, table_def in self.tables.items(): + print(f"Table: {table_name}") + + if self.output_destination: + output = f"{self.output_destination}" + display(HTML(f"Output destination: {output}")) + else: + message = ( + "Output destination: " + "None
" + "Set it using the output_destination " + "attribute on your DatagenSpec object " + "(e.g., my_spec.output_destination = UCSchemaTarget(...))." + ) + display(HTML(message)) + + df = pd.DataFrame([col.dict() for col in table_def.columns]) + try: + display(df) + except NameError: + print(df.to_string()) diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py new file mode 100644 index 00000000..a508b1a5 --- /dev/null +++ b/dbldatagen/spec/generator_spec_impl.py @@ -0,0 +1,254 @@ +import logging +from typing import Dict, Union +import posixpath + +from dbldatagen.spec.generator_spec import TableDefinition +from pyspark.sql import SparkSession +import dbldatagen as dg +from .generator_spec import DatagenSpec, UCSchemaTarget, FilePathTarget, ColumnDefinition + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + +INTERNAL_ID_COLUMN_NAME = "id" + + +class Generator: + """ + Main data generation orchestrator that handles configuration, preparation, and writing of data. + """ + + def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> None: + """ + Initialize the Generator with a SparkSession. + Args: + spark: An existing SparkSession instance + app_name: Application name for logging purposes + Raises: + RuntimeError: If spark is None + """ + if not spark: + logger.error( + "SparkSession cannot be None during Generator initialization") + raise RuntimeError("SparkSession cannot be None") + self.spark = spark + self._created_spark_session = False + self.app_name = app_name + logger.info("Generator initialized with SparkSession") + + def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> Dict[str, str]: + """ + Convert a ColumnDefinition to dbldatagen column specification. + Args: + col_def: ColumnDefinition object containing column configuration + Returns: + Dictionary containing dbldatagen column specification + """ + col_name = col_def.name + col_type = col_def.type + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_def.primary: + kwargs["colType"] = col_type + kwargs["baseColumn"] = INTERNAL_ID_COLUMN_NAME + + if col_type == "string": + kwargs["baseColumnType"] = "hash" + elif col_type not in ["int", "long", "integer", "bigint", "short"]: + kwargs["baseColumnType"] = "auto" + logger.warning( + f"Primary key '{col_name}' has non-standard type '{col_type}'") + + # Log conflicting options for primary keys + conflicting_opts_for_pk = [ + "distribution", "template", "dataRange", "random", "omit", + "min", "max", "uniqueValues", "values", "expr" + ] + + for opt_key in conflicting_opts_for_pk: + if opt_key in kwargs: + logger.warning( + f"Primary key '{col_name}': Option '{opt_key}' may be ignored") + + if col_def.omit is not None and col_def.omit: + kwargs["omit"] = True + else: + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_type: + kwargs["colType"] = col_type + if col_def.baseColumn: + kwargs["baseColumn"] = col_def.baseColumn + if col_def.baseColumnType: + kwargs["baseColumnType"] = col_def.baseColumnType + if col_def.omit is not None: + kwargs["omit"] = col_def.omit + + return kwargs + + def _prepare_data_generators( + self, + config: DatagenSpec, + config_source_name: str = "PydanticConfig" + ) -> Dict[str, dg.DataGenerator]: + """ + Prepare DataGenerator specifications for each table based on the configuration. + Args: + config: DatagenSpec Pydantic object containing table configurations + config_source_name: Name for the configuration source (for logging) + Returns: + Dictionary mapping table names to their configured dbldatagen.DataGenerator objects + Raises: + RuntimeError: If SparkSession is not available + ValueError: If any table preparation fails + Exception: If any unexpected error occurs during preparation + """ + logger.info( + f"Preparing data generators for {len(config.tables)} tables") + + if not self.spark: + logger.error( + "SparkSession is not available. Cannot prepare data generators") + raise RuntimeError( + "SparkSession is not available. Cannot prepare data generators") + + tables_config: Dict[str, TableDefinition] = config.tables + global_gen_options = config.generator_options if config.generator_options else {} + + prepared_generators: Dict[str, dg.DataGenerator] = {} + generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable + + for table_name in generation_order: + table_spec = tables_config[table_name] + logger.info(f"Preparing table: {table_name}") + + try: + # Create DataGenerator instance + data_gen = dg.DataGenerator( + sparkSession=self.spark, + name=f"{table_name}_spec_from_{config_source_name}", + rows=table_spec.number_of_rows, + partitions=table_spec.partitions, + **global_gen_options, + ) + + # Process each column + for col_def in table_spec.columns: + kwargs = self._columnspec_to_datagen_columnspec(col_def) + data_gen = data_gen.withColumn(colName=col_def.name, **kwargs) + # Has performance implications. + + prepared_generators[table_name] = data_gen + logger.info(f"Successfully prepared table: {table_name}") + + except Exception as e: + logger.error(f"Failed to prepare table '{table_name}': {e}") + raise RuntimeError( + f"Failed to prepare table '{table_name}': {e}") from e + + logger.info("All data generators prepared successfully") + return prepared_generators + + def write_prepared_data( + self, + prepared_generators: Dict[str, dg.DataGenerator], + output_destination: Union[UCSchemaTarget, FilePathTarget, None], + config_source_name: str = "PydanticConfig", + ) -> None: + """ + Write data from prepared generators to the specified output destination. + + Args: + prepared_generators: Dictionary of prepared DataGenerator objects + output_destination: Target destination for data output + config_source_name: Name for the configuration source (for logging) + + Raises: + RuntimeError: If any table write fails + ValueError: If output destination is not properly configured + """ + logger.info("Starting data writing phase") + + if not prepared_generators: + logger.warning("No prepared data generators to write") + return + + for table_name, data_gen in prepared_generators.items(): + logger.info(f"Writing table: {table_name}") + + try: + df = data_gen.build() + requested_rows = data_gen.rowCount + actual_row_count = df.count() + logger.info( + f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})") + + if actual_row_count == 0 and requested_rows > 0: + logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0") + + # Write data based on destination type + if isinstance(output_destination, FilePathTarget): + output_path = posixpath.join(output_destination.base_path, table_name) + df.write.format(output_destination.output_format).mode("overwrite").save(output_path) + logger.info(f"Wrote table '{table_name}' to file path: {output_path}") + + elif isinstance(output_destination, UCSchemaTarget): + output_table = f"{output_destination.catalog}.{output_destination.schema_}.{table_name}" + df.write.mode("overwrite").saveAsTable(output_table) + logger.info(f"Wrote table '{table_name}' to Unity Catalog: {output_table}") + else: + logger.warning("No output destination specified, skipping data write") + return + except Exception as e: + logger.error(f"Failed to write table '{table_name}': {e}") + raise RuntimeError(f"Failed to write table '{table_name}': {e}") from e + logger.info("All data writes completed successfully") + + def generate_and_write_data( + self, + config: DatagenSpec, + config_source_name: str = "PydanticConfig" + ) -> None: + """ + Combined method to prepare data generators and write data in one operation. + This method orchestrates the complete data generation workflow: + 1. Prepare data generators from configuration + 2. Write data to the specified destination + Args: + config: DatagenSpec Pydantic object containing table configurations + config_source_name: Name for the configuration source (for logging) + Raises: + RuntimeError: If SparkSession is not available or any step fails + ValueError: If critical errors occur during preparation or writing + """ + logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables") + + try: + # Phase 1: Prepare data generators + prepared_generators_map = self._prepare_data_generators(config, config_source_name) + + if not prepared_generators_map and list(config.tables.keys()): + logger.warning( + "No data generators were successfully prepared, though tables were defined") + return + + # Phase 2: Write data + self.write_prepared_data( + prepared_generators_map, + config.output_destination, + config_source_name + ) + + logger.info( + "Combined data generation and writing completed successfully") + + except Exception as e: + logger.error( + f"Error during combined data generation and writing: {e}") + raise RuntimeError( + f"Error during combined data generation and writing: {e}") from e \ No newline at end of file diff --git a/makefile b/makefile index 772397bf..a5f4486c 100644 --- a/makefile +++ b/makefile @@ -8,7 +8,7 @@ clean: .venv/bin/python: pip install hatch - hatch env create + hatch env create test-pydantic.pydantic==1.10.6-v1 dev: .venv/bin/python @hatch run which python @@ -20,7 +20,7 @@ fmt: hatch run fmt test: - hatch run test + hatch run test-pydantic:test test-coverage: make test && open htmlcov/index.html diff --git a/pydantic_compat.md b/pydantic_compat.md new file mode 100644 index 00000000..abf26e60 --- /dev/null +++ b/pydantic_compat.md @@ -0,0 +1,101 @@ +To write code that works on both Pydantic V1 and V2 and ensures a smooth future migration, you should code against the V1 API but import it through a compatibility shim. This approach uses V1's syntax, which Pydantic V2 can understand via its built-in V1 compatibility layer. + +----- + +### \#\# The Golden Rule: Code to V1, Import via a Shim 💡 + +The core strategy is to **write all your models using Pydantic V1 syntax and features**. You then use a special utility file to handle the imports, which makes your application code completely agnostic to the installed Pydantic version. + +----- + +### \#\# 1. Implement a Compatibility Shim (`compat.py`) + +This is the most critical step. Create a file named `compat.py` in your project that intelligently imports Pydantic components. Your application will import everything from this file instead of directly from `pydantic`. + +```python +# compat.py +# This module acts as a compatibility layer for Pydantic V1 and V2. + +try: + # This will succeed on environments with Pydantic V2.x + # It imports the V1 API that is bundled within V2. + from pydantic.v1 import BaseModel, Field, validator, constr + +except ImportError: + # This will be executed on environments with only Pydantic V1.x + from pydantic import BaseModel, Field, validator, constr + +# In your application code, do this: +# from .compat import BaseModel +# NOT this: +# from pydantic import BaseModel +``` + +----- + +### \#\# 2. Stick to V1 Features and Syntax (Do's and Don'ts) + +By following these rules in your application code, you ensure the logic works on both versions. + +#### **✅ Models and Fields: DO** + + * Use standard `BaseModel` and `Field` for all your data structures. This is the most stable part of the API. + +#### **❌ Models and Fields: DON'T** + + * **Do not use `__root__` models**. This V1 feature was removed in V2 and the compatibility is not perfect. Instead, model the data explicitly, even if it feels redundant. + * **Bad (Avoid):** `class MyList(BaseModel): __root__: list[str]` + * **Good (Compatible):** `class MyList(BaseModel): items: list[str]` + +#### **✅ Configuration: DO** + + * Use the nested `class Config:` for model configuration. This is the V1 way and is fully supported by the V2 compatibility layer. + * **Example:** + ```python + from .compat import BaseModel + + class User(BaseModel): + id: int + full_name: str + + class Config: + orm_mode = True # V2's compatibility layer translates this + allow_population_by_field_name = True + ``` + +#### **❌ Configuration: DON'T** + + * **Do not use the V2 `model_config` dictionary**. This is a V2-only feature. + +#### **✅ Validators and Data Types: DO** + + * Use the standard V1 `@validator`. It's robust and works perfectly across both versions. + * Use V1 constrained types like `constr`, `conint`, `conlist`. + * **Example:** + ```python + from .compat import BaseModel, validator, constr + + class Product(BaseModel): + name: constr(min_length=3) + + @validator("name") + def name_must_be_alpha(cls, v): + if not v.isalpha(): + raise ValueError("Name must be alphabetic") + return v + ``` + +#### **❌ Validators and Data Types: DON'T** + + * **Do not use V2 decorators** like `@field_validator`, `@model_validator`, or `@field_serializer`. + * **Do not use the V2 `Annotated` syntax** for validation (e.g., `Annotated[str, StringConstraints(min_length=2)]`). + +----- + +### \#\# 3. The Easy Migration Path + +When you're finally ready to leave V1 behind and upgrade your code to be V2-native, the process will be straightforward because your code is already consistent: + +1. **Change Imports**: Your first step will be a simple find-and-replace to change all `from .compat import ...` statements to `from pydantic import ...`. +2. **Run a Codelinter**: Tools like **Ruff** have built-in rules that can automatically refactor most of your V1 syntax (like `Config` classes and `@validator`s) to the new V2 syntax. +3. **Manual Refinements**: Address any complex patterns the automated tools couldn't handle, like replacing your `__root__` model alternatives. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 13728ba2..f0dabcf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,21 +103,35 @@ dependencies = [ "jmespath>=0.10.0", "py4j>=0.10.9", "pickleshare>=0.7.5", + "ipython>=7.32.0", ] python="3.10" - -# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better path = ".venv" [tool.hatch.envs.default.scripts] test = "pytest tests/ -n 10 --cov --cov-report=html --timeout 600 --durations 20" -fmt = ["ruff check . --fix", - "mypy .", - "pylint --output-format=colorized -j 0 dbldatagen tests"] -verify = ["ruff check .", - "mypy .", - "pylint --output-format=colorized -j 0 dbldatagen tests"] +fmt = [ + "ruff check . --fix", + "mypy .", + "pylint --output-format=colorized -j 0 dbldatagen tests" +] +verify = [ + "ruff check .", + "mypy .", + "pylint --output-format=colorized -j 0 dbldatagen tests" +] + + +[tool.hatch.envs.test-pydantic] +template = "default" +matrix = [ + { pydantic_version = ["1.10.6", "2.8.2"] } +] +extra-dependencies = [ + "pydantic=={matrix:pydantic_version}" +] + # Ruff configuration - replaces flake8, isort, pydocstyle, etc. [tool.ruff] diff --git a/scratch.md b/scratch.md new file mode 100644 index 00000000..a3afa5c3 --- /dev/null +++ b/scratch.md @@ -0,0 +1,4 @@ +Pydantic Notes +https://docs.databricks.com/aws/en/release-notes/runtime/14.3lts - 1.10.6 +https://docs.databricks.com/aws/en/release-notes/runtime/15.4lts - 1.10.6 +https://docs.databricks.com/aws/en/release-notes/runtime/16.4lts - 2.8.2 (2.20.1 - core) \ No newline at end of file diff --git a/tests/test_specs.py b/tests/test_specs.py new file mode 100644 index 00000000..d3c8ab2c --- /dev/null +++ b/tests/test_specs.py @@ -0,0 +1,466 @@ +from dbldatagen.spec.generator_spec import DatagenSpec +import pytest +from dbldatagen.spec.generator_spec import ( + DatagenSpec, + TableDefinition, + ColumnDefinition, + UCSchemaTarget, + FilePathTarget, + ValidationResult +) + +class TestValidationResult: + """Tests for ValidationResult class""" + + def test_empty_result_is_valid(self): + result = ValidationResult() + assert result.is_valid() + assert len(result.errors) == 0 + assert len(result.warnings) == 0 + + def test_result_with_errors_is_invalid(self): + result = ValidationResult() + result.add_error("Test error") + assert not result.is_valid() + assert len(result.errors) == 1 + + def test_result_with_only_warnings_is_valid(self): + result = ValidationResult() + result.add_warning("Test warning") + assert result.is_valid() + assert len(result.warnings) == 1 + + def test_result_string_representation(self): + result = ValidationResult() + result.add_error("Error 1") + result.add_error("Error 2") + result.add_warning("Warning 1") + + result_str = str(result) + assert "✗ Validation failed" in result_str + assert "Errors (2)" in result_str + assert "Error 1" in result_str + assert "Error 2" in result_str + assert "Warnings (1)" in result_str + assert "Warning 1" in result_str + + def test_valid_result_string_representation(self): + result = ValidationResult() + result_str = str(result) + assert "✓ Validation passed successfully" in result_str + + +class TestColumnDefinitionValidation: + """Tests for ColumnDefinition validation""" + + def test_valid_primary_column(self): + col = ColumnDefinition( + name="id", + type="int", + primary=True + ) + assert col.primary + assert col.type == "int" + + def test_primary_column_with_min_max_raises_error(self): + with pytest.raises(ValueError, match="cannot have min/max options"): + ColumnDefinition( + name="id", + type="int", + primary=True, + options={"min": 1, "max": 100} + ) + + def test_primary_column_nullable_raises_error(self): + with pytest.raises(ValueError, match="cannot be nullable"): + ColumnDefinition( + name="id", + type="int", + primary=True, + nullable=True + ) + + def test_primary_column_without_type_raises_error(self): + with pytest.raises(ValueError, match="must have a type defined"): + ColumnDefinition( + name="id", + primary=True + ) + + def test_non_primary_column_without_type(self): + # Should not raise + col = ColumnDefinition( + name="data", + options={"values": ["a", "b", "c"]} + ) + assert col.name == "data" + + +class TestDatagenSpecValidation: + """Tests for DatagenSpec.validate() method""" + + def test_valid_spec_passes_validation(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="name", type="string", options={"values": ["Alice", "Bob"]}), + ] + ) + }, + output_destination=UCSchemaTarget(catalog="main", schema_="default") + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.errors) == 0 + + def test_empty_tables_raises_error(self): + spec = DatagenSpec(tables={}) + + with pytest.raises(ValueError, match="at least one table"): + spec.validate(strict=True) + + def test_table_without_columns_raises_error(self): + spec = DatagenSpec( + tables={ + "empty_table": TableDefinition( + number_of_rows=100, + columns=[] + ) + } + ) + + with pytest.raises(ValueError, match="must have at least one column"): + spec.validate() + + def test_negative_row_count_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=-10, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid number_of_rows"): + spec.validate() + + def test_zero_row_count_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=0, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid number_of_rows"): + spec.validate() + + def test_invalid_partitions_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + partitions=-5, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid partitions"): + spec.validate() + + def test_duplicate_column_names_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="duplicate", type="string"), + ColumnDefinition(name="duplicate", type="int"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="duplicate column names"): + spec.validate() + + def test_invalid_base_column_reference_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="email", type="string", baseColumn="nonexistent"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="does not exist"): + spec.validate() + + def test_circular_dependency_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="col_a", type="string", baseColumn="col_b"), + ColumnDefinition(name="col_b", type="string", baseColumn="col_c"), + ColumnDefinition(name="col_c", type="string", baseColumn="col_a"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="Circular dependency"): + spec.validate() + + def test_multiple_primary_columns_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id1", type="int", primary=True), + ColumnDefinition(name="id2", type="int", primary=True), + ] + ) + } + ) + + # In strict mode, warnings cause errors + with pytest.raises(ValueError, match="multiple primary columns"): + spec.validate(strict=True) + + # In non-strict mode, should pass but have warnings + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("multiple primary columns" in w for w in result.warnings) + + def test_column_without_type_or_options_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="empty_col"), + ] + ) + } + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("No type specified" in w for w in result.warnings) + + def test_no_output_destination_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("No output_destination" in w for w in result.warnings) + + def test_unknown_generator_option_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + }, + generator_options={"unknown_option": "value"} + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("Unknown generator option" in w for w in result.warnings) + + def test_multiple_errors_collected(self): + """Test that all errors are collected before raising""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=-10, # Error 1 + partitions=0, # Error 2 + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="id", type="string"), # Error 3: duplicate + ColumnDefinition(name="email", baseColumn="phone"), # Error 4: nonexistent + ] + ) + } + ) + + with pytest.raises(ValueError) as exc_info: + spec.validate() + + error_msg = str(exc_info.value) + # Should contain all errors + assert "invalid number_of_rows" in error_msg + assert "invalid partitions" in error_msg + assert "duplicate column names" in error_msg + assert "does not exist" in error_msg + + def test_strict_mode_raises_on_warnings(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + # No output_destination - will generate warning + ) + + # Strict mode should raise + with pytest.raises(ValueError): + spec.validate(strict=True) + + # Non-strict mode should pass + result = spec.validate(strict=False) + assert result.is_valid() + + def test_valid_base_column_chain(self): + """Test that valid baseColumn chains work""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="code", type="string", baseColumn="id"), + ColumnDefinition(name="hash", type="string", baseColumn="code"), + ] + ) + }, + output_destination=FilePathTarget(base_path="/tmp/data", output_format="parquet") + ) + + result = spec.validate(strict=False) + assert result.is_valid() + + def test_multiple_tables_validation(self): + """Test validation across multiple tables""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ), + "orders": TableDefinition( + number_of_rows=-50, # Error in second table + columns=[ColumnDefinition(name="order_id", type="int", primary=True)] + ), + "products": TableDefinition( + number_of_rows=200, + columns=[] # Error: no columns + ) + } + ) + + with pytest.raises(ValueError) as exc_info: + spec.validate() + + error_msg = str(exc_info.value) + # Should find errors in both tables + assert "orders" in error_msg + assert "products" in error_msg + + +class TestTargetValidation: + """Tests for output target validation""" + + def test_valid_uc_schema_target(self): + target = UCSchemaTarget(catalog="main", schema_="default") + assert target.catalog == "main" + assert target.schema_ == "default" + + def test_uc_schema_empty_catalog_raises_error(self): + with pytest.raises(ValueError, match="non-empty"): + UCSchemaTarget(catalog="", schema_="default") + + def test_valid_file_path_target(self): + target = FilePathTarget(base_path="/tmp/data", output_format="parquet") + assert target.base_path == "/tmp/data" + assert target.output_format == "parquet" + + def test_file_path_empty_base_path_raises_error(self): + with pytest.raises(ValueError, match="non-empty"): + FilePathTarget(base_path="", output_format="csv") + + def test_file_path_invalid_format_raises_error(self): + with pytest.raises(ValueError): + FilePathTarget(base_path="/tmp/data", output_format="json") + + +class TestValidationIntegration: + """Integration tests for validation""" + + def test_realistic_valid_spec(self): + """Test a realistic, valid specification""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=1000, + partitions=4, + columns=[ + ColumnDefinition(name="user_id", type="int", primary=True), + ColumnDefinition(name="username", type="string", options={ + "template": r"\w{8,12}" + }), + ColumnDefinition(name="email", type="string", options={ + "template": r"\w.\w@\w.com" + }), + ColumnDefinition(name="age", type="int", options={ + "min": 18, "max": 99 + }), + ] + ), + "orders": TableDefinition( + number_of_rows=5000, + columns=[ + ColumnDefinition(name="order_id", type="int", primary=True), + ColumnDefinition(name="amount", type="decimal", options={ + "min": 10.0, "max": 1000.0 + }), + ] + ) + }, + output_destination=UCSchemaTarget( + catalog="main", + schema_="synthetic_data" + ), + generator_options={ + "random": True, + "randomSeed": 42 + } + ) + + result = spec.validate(strict=True) + assert result.is_valid() + assert len(result.errors) == 0 + assert len(result.warnings) == 0 \ No newline at end of file