diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py
new file mode 100644
index 00000000..20383e04
--- /dev/null
+++ b/dbldatagen/spec/column_spec.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from typing import Any, Literal
+
+from .compat import BaseModel, root_validator
+
+
+DbldatagenBasicType = Literal[
+ "string",
+ "int",
+ "long",
+ "float",
+ "double",
+ "decimal",
+ "boolean",
+ "date",
+ "timestamp",
+ "short",
+ "byte",
+ "binary",
+ "integer",
+ "bigint",
+ "tinyint",
+]
+class ColumnDefinition(BaseModel):
+ name: str
+ type: DbldatagenBasicType | None = None
+ primary: bool = False
+ options: dict[str, Any] | None = {}
+ nullable: bool | None = False
+ omit: bool | None = False
+ baseColumn: str | None = "id"
+ baseColumnType: str | None = "auto"
+
+ @root_validator()
+ def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]:
+ """
+ Validates constraints across the entire model after individual fields are processed.
+ """
+ is_primary = values.get("primary")
+ options = values.get("options", {})
+ name = values.get("name")
+ is_nullable = values.get("nullable")
+ column_type = values.get("type")
+
+ if is_primary:
+ if "min" in options or "max" in options:
+ raise ValueError(f"Primary column '{name}' cannot have min/max options.")
+
+ if is_nullable:
+ raise ValueError(f"Primary column '{name}' cannot be nullable.")
+
+ if column_type is None:
+ raise ValueError(f"Primary column '{name}' must have a type defined.")
+ return values
diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py
new file mode 100644
index 00000000..3b604afb
--- /dev/null
+++ b/dbldatagen/spec/compat.py
@@ -0,0 +1,31 @@
+# This module acts as a compatibility layer for Pydantic V1 and V2.
+
+try:
+ # This will succeed on environments with Pydantic V2.x
+ # It imports the V1 API that is bundled within V2.
+ from pydantic.v1 import BaseModel, Field, validator, constr, root_validator
+
+except ImportError:
+ # This will be executed on environments with only Pydantic V1.x
+ from pydantic import BaseModel, Field, validator, constr, root_validator
+
+# In your application code, do this:
+# from .compat import BaseModel
+# NOT this:
+# from pydantic import BaseModel
+
+# FastAPI Notes
+# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py
+
+
+"""
+## Why This Approach
+No Installation Required: It directly addresses your core requirement.
+You don't need to %pip install anything, which avoids conflicts with the pre-installed libraries on Databricks.
+Single Codebase: You maintain one set of code that is guaranteed to work with the Pydantic V1 API, which is available in both runtimes.
+
+Environment Agnostic: Your application code in models.py has no idea which version of Pydantic is actually installed. The compat.py module handles that complexity completely.
+
+Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features),
+you only need to change your application code and your compat.py import statements, making the transition much clearer.
+"""
\ No newline at end of file
diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py
new file mode 100644
index 00000000..c12d9a59
--- /dev/null
+++ b/dbldatagen/spec/generator_spec.py
@@ -0,0 +1,278 @@
+from __future__ import annotations
+
+from typing import Any, Literal, Union
+
+import pandas as pd
+from IPython.display import HTML, display
+
+from dbldatagen.spec.column_spec import ColumnDefinition
+
+from .compat import BaseModel, validator
+
+
+class UCSchemaTarget(BaseModel):
+ catalog: str
+ schema_: str
+ output_format: str = "delta" # Default to delta for UC Schema
+
+ @validator("catalog", "schema_")
+ def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argument
+ if not v.strip():
+ raise ValueError("Identifier must be non-empty.")
+ if not v.isidentifier():
+ logger.warning(
+ f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.")
+ return v.strip()
+
+ def __str__(self):
+ return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)"
+
+
+class FilePathTarget(BaseModel):
+ base_path: str
+ output_format: Literal["csv", "parquet"] # No default, must be specified
+
+ @validator("base_path")
+ def validate_base_path(cls, v): # noqa: N805, pylint: disable=no-self-argument
+ if not v.strip():
+ raise ValueError("base_path must be non-empty.")
+ return v.strip()
+
+ def __str__(self):
+ return f"{self.base_path} (Format: {self.output_format}, Type: File Path)"
+
+
+class TableDefinition(BaseModel):
+ number_of_rows: int
+ partitions: int | None = None
+ columns: list[ColumnDefinition]
+
+
+class ValidationResult:
+ """Container for validation results with errors and warnings."""
+
+ def __init__(self) -> None:
+ self.errors: list[str] = []
+ self.warnings: list[str] = []
+
+ def add_error(self, message: str) -> None:
+ """Add an error message."""
+ self.errors.append(message)
+
+ def add_warning(self, message: str) -> None:
+ """Add a warning message."""
+ self.warnings.append(message)
+
+ def is_valid(self) -> bool:
+ """Returns True if there are no errors."""
+ return len(self.errors) == 0
+
+ def __str__(self) -> str:
+ """String representation of validation results."""
+ lines = []
+ if self.is_valid():
+ lines.append("✓ Validation passed successfully")
+ else:
+ lines.append("✗ Validation failed")
+
+ if self.errors:
+ lines.append(f"\nErrors ({len(self.errors)}):")
+ for i, error in enumerate(self.errors, 1):
+ lines.append(f" {i}. {error}")
+
+ if self.warnings:
+ lines.append(f"\nWarnings ({len(self.warnings)}):")
+ for i, warning in enumerate(self.warnings, 1):
+ lines.append(f" {i}. {warning}")
+
+ return "\n".join(lines)
+
+class DatagenSpec(BaseModel):
+ tables: dict[str, TableDefinition]
+ output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg
+ generator_options: dict[str, Any] | None = {}
+ intended_for_databricks: bool | None = None # May be infered.
+
+ def _check_circular_dependencies(
+ self,
+ table_name: str,
+ columns: list[ColumnDefinition]
+ ) -> list[str]:
+ """
+ Check for circular dependencies in baseColumn references.
+ Returns a list of error messages if circular dependencies are found.
+ """
+ errors = []
+ column_map = {col.name: col for col in columns}
+
+ for col in columns:
+ if col.baseColumn and col.baseColumn != "id":
+ # Track the dependency chain
+ visited = set()
+ current = col.name
+
+ while current:
+ if current in visited:
+ # Found a cycle
+ cycle_path = " -> ".join(list(visited) + [current])
+ errors.append(
+ f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}"
+ )
+ break
+
+ visited.add(current)
+ current_col = column_map.get(current)
+
+ if not current_col:
+ break
+
+ # Move to the next column in the chain
+ if current_col.baseColumn and current_col.baseColumn != "id":
+ if current_col.baseColumn not in column_map:
+ # baseColumn doesn't exist - we'll catch this in another validation
+ break
+ current = current_col.baseColumn
+ else:
+ # Reached a column that doesn't have a baseColumn or uses "id"
+ break
+
+ return errors
+
+ def validate(self, strict: bool = True) -> ValidationResult:
+ """
+ Validates the entire DatagenSpec configuration.
+ Always runs all validation checks and collects all errors and warnings.
+
+ Args:
+ strict: If True, raises ValueError if any errors or warnings are found.
+ If False, only raises ValueError if errors (not warnings) are found.
+
+ Returns:
+ ValidationResult object containing all errors and warnings found.
+
+ Raises:
+ ValueError: If validation fails based on strict mode setting.
+ The exception message contains all errors and warnings.
+ """
+ result = ValidationResult()
+
+ # 1. Check that there's at least one table
+ if not self.tables:
+ result.add_error("Spec must contain at least one table definition")
+
+ # 2. Validate each table (continue checking all tables even if errors found)
+ for table_name, table_def in self.tables.items():
+ # Check table has at least one column
+ if not table_def.columns:
+ result.add_error(f"Table '{table_name}' must have at least one column")
+ continue # Skip further checks for this table since it has no columns
+
+ # Check row count is positive
+ if table_def.number_of_rows <= 0:
+ result.add_error(
+ f"Table '{table_name}' has invalid number_of_rows: {table_def.number_of_rows}. "
+ "Must be a positive integer."
+ )
+
+ # Check partitions if specified
+ #TODO: though this can be a model field check, we are checking here so that one can correct
+ # Can we find a way to use the default way?
+ if table_def.partitions is not None and table_def.partitions <= 0:
+ result.add_error(
+ f"Table '{table_name}' has invalid partitions: {table_def.partitions}. "
+ "Must be a positive integer or None."
+ )
+
+ # Check for duplicate column names
+ # TODO: Not something possible if we right model, recheck
+ column_names = [col.name for col in table_def.columns]
+ duplicates = [name for name in set(column_names) if column_names.count(name) > 1]
+ if duplicates:
+ result.add_error(
+ f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}"
+ )
+
+ # Build column map for reference checking
+ column_map = {col.name: col for col in table_def.columns}
+
+ # TODO: Check baseColumn references, this is tricky? check the dbldefaults
+ for col in table_def.columns:
+ if col.baseColumn and col.baseColumn != "id":
+ if col.baseColumn not in column_map:
+ result.add_error(
+ f"Table '{table_name}', column '{col.name}': "
+ f"baseColumn '{col.baseColumn}' does not exist in the table"
+ )
+
+ # Check for circular dependencies in baseColumn references
+ circular_errors = self._check_circular_dependencies(table_name, table_def.columns)
+ for error in circular_errors:
+ result.add_error(error)
+
+ # Check primary key constraints
+ primary_columns = [col for col in table_def.columns if col.primary]
+ if len(primary_columns) > 1:
+ primary_names = [col.name for col in primary_columns]
+ result.add_warning(
+ f"Table '{table_name}' has multiple primary columns: {', '.join(primary_names)}. "
+ "This may not be the intended behavior."
+ )
+
+ # Check for columns with no type and not using baseColumn properly
+ for col in table_def.columns:
+ if not col.primary and not col.type and not col.options:
+ result.add_warning(
+ f"Table '{table_name}', column '{col.name}': "
+ "No type specified and no options provided. "
+ "Column may not generate data as expected."
+ )
+
+ # 3. Check output destination
+ if not self.output_destination:
+ result.add_warning(
+ "No output_destination specified. Data will be generated but not persisted. "
+ "Set output_destination to save generated data."
+ )
+
+ # 4. Validate generator options (if any known options)
+ if self.generator_options:
+ known_options = [
+ "random", "randomSeed", "randomSeedMethod", "verbose",
+ "debug", "seedColumnName"
+ ]
+ for key in self.generator_options:
+ if key not in known_options:
+ result.add_warning(
+ f"Unknown generator option: '{key}'. "
+ "This may be ignored during generation."
+ )
+
+ # Now that all validations are complete, decide whether to raise
+ if (strict and (result.errors or result.warnings)) or (not strict and result.errors):
+ raise ValueError(str(result))
+
+ return result
+
+
+ def display_all_tables(self) -> None:
+ for table_name, table_def in self.tables.items():
+ print(f"Table: {table_name}")
+
+ if self.output_destination:
+ output = f"{self.output_destination}"
+ display(HTML(f"Output destination: {output}"))
+ else:
+ message = (
+ "Output destination: "
+ "None
"
+ "Set it using the output_destination "
+ "attribute on your DatagenSpec object "
+ "(e.g., my_spec.output_destination = UCSchemaTarget(...))."
+ )
+ display(HTML(message))
+
+ df = pd.DataFrame([col.dict() for col in table_def.columns])
+ try:
+ display(df)
+ except NameError:
+ print(df.to_string())
diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py
new file mode 100644
index 00000000..a508b1a5
--- /dev/null
+++ b/dbldatagen/spec/generator_spec_impl.py
@@ -0,0 +1,254 @@
+import logging
+from typing import Dict, Union
+import posixpath
+
+from dbldatagen.spec.generator_spec import TableDefinition
+from pyspark.sql import SparkSession
+import dbldatagen as dg
+from .generator_spec import DatagenSpec, UCSchemaTarget, FilePathTarget, ColumnDefinition
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+logger = logging.getLogger(__name__)
+
+INTERNAL_ID_COLUMN_NAME = "id"
+
+
+class Generator:
+ """
+ Main data generation orchestrator that handles configuration, preparation, and writing of data.
+ """
+
+ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> None:
+ """
+ Initialize the Generator with a SparkSession.
+ Args:
+ spark: An existing SparkSession instance
+ app_name: Application name for logging purposes
+ Raises:
+ RuntimeError: If spark is None
+ """
+ if not spark:
+ logger.error(
+ "SparkSession cannot be None during Generator initialization")
+ raise RuntimeError("SparkSession cannot be None")
+ self.spark = spark
+ self._created_spark_session = False
+ self.app_name = app_name
+ logger.info("Generator initialized with SparkSession")
+
+ def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> Dict[str, str]:
+ """
+ Convert a ColumnDefinition to dbldatagen column specification.
+ Args:
+ col_def: ColumnDefinition object containing column configuration
+ Returns:
+ Dictionary containing dbldatagen column specification
+ """
+ col_name = col_def.name
+ col_type = col_def.type
+ kwargs = col_def.options.copy() if col_def.options is not None else {}
+
+ if col_def.primary:
+ kwargs["colType"] = col_type
+ kwargs["baseColumn"] = INTERNAL_ID_COLUMN_NAME
+
+ if col_type == "string":
+ kwargs["baseColumnType"] = "hash"
+ elif col_type not in ["int", "long", "integer", "bigint", "short"]:
+ kwargs["baseColumnType"] = "auto"
+ logger.warning(
+ f"Primary key '{col_name}' has non-standard type '{col_type}'")
+
+ # Log conflicting options for primary keys
+ conflicting_opts_for_pk = [
+ "distribution", "template", "dataRange", "random", "omit",
+ "min", "max", "uniqueValues", "values", "expr"
+ ]
+
+ for opt_key in conflicting_opts_for_pk:
+ if opt_key in kwargs:
+ logger.warning(
+ f"Primary key '{col_name}': Option '{opt_key}' may be ignored")
+
+ if col_def.omit is not None and col_def.omit:
+ kwargs["omit"] = True
+ else:
+ kwargs = col_def.options.copy() if col_def.options is not None else {}
+
+ if col_type:
+ kwargs["colType"] = col_type
+ if col_def.baseColumn:
+ kwargs["baseColumn"] = col_def.baseColumn
+ if col_def.baseColumnType:
+ kwargs["baseColumnType"] = col_def.baseColumnType
+ if col_def.omit is not None:
+ kwargs["omit"] = col_def.omit
+
+ return kwargs
+
+ def _prepare_data_generators(
+ self,
+ config: DatagenSpec,
+ config_source_name: str = "PydanticConfig"
+ ) -> Dict[str, dg.DataGenerator]:
+ """
+ Prepare DataGenerator specifications for each table based on the configuration.
+ Args:
+ config: DatagenSpec Pydantic object containing table configurations
+ config_source_name: Name for the configuration source (for logging)
+ Returns:
+ Dictionary mapping table names to their configured dbldatagen.DataGenerator objects
+ Raises:
+ RuntimeError: If SparkSession is not available
+ ValueError: If any table preparation fails
+ Exception: If any unexpected error occurs during preparation
+ """
+ logger.info(
+ f"Preparing data generators for {len(config.tables)} tables")
+
+ if not self.spark:
+ logger.error(
+ "SparkSession is not available. Cannot prepare data generators")
+ raise RuntimeError(
+ "SparkSession is not available. Cannot prepare data generators")
+
+ tables_config: Dict[str, TableDefinition] = config.tables
+ global_gen_options = config.generator_options if config.generator_options else {}
+
+ prepared_generators: Dict[str, dg.DataGenerator] = {}
+ generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable
+
+ for table_name in generation_order:
+ table_spec = tables_config[table_name]
+ logger.info(f"Preparing table: {table_name}")
+
+ try:
+ # Create DataGenerator instance
+ data_gen = dg.DataGenerator(
+ sparkSession=self.spark,
+ name=f"{table_name}_spec_from_{config_source_name}",
+ rows=table_spec.number_of_rows,
+ partitions=table_spec.partitions,
+ **global_gen_options,
+ )
+
+ # Process each column
+ for col_def in table_spec.columns:
+ kwargs = self._columnspec_to_datagen_columnspec(col_def)
+ data_gen = data_gen.withColumn(colName=col_def.name, **kwargs)
+ # Has performance implications.
+
+ prepared_generators[table_name] = data_gen
+ logger.info(f"Successfully prepared table: {table_name}")
+
+ except Exception as e:
+ logger.error(f"Failed to prepare table '{table_name}': {e}")
+ raise RuntimeError(
+ f"Failed to prepare table '{table_name}': {e}") from e
+
+ logger.info("All data generators prepared successfully")
+ return prepared_generators
+
+ def write_prepared_data(
+ self,
+ prepared_generators: Dict[str, dg.DataGenerator],
+ output_destination: Union[UCSchemaTarget, FilePathTarget, None],
+ config_source_name: str = "PydanticConfig",
+ ) -> None:
+ """
+ Write data from prepared generators to the specified output destination.
+
+ Args:
+ prepared_generators: Dictionary of prepared DataGenerator objects
+ output_destination: Target destination for data output
+ config_source_name: Name for the configuration source (for logging)
+
+ Raises:
+ RuntimeError: If any table write fails
+ ValueError: If output destination is not properly configured
+ """
+ logger.info("Starting data writing phase")
+
+ if not prepared_generators:
+ logger.warning("No prepared data generators to write")
+ return
+
+ for table_name, data_gen in prepared_generators.items():
+ logger.info(f"Writing table: {table_name}")
+
+ try:
+ df = data_gen.build()
+ requested_rows = data_gen.rowCount
+ actual_row_count = df.count()
+ logger.info(
+ f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})")
+
+ if actual_row_count == 0 and requested_rows > 0:
+ logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0")
+
+ # Write data based on destination type
+ if isinstance(output_destination, FilePathTarget):
+ output_path = posixpath.join(output_destination.base_path, table_name)
+ df.write.format(output_destination.output_format).mode("overwrite").save(output_path)
+ logger.info(f"Wrote table '{table_name}' to file path: {output_path}")
+
+ elif isinstance(output_destination, UCSchemaTarget):
+ output_table = f"{output_destination.catalog}.{output_destination.schema_}.{table_name}"
+ df.write.mode("overwrite").saveAsTable(output_table)
+ logger.info(f"Wrote table '{table_name}' to Unity Catalog: {output_table}")
+ else:
+ logger.warning("No output destination specified, skipping data write")
+ return
+ except Exception as e:
+ logger.error(f"Failed to write table '{table_name}': {e}")
+ raise RuntimeError(f"Failed to write table '{table_name}': {e}") from e
+ logger.info("All data writes completed successfully")
+
+ def generate_and_write_data(
+ self,
+ config: DatagenSpec,
+ config_source_name: str = "PydanticConfig"
+ ) -> None:
+ """
+ Combined method to prepare data generators and write data in one operation.
+ This method orchestrates the complete data generation workflow:
+ 1. Prepare data generators from configuration
+ 2. Write data to the specified destination
+ Args:
+ config: DatagenSpec Pydantic object containing table configurations
+ config_source_name: Name for the configuration source (for logging)
+ Raises:
+ RuntimeError: If SparkSession is not available or any step fails
+ ValueError: If critical errors occur during preparation or writing
+ """
+ logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables")
+
+ try:
+ # Phase 1: Prepare data generators
+ prepared_generators_map = self._prepare_data_generators(config, config_source_name)
+
+ if not prepared_generators_map and list(config.tables.keys()):
+ logger.warning(
+ "No data generators were successfully prepared, though tables were defined")
+ return
+
+ # Phase 2: Write data
+ self.write_prepared_data(
+ prepared_generators_map,
+ config.output_destination,
+ config_source_name
+ )
+
+ logger.info(
+ "Combined data generation and writing completed successfully")
+
+ except Exception as e:
+ logger.error(
+ f"Error during combined data generation and writing: {e}")
+ raise RuntimeError(
+ f"Error during combined data generation and writing: {e}") from e
\ No newline at end of file
diff --git a/makefile b/makefile
index 772397bf..a5f4486c 100644
--- a/makefile
+++ b/makefile
@@ -8,7 +8,7 @@ clean:
.venv/bin/python:
pip install hatch
- hatch env create
+ hatch env create test-pydantic.pydantic==1.10.6-v1
dev: .venv/bin/python
@hatch run which python
@@ -20,7 +20,7 @@ fmt:
hatch run fmt
test:
- hatch run test
+ hatch run test-pydantic:test
test-coverage:
make test && open htmlcov/index.html
diff --git a/pydantic_compat.md b/pydantic_compat.md
new file mode 100644
index 00000000..abf26e60
--- /dev/null
+++ b/pydantic_compat.md
@@ -0,0 +1,101 @@
+To write code that works on both Pydantic V1 and V2 and ensures a smooth future migration, you should code against the V1 API but import it through a compatibility shim. This approach uses V1's syntax, which Pydantic V2 can understand via its built-in V1 compatibility layer.
+
+-----
+
+### \#\# The Golden Rule: Code to V1, Import via a Shim 💡
+
+The core strategy is to **write all your models using Pydantic V1 syntax and features**. You then use a special utility file to handle the imports, which makes your application code completely agnostic to the installed Pydantic version.
+
+-----
+
+### \#\# 1. Implement a Compatibility Shim (`compat.py`)
+
+This is the most critical step. Create a file named `compat.py` in your project that intelligently imports Pydantic components. Your application will import everything from this file instead of directly from `pydantic`.
+
+```python
+# compat.py
+# This module acts as a compatibility layer for Pydantic V1 and V2.
+
+try:
+ # This will succeed on environments with Pydantic V2.x
+ # It imports the V1 API that is bundled within V2.
+ from pydantic.v1 import BaseModel, Field, validator, constr
+
+except ImportError:
+ # This will be executed on environments with only Pydantic V1.x
+ from pydantic import BaseModel, Field, validator, constr
+
+# In your application code, do this:
+# from .compat import BaseModel
+# NOT this:
+# from pydantic import BaseModel
+```
+
+-----
+
+### \#\# 2. Stick to V1 Features and Syntax (Do's and Don'ts)
+
+By following these rules in your application code, you ensure the logic works on both versions.
+
+#### **✅ Models and Fields: DO**
+
+ * Use standard `BaseModel` and `Field` for all your data structures. This is the most stable part of the API.
+
+#### **❌ Models and Fields: DON'T**
+
+ * **Do not use `__root__` models**. This V1 feature was removed in V2 and the compatibility is not perfect. Instead, model the data explicitly, even if it feels redundant.
+ * **Bad (Avoid):** `class MyList(BaseModel): __root__: list[str]`
+ * **Good (Compatible):** `class MyList(BaseModel): items: list[str]`
+
+#### **✅ Configuration: DO**
+
+ * Use the nested `class Config:` for model configuration. This is the V1 way and is fully supported by the V2 compatibility layer.
+ * **Example:**
+ ```python
+ from .compat import BaseModel
+
+ class User(BaseModel):
+ id: int
+ full_name: str
+
+ class Config:
+ orm_mode = True # V2's compatibility layer translates this
+ allow_population_by_field_name = True
+ ```
+
+#### **❌ Configuration: DON'T**
+
+ * **Do not use the V2 `model_config` dictionary**. This is a V2-only feature.
+
+#### **✅ Validators and Data Types: DO**
+
+ * Use the standard V1 `@validator`. It's robust and works perfectly across both versions.
+ * Use V1 constrained types like `constr`, `conint`, `conlist`.
+ * **Example:**
+ ```python
+ from .compat import BaseModel, validator, constr
+
+ class Product(BaseModel):
+ name: constr(min_length=3)
+
+ @validator("name")
+ def name_must_be_alpha(cls, v):
+ if not v.isalpha():
+ raise ValueError("Name must be alphabetic")
+ return v
+ ```
+
+#### **❌ Validators and Data Types: DON'T**
+
+ * **Do not use V2 decorators** like `@field_validator`, `@model_validator`, or `@field_serializer`.
+ * **Do not use the V2 `Annotated` syntax** for validation (e.g., `Annotated[str, StringConstraints(min_length=2)]`).
+
+-----
+
+### \#\# 3. The Easy Migration Path
+
+When you're finally ready to leave V1 behind and upgrade your code to be V2-native, the process will be straightforward because your code is already consistent:
+
+1. **Change Imports**: Your first step will be a simple find-and-replace to change all `from .compat import ...` statements to `from pydantic import ...`.
+2. **Run a Codelinter**: Tools like **Ruff** have built-in rules that can automatically refactor most of your V1 syntax (like `Config` classes and `@validator`s) to the new V2 syntax.
+3. **Manual Refinements**: Address any complex patterns the automated tools couldn't handle, like replacing your `__root__` model alternatives.
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 13728ba2..f0dabcf7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -103,21 +103,35 @@ dependencies = [
"jmespath>=0.10.0",
"py4j>=0.10.9",
"pickleshare>=0.7.5",
+ "ipython>=7.32.0",
]
python="3.10"
-
-# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better
path = ".venv"
[tool.hatch.envs.default.scripts]
test = "pytest tests/ -n 10 --cov --cov-report=html --timeout 600 --durations 20"
-fmt = ["ruff check . --fix",
- "mypy .",
- "pylint --output-format=colorized -j 0 dbldatagen tests"]
-verify = ["ruff check .",
- "mypy .",
- "pylint --output-format=colorized -j 0 dbldatagen tests"]
+fmt = [
+ "ruff check . --fix",
+ "mypy .",
+ "pylint --output-format=colorized -j 0 dbldatagen tests"
+]
+verify = [
+ "ruff check .",
+ "mypy .",
+ "pylint --output-format=colorized -j 0 dbldatagen tests"
+]
+
+
+[tool.hatch.envs.test-pydantic]
+template = "default"
+matrix = [
+ { pydantic_version = ["1.10.6", "2.8.2"] }
+]
+extra-dependencies = [
+ "pydantic=={matrix:pydantic_version}"
+]
+
# Ruff configuration - replaces flake8, isort, pydocstyle, etc.
[tool.ruff]
diff --git a/scratch.md b/scratch.md
new file mode 100644
index 00000000..a3afa5c3
--- /dev/null
+++ b/scratch.md
@@ -0,0 +1,4 @@
+Pydantic Notes
+https://docs.databricks.com/aws/en/release-notes/runtime/14.3lts - 1.10.6
+https://docs.databricks.com/aws/en/release-notes/runtime/15.4lts - 1.10.6
+https://docs.databricks.com/aws/en/release-notes/runtime/16.4lts - 2.8.2 (2.20.1 - core)
\ No newline at end of file
diff --git a/tests/test_specs.py b/tests/test_specs.py
new file mode 100644
index 00000000..d3c8ab2c
--- /dev/null
+++ b/tests/test_specs.py
@@ -0,0 +1,466 @@
+from dbldatagen.spec.generator_spec import DatagenSpec
+import pytest
+from dbldatagen.spec.generator_spec import (
+ DatagenSpec,
+ TableDefinition,
+ ColumnDefinition,
+ UCSchemaTarget,
+ FilePathTarget,
+ ValidationResult
+)
+
+class TestValidationResult:
+ """Tests for ValidationResult class"""
+
+ def test_empty_result_is_valid(self):
+ result = ValidationResult()
+ assert result.is_valid()
+ assert len(result.errors) == 0
+ assert len(result.warnings) == 0
+
+ def test_result_with_errors_is_invalid(self):
+ result = ValidationResult()
+ result.add_error("Test error")
+ assert not result.is_valid()
+ assert len(result.errors) == 1
+
+ def test_result_with_only_warnings_is_valid(self):
+ result = ValidationResult()
+ result.add_warning("Test warning")
+ assert result.is_valid()
+ assert len(result.warnings) == 1
+
+ def test_result_string_representation(self):
+ result = ValidationResult()
+ result.add_error("Error 1")
+ result.add_error("Error 2")
+ result.add_warning("Warning 1")
+
+ result_str = str(result)
+ assert "✗ Validation failed" in result_str
+ assert "Errors (2)" in result_str
+ assert "Error 1" in result_str
+ assert "Error 2" in result_str
+ assert "Warnings (1)" in result_str
+ assert "Warning 1" in result_str
+
+ def test_valid_result_string_representation(self):
+ result = ValidationResult()
+ result_str = str(result)
+ assert "✓ Validation passed successfully" in result_str
+
+
+class TestColumnDefinitionValidation:
+ """Tests for ColumnDefinition validation"""
+
+ def test_valid_primary_column(self):
+ col = ColumnDefinition(
+ name="id",
+ type="int",
+ primary=True
+ )
+ assert col.primary
+ assert col.type == "int"
+
+ def test_primary_column_with_min_max_raises_error(self):
+ with pytest.raises(ValueError, match="cannot have min/max options"):
+ ColumnDefinition(
+ name="id",
+ type="int",
+ primary=True,
+ options={"min": 1, "max": 100}
+ )
+
+ def test_primary_column_nullable_raises_error(self):
+ with pytest.raises(ValueError, match="cannot be nullable"):
+ ColumnDefinition(
+ name="id",
+ type="int",
+ primary=True,
+ nullable=True
+ )
+
+ def test_primary_column_without_type_raises_error(self):
+ with pytest.raises(ValueError, match="must have a type defined"):
+ ColumnDefinition(
+ name="id",
+ primary=True
+ )
+
+ def test_non_primary_column_without_type(self):
+ # Should not raise
+ col = ColumnDefinition(
+ name="data",
+ options={"values": ["a", "b", "c"]}
+ )
+ assert col.name == "data"
+
+
+class TestDatagenSpecValidation:
+ """Tests for DatagenSpec.validate() method"""
+
+ def test_valid_spec_passes_validation(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="name", type="string", options={"values": ["Alice", "Bob"]}),
+ ]
+ )
+ },
+ output_destination=UCSchemaTarget(catalog="main", schema_="default")
+ )
+
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+ assert len(result.errors) == 0
+
+ def test_empty_tables_raises_error(self):
+ spec = DatagenSpec(tables={})
+
+ with pytest.raises(ValueError, match="at least one table"):
+ spec.validate(strict=True)
+
+ def test_table_without_columns_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "empty_table": TableDefinition(
+ number_of_rows=100,
+ columns=[]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="must have at least one column"):
+ spec.validate()
+
+ def test_negative_row_count_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=-10,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="invalid number_of_rows"):
+ spec.validate()
+
+ def test_zero_row_count_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=0,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="invalid number_of_rows"):
+ spec.validate()
+
+ def test_invalid_partitions_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ partitions=-5,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="invalid partitions"):
+ spec.validate()
+
+ def test_duplicate_column_names_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="duplicate", type="string"),
+ ColumnDefinition(name="duplicate", type="int"),
+ ]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="duplicate column names"):
+ spec.validate()
+
+ def test_invalid_base_column_reference_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="email", type="string", baseColumn="nonexistent"),
+ ]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="does not exist"):
+ spec.validate()
+
+ def test_circular_dependency_raises_error(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="col_a", type="string", baseColumn="col_b"),
+ ColumnDefinition(name="col_b", type="string", baseColumn="col_c"),
+ ColumnDefinition(name="col_c", type="string", baseColumn="col_a"),
+ ]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError, match="Circular dependency"):
+ spec.validate()
+
+ def test_multiple_primary_columns_warning(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id1", type="int", primary=True),
+ ColumnDefinition(name="id2", type="int", primary=True),
+ ]
+ )
+ }
+ )
+
+ # In strict mode, warnings cause errors
+ with pytest.raises(ValueError, match="multiple primary columns"):
+ spec.validate(strict=True)
+
+ # In non-strict mode, should pass but have warnings
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+ assert len(result.warnings) > 0
+ assert any("multiple primary columns" in w for w in result.warnings)
+
+ def test_column_without_type_or_options_warning(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="empty_col"),
+ ]
+ )
+ }
+ )
+
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+ assert len(result.warnings) > 0
+ assert any("No type specified" in w for w in result.warnings)
+
+ def test_no_output_destination_warning(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ }
+ )
+
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+ assert len(result.warnings) > 0
+ assert any("No output_destination" in w for w in result.warnings)
+
+ def test_unknown_generator_option_warning(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ },
+ generator_options={"unknown_option": "value"}
+ )
+
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+ assert len(result.warnings) > 0
+ assert any("Unknown generator option" in w for w in result.warnings)
+
+ def test_multiple_errors_collected(self):
+ """Test that all errors are collected before raising"""
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=-10, # Error 1
+ partitions=0, # Error 2
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="id", type="string"), # Error 3: duplicate
+ ColumnDefinition(name="email", baseColumn="phone"), # Error 4: nonexistent
+ ]
+ )
+ }
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ spec.validate()
+
+ error_msg = str(exc_info.value)
+ # Should contain all errors
+ assert "invalid number_of_rows" in error_msg
+ assert "invalid partitions" in error_msg
+ assert "duplicate column names" in error_msg
+ assert "does not exist" in error_msg
+
+ def test_strict_mode_raises_on_warnings(self):
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ )
+ }
+ # No output_destination - will generate warning
+ )
+
+ # Strict mode should raise
+ with pytest.raises(ValueError):
+ spec.validate(strict=True)
+
+ # Non-strict mode should pass
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+
+ def test_valid_base_column_chain(self):
+ """Test that valid baseColumn chains work"""
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[
+ ColumnDefinition(name="id", type="int", primary=True),
+ ColumnDefinition(name="code", type="string", baseColumn="id"),
+ ColumnDefinition(name="hash", type="string", baseColumn="code"),
+ ]
+ )
+ },
+ output_destination=FilePathTarget(base_path="/tmp/data", output_format="parquet")
+ )
+
+ result = spec.validate(strict=False)
+ assert result.is_valid()
+
+ def test_multiple_tables_validation(self):
+ """Test validation across multiple tables"""
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=100,
+ columns=[ColumnDefinition(name="id", type="int", primary=True)]
+ ),
+ "orders": TableDefinition(
+ number_of_rows=-50, # Error in second table
+ columns=[ColumnDefinition(name="order_id", type="int", primary=True)]
+ ),
+ "products": TableDefinition(
+ number_of_rows=200,
+ columns=[] # Error: no columns
+ )
+ }
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ spec.validate()
+
+ error_msg = str(exc_info.value)
+ # Should find errors in both tables
+ assert "orders" in error_msg
+ assert "products" in error_msg
+
+
+class TestTargetValidation:
+ """Tests for output target validation"""
+
+ def test_valid_uc_schema_target(self):
+ target = UCSchemaTarget(catalog="main", schema_="default")
+ assert target.catalog == "main"
+ assert target.schema_ == "default"
+
+ def test_uc_schema_empty_catalog_raises_error(self):
+ with pytest.raises(ValueError, match="non-empty"):
+ UCSchemaTarget(catalog="", schema_="default")
+
+ def test_valid_file_path_target(self):
+ target = FilePathTarget(base_path="/tmp/data", output_format="parquet")
+ assert target.base_path == "/tmp/data"
+ assert target.output_format == "parquet"
+
+ def test_file_path_empty_base_path_raises_error(self):
+ with pytest.raises(ValueError, match="non-empty"):
+ FilePathTarget(base_path="", output_format="csv")
+
+ def test_file_path_invalid_format_raises_error(self):
+ with pytest.raises(ValueError):
+ FilePathTarget(base_path="/tmp/data", output_format="json")
+
+
+class TestValidationIntegration:
+ """Integration tests for validation"""
+
+ def test_realistic_valid_spec(self):
+ """Test a realistic, valid specification"""
+ spec = DatagenSpec(
+ tables={
+ "users": TableDefinition(
+ number_of_rows=1000,
+ partitions=4,
+ columns=[
+ ColumnDefinition(name="user_id", type="int", primary=True),
+ ColumnDefinition(name="username", type="string", options={
+ "template": r"\w{8,12}"
+ }),
+ ColumnDefinition(name="email", type="string", options={
+ "template": r"\w.\w@\w.com"
+ }),
+ ColumnDefinition(name="age", type="int", options={
+ "min": 18, "max": 99
+ }),
+ ]
+ ),
+ "orders": TableDefinition(
+ number_of_rows=5000,
+ columns=[
+ ColumnDefinition(name="order_id", type="int", primary=True),
+ ColumnDefinition(name="amount", type="decimal", options={
+ "min": 10.0, "max": 1000.0
+ }),
+ ]
+ )
+ },
+ output_destination=UCSchemaTarget(
+ catalog="main",
+ schema_="synthetic_data"
+ ),
+ generator_options={
+ "random": True,
+ "randomSeed": 42
+ }
+ )
+
+ result = spec.validate(strict=True)
+ assert result.is_valid()
+ assert len(result.errors) == 0
+ assert len(result.warnings) == 0
\ No newline at end of file