Skip to content

Commit b814472

Browse files
committed
!squash more
1 parent 2b7b9ed commit b814472

File tree

3 files changed

+381
-209
lines changed

3 files changed

+381
-209
lines changed

src/vcspull/models.py renamed to src/vcspull/schemas.py

Lines changed: 112 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,27 @@
1-
"""Pydantic models for vcspull configuration."""
1+
"""Pydantic schemas for vcspull configuration."""
22

33
from __future__ import annotations
44

5+
import enum
56
import os
67
import pathlib
78
import typing as t
8-
from enum import Enum
9-
from pathlib import Path
10-
from typing import Any, Dict, List, Optional, Union
119

1210
from pydantic import (
1311
BaseModel,
1412
ConfigDict,
15-
Field,
16-
HttpUrl,
13+
RootModel,
1714
field_validator,
18-
model_validator,
1915
)
2016

21-
if t.TYPE_CHECKING:
22-
from libvcs._internal.types import VCSLiteral
23-
from libvcs.sync.git import GitSyncRemoteDict
24-
2517
# Type aliases for better readability
26-
PathLike = Union[str, Path]
18+
PathLike = t.Union[str, pathlib.Path]
2719
ConfigName = str
2820
SectionName = str
2921
ShellCommand = str
3022

3123

32-
class VCSType(str, Enum):
24+
class VCSType(str, enum.Enum):
3325
"""Supported version control systems."""
3426

3527
GIT = "git"
@@ -42,8 +34,8 @@ class GitRemote(BaseModel):
4234

4335
name: str
4436
url: str
45-
fetch: Optional[str] = None
46-
push: Optional[str] = None
37+
fetch: str | None = None
38+
push: str | None = None
4739

4840

4941
class RepositoryModel(BaseModel):
@@ -67,10 +59,10 @@ class RepositoryModel(BaseModel):
6759

6860
vcs: str
6961
name: str
70-
path: Union[str, Path]
62+
path: str | pathlib.Path
7163
url: str
72-
remotes: Optional[Dict[str, GitRemote]] = None
73-
shell_command_after: Optional[List[str]] = None
64+
remotes: dict[str, GitRemote] | None = None
65+
shell_command_after: list[str] | None = None
7466

7567
model_config = ConfigDict(
7668
extra="forbid",
@@ -97,15 +89,14 @@ def validate_vcs(cls, v: str) -> str:
9789
ValueError
9890
If VCS type is invalid
9991
"""
100-
if v.lower() not in ("git", "hg", "svn"):
101-
raise ValueError(
102-
f"Invalid VCS type: {v}. Supported types are: git, hg, svn"
103-
)
92+
if v.lower() not in {"git", "hg", "svn"}:
93+
msg = f"Invalid VCS type: {v}. Supported types are: git, hg, svn"
94+
raise ValueError(msg)
10495
return v.lower()
10596

10697
@field_validator("path")
10798
@classmethod
108-
def validate_path(cls, v: Union[str, Path]) -> Path:
99+
def validate_path(cls, v: str | pathlib.Path) -> pathlib.Path:
109100
"""Validate and convert path to Path object.
110101
111102
Parameters
@@ -127,12 +118,13 @@ def validate_path(cls, v: Union[str, Path]) -> Path:
127118
# Convert to string first to handle Path objects
128119
path_str = str(v)
129120
# Expand environment variables and user directory
130-
expanded_path = os.path.expandvars(path_str)
131-
expanded_path = os.path.expanduser(expanded_path)
132-
# Convert to Path object
133-
return Path(expanded_path)
121+
path_obj = pathlib.Path(path_str)
122+
# Use Path methods instead of os.path
123+
expanded_path = pathlib.Path(os.path.expandvars(str(path_obj)))
124+
return expanded_path.expanduser()
134125
except Exception as e:
135-
raise ValueError(f"Invalid path: {v}. Error: {str(e)}")
126+
msg = f"Invalid path: {v}. Error: {e!s}"
127+
raise ValueError(msg) from e
136128

137129
@field_validator("url")
138130
@classmethod
@@ -157,29 +149,30 @@ def validate_url(cls, v: str, info: t.Any) -> str:
157149
If URL is invalid
158150
"""
159151
if not v:
160-
raise ValueError("URL cannot be empty")
152+
msg = "URL cannot be empty"
153+
raise ValueError(msg)
161154

162155
# Different validation based on VCS type
163-
values = info.data
164-
vcs_type = values.get("vcs", "").lower()
156+
# Keeping this but not using yet - can be expanded later
157+
# vcs_type = values.get("vcs", "").lower()
165158

166159
# Basic validation for all URL types
167160
if v.strip() == "":
168-
raise ValueError("URL cannot be empty or whitespace")
161+
msg = "URL cannot be empty or whitespace"
162+
raise ValueError(msg)
169163

170164
# VCS-specific validation could be added here
171165
# For now, just return the URL as is
172166
return v
173167

174168

175-
class ConfigSectionModel(BaseModel):
169+
class ConfigSectionDictModel(RootModel[dict[str, RepositoryModel]]):
176170
"""Configuration section model containing repositories.
177171
178-
A section is a logical grouping of repositories, typically by project or organization.
172+
A section is a logical grouping of repositories, typically by project or
173+
organization.
179174
"""
180175

181-
__root__: Dict[str, RepositoryModel] = Field(default_factory=dict)
182-
183176
def __getitem__(self, key: str) -> RepositoryModel:
184177
"""Get repository by name.
185178
@@ -193,17 +186,17 @@ def __getitem__(self, key: str) -> RepositoryModel:
193186
RepositoryModel
194187
Repository configuration
195188
"""
196-
return self.__root__[key]
189+
return self.root[key]
197190

198-
def __iter__(self) -> t.Iterator[str]:
199-
"""Iterate over repository names.
191+
def keys(self) -> t.KeysView[str]:
192+
"""Get repository names.
200193
201194
Returns
202195
-------
203-
Iterator[str]
204-
Iterator of repository names
196+
KeysView[str]
197+
View of repository names
205198
"""
206-
return iter(self.__root__)
199+
return self.root.keys()
207200

208201
def items(self) -> t.ItemsView[str, RepositoryModel]:
209202
"""Get items as name-repository pairs.
@@ -213,7 +206,7 @@ def items(self) -> t.ItemsView[str, RepositoryModel]:
213206
ItemsView[str, RepositoryModel]
214207
View of name-repository pairs
215208
"""
216-
return self.__root__.items()
209+
return self.root.items()
217210

218211
def values(self) -> t.ValuesView[RepositoryModel]:
219212
"""Get repository configurations.
@@ -223,18 +216,17 @@ def values(self) -> t.ValuesView[RepositoryModel]:
223216
ValuesView[RepositoryModel]
224217
View of repository configurations
225218
"""
226-
return self.__root__.values()
219+
return self.root.values()
227220

228221

229-
class ConfigModel(BaseModel):
222+
class ConfigDictModel(RootModel[dict[str, ConfigSectionDictModel]]):
230223
"""Complete configuration model containing sections.
231224
232-
A configuration is a collection of sections, where each section contains repositories.
225+
A configuration is a collection of sections, where each section contains
226+
repositories.
233227
"""
234228

235-
__root__: Dict[str, ConfigSectionModel] = Field(default_factory=dict)
236-
237-
def __getitem__(self, key: str) -> ConfigSectionModel:
229+
def __getitem__(self, key: str) -> ConfigSectionDictModel:
238230
"""Get section by name.
239231
240232
Parameters
@@ -244,40 +236,40 @@ def __getitem__(self, key: str) -> ConfigSectionModel:
244236
245237
Returns
246238
-------
247-
ConfigSectionModel
239+
ConfigSectionDictModel
248240
Section configuration
249241
"""
250-
return self.__root__[key]
242+
return self.root[key]
251243

252-
def __iter__(self) -> t.Iterator[str]:
253-
"""Iterate over section names.
244+
def keys(self) -> t.KeysView[str]:
245+
"""Get section names.
254246
255247
Returns
256248
-------
257-
Iterator[str]
258-
Iterator of section names
249+
KeysView[str]
250+
View of section names
259251
"""
260-
return iter(self.__root__)
252+
return self.root.keys()
261253

262-
def items(self) -> t.ItemsView[str, ConfigSectionModel]:
254+
def items(self) -> t.ItemsView[str, ConfigSectionDictModel]:
263255
"""Get items as section-repositories pairs.
264256
265257
Returns
266258
-------
267-
ItemsView[str, ConfigSectionModel]
259+
ItemsView[str, ConfigSectionDictModel]
268260
View of section-repositories pairs
269261
"""
270-
return self.__root__.items()
262+
return self.root.items()
271263

272-
def values(self) -> t.ValuesView[ConfigSectionModel]:
264+
def values(self) -> t.ValuesView[ConfigSectionDictModel]:
273265
"""Get section configurations.
274266
275267
Returns
276268
-------
277-
ValuesView[ConfigSectionModel]
269+
ValuesView[ConfigSectionDictModel]
278270
View of section configurations
279271
"""
280-
return self.__root__.values()
272+
return self.root.values()
281273

282274

283275
# Raw configuration models for initial parsing without validation
@@ -286,50 +278,89 @@ class RawRepositoryModel(BaseModel):
286278

287279
vcs: str
288280
name: str
289-
path: Union[str, Path]
281+
path: str | pathlib.Path
290282
url: str
291-
remotes: Optional[Dict[str, Dict[str, Any]]] = None
292-
shell_command_after: Optional[List[str]] = None
283+
remotes: dict[str, dict[str, t.Any]] | None = None
284+
shell_command_after: list[str] | None = None
293285

294286
model_config = ConfigDict(
295287
extra="allow", # Allow extra fields in raw config
296288
str_strip_whitespace=True,
297289
)
298290

299291

300-
class RawConfigSectionModel(BaseModel):
301-
"""Raw configuration section model before validation."""
292+
# Use a type alias for the complex type in RawConfigSectionDictModel
293+
RawRepoDataType = t.Union[RawRepositoryModel, str, dict[str, t.Any]]
302294

303-
__root__: Dict[str, Union[RawRepositoryModel, str, Dict[str, Any]]] = Field(
304-
default_factory=dict
305-
)
295+
296+
class RawConfigSectionDictModel(RootModel[dict[str, RawRepoDataType]]):
297+
"""Raw configuration section model before validation."""
306298

307299

308-
class RawConfigModel(BaseModel):
300+
class RawConfigDictModel(RootModel[dict[str, RawConfigSectionDictModel]]):
309301
"""Raw configuration model before validation and processing."""
310302

311-
__root__: Dict[str, RawConfigSectionModel] = Field(default_factory=dict)
312-
313303

314304
# Functions to convert between raw and validated models
315305
def convert_raw_to_validated(
316-
raw_config: RawConfigModel,
317-
cwd: t.Callable[[], Path] = Path.cwd,
318-
) -> ConfigModel:
306+
raw_config: RawConfigDictModel,
307+
cwd: t.Callable[[], pathlib.Path] = pathlib.Path.cwd,
308+
) -> ConfigDictModel:
319309
"""Convert raw configuration to validated configuration.
320310
321311
Parameters
322312
----------
323-
raw_config : RawConfigModel
313+
raw_config : RawConfigDictModel
324314
Raw configuration
325315
cwd : Callable[[], Path], optional
326316
Function to get current working directory, by default Path.cwd
327317
328318
Returns
329319
-------
330-
ConfigModel
320+
ConfigDictModel
331321
Validated configuration
332322
"""
333-
# Implementation will go here
334-
# This will handle shorthand syntax, variable resolution, etc.
335-
pass
323+
# Create a new ConfigDictModel
324+
config = ConfigDictModel(root={})
325+
326+
# Process each section in the raw config
327+
for section_name, raw_section in raw_config.root.items():
328+
# Create a new section in the validated config
329+
config.root[section_name] = ConfigSectionDictModel(root={})
330+
331+
# Process each repository in the section
332+
for repo_name, raw_repo_data in raw_section.root.items():
333+
# Handle string shortcuts (URL strings)
334+
if isinstance(raw_repo_data, str):
335+
# Convert string URL to a repository model
336+
repo_model = RepositoryModel(
337+
vcs="git", # Default to git for string URLs
338+
name=repo_name,
339+
path=cwd() / repo_name, # Default path is repo name in current dir
340+
url=raw_repo_data,
341+
)
342+
# Handle direct dictionary data
343+
elif isinstance(raw_repo_data, dict):
344+
# Ensure name is set
345+
if "name" not in raw_repo_data:
346+
raw_repo_data["name"] = repo_name
347+
348+
# Validate and convert path
349+
if "path" in raw_repo_data:
350+
path = raw_repo_data["path"]
351+
# Convert relative paths to absolute using cwd
352+
path_obj = pathlib.Path(os.path.expandvars(str(path))).expanduser()
353+
if not path_obj.is_absolute():
354+
path_obj = cwd() / path_obj
355+
raw_repo_data["path"] = path_obj
356+
357+
# Create repository model
358+
repo_model = RepositoryModel.model_validate(raw_repo_data)
359+
else:
360+
# Skip invalid repository data
361+
continue
362+
363+
# Add repository to the section
364+
config.root[section_name].root[repo_name] = repo_model
365+
366+
return config

0 commit comments

Comments
 (0)