diff --git a/commitizen/bump.py b/commitizen/bump.py index 76a8e15893..6d6b6dc069 100644 --- a/commitizen/bump.py +++ b/commitizen/bump.py @@ -3,6 +3,7 @@ import os import re from collections import OrderedDict +from collections.abc import Iterable from glob import iglob from logging import getLogger from string import Template @@ -61,7 +62,7 @@ def find_increment( def update_version_in_files( current_version: str, new_version: str, - files: list[str], + files: Iterable[str], *, check_consistency: bool = False, encoding: str = ENCODING, @@ -76,7 +77,7 @@ def update_version_in_files( """ # TODO: separate check step and write step updated = [] - for path, regex in files_and_regexs(files, current_version): + for path, regex in _files_and_regexes(files, current_version): current_version_found, version_file = _bump_with_regex( path, current_version, @@ -99,21 +100,22 @@ def update_version_in_files( return updated -def files_and_regexs(patterns: list[str], version: str) -> list[tuple[str, str]]: +def _files_and_regexes(patterns: Iterable[str], version: str) -> list[tuple[str, str]]: """ Resolve all distinct files with their regexp from a list of glob patterns with optional regexp """ - out = [] + out: set[tuple[str, str]] = set() for pattern in patterns: drive, tail = os.path.splitdrive(pattern) path, _, regex = tail.partition(":") filepath = drive + path if not regex: - regex = _version_to_regex(version) + regex = re.escape(version) - for path in iglob(filepath): - out.append((path, regex)) - return sorted(list(set(out))) + for file in iglob(filepath): + out.add((file, regex)) + + return sorted(out) def _bump_with_regex( @@ -128,18 +130,16 @@ def _bump_with_regex( pattern = re.compile(regex) with open(version_filepath, encoding=encoding) as f: for line in f: - if pattern.search(line): - bumped_line = line.replace(current_version, new_version) - if bumped_line != line: - current_version_found = True - lines.append(bumped_line) - else: + if not pattern.search(line): lines.append(line) - return current_version_found, "".join(lines) + continue + bumped_line = line.replace(current_version, new_version) + if bumped_line != line: + current_version_found = True + lines.append(bumped_line) -def _version_to_regex(version: str) -> str: - return version.replace(".", r"\.").replace("+", r"\+") + return current_version_found, "".join(lines) def create_commit_message( diff --git a/commitizen/changelog.py b/commitizen/changelog.py index 704efe6071..ba6fbbc6b3 100644 --- a/commitizen/changelog.py +++ b/commitizen/changelog.py @@ -29,10 +29,10 @@ import re from collections import OrderedDict, defaultdict -from collections.abc import Iterable +from collections.abc import Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from datetime import date -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from jinja2 import ( BaseLoader, @@ -63,7 +63,7 @@ class Metadata: latest_version_position: int | None = None latest_version_tag: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: if self.latest_version and not self.latest_version_tag: # Test syntactic sugar # latest version tag is optional if same as latest version @@ -84,7 +84,7 @@ def generate_tree_from_commits( changelog_message_builder_hook: MessageBuilderHook | None = None, changelog_release_hook: ChangelogReleaseHook | None = None, rules: TagRules | None = None, -) -> Iterable[dict]: +) -> Generator[dict[str, Any], None, None]: pat = re.compile(changelog_pattern) map_pat = re.compile(commit_parser, re.MULTILINE) body_map_pat = re.compile(commit_parser, re.MULTILINE | re.DOTALL) @@ -169,8 +169,8 @@ def process_commit_message( commit: GitCommit, changes: dict[str | None, list], change_type_map: dict[str, str] | None = None, -): - message: dict = { +) -> None: + message: dict[str, Any] = { "sha1": commit.rev, "parents": commit.parents, "author": commit.author, @@ -187,24 +187,27 @@ def process_commit_message( changes[change_type].append(msg) -def order_changelog_tree(tree: Iterable, change_type_order: list[str]) -> Iterable: +def generate_ordered_changelog_tree( + tree: Iterable[Mapping[str, Any]], change_type_order: list[str] +) -> Generator[dict[str, Any], None, None]: if len(set(change_type_order)) != len(change_type_order): raise InvalidConfigurationError( - f"Change types contain duplicates types ({change_type_order})" + f"Change types contain duplicated types ({change_type_order})" ) - sorted_tree = [] for entry in tree: - ordered_change_types = change_type_order + sorted( - set(entry["changes"].keys()) - set(change_type_order) - ) - changes = [ - (ct, entry["changes"][ct]) - for ct in ordered_change_types - if ct in entry["changes"] - ] - sorted_tree.append({**entry, **{"changes": OrderedDict(changes)}}) - return sorted_tree + yield { + **entry, + "changes": _calculate_sorted_changes(change_type_order, entry["changes"]), + } + + +def _calculate_sorted_changes( + change_type_order: list[str], changes: Mapping[str, Any] +) -> OrderedDict[str, Any]: + remaining_change_types = set(changes.keys()) - set(change_type_order) + sorted_change_types = change_type_order + sorted(remaining_change_types) + return OrderedDict((ct, changes[ct]) for ct in sorted_change_types if ct in changes) def get_changelog_template(loader: BaseLoader, template: str) -> Template: @@ -222,7 +225,7 @@ def render_changelog( tree: Iterable, loader: BaseLoader, template: str, - **kwargs, + **kwargs: Any, ) -> str: jinja_template = get_changelog_template(loader, template) changelog: str = jinja_template.render(tree=tree, **kwargs) @@ -279,7 +282,7 @@ def incremental_build( def get_smart_tag_range( - tags: list[GitTag], newest: str, oldest: str | None = None + tags: Sequence[GitTag], newest: str, oldest: str | None = None ) -> list[GitTag]: """Smart because it finds the N+1 tag. @@ -305,10 +308,10 @@ def get_smart_tag_range( def get_oldest_and_newest_rev( - tags: list[GitTag], + tags: Sequence[GitTag], version: str, rules: TagRules, -) -> tuple[str | None, str | None]: +) -> tuple[str | None, str]: """Find the tags for the given version. `version` may come in different formats: diff --git a/commitizen/changelog_formats/__init__.py b/commitizen/changelog_formats/__init__.py index 782bfb24cb..b7b3cac01d 100644 --- a/commitizen/changelog_formats/__init__.py +++ b/commitizen/changelog_formats/__init__.py @@ -25,7 +25,7 @@ class ChangelogFormat(Protocol): config: BaseConfig - def __init__(self, config: BaseConfig): + def __init__(self, config: BaseConfig) -> None: self.config = config @property diff --git a/commitizen/changelog_formats/base.py b/commitizen/changelog_formats/base.py index f69cf8f00f..cb5d385bf8 100644 --- a/commitizen/changelog_formats/base.py +++ b/commitizen/changelog_formats/base.py @@ -20,7 +20,7 @@ class BaseFormat(ChangelogFormat, metaclass=ABCMeta): extension: ClassVar[str] = "" alternative_extensions: ClassVar[set[str]] = set() - def __init__(self, config: BaseConfig): + def __init__(self, config: BaseConfig) -> None: # Constructor needs to be redefined because `Protocol` prevent instantiation by default # See: https://bugs.python.org/issue44807 self.config = config diff --git a/commitizen/cli.py b/commitizen/cli.py index cb834c5d6f..12db01711c 100644 --- a/commitizen/cli.py +++ b/commitizen/cli.py @@ -3,12 +3,11 @@ import argparse import logging import sys -from collections.abc import Sequence from copy import deepcopy from functools import partial from pathlib import Path from types import TracebackType -from typing import Any +from typing import TYPE_CHECKING, cast import argcomplete from decli import cli @@ -48,17 +47,17 @@ def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - kwarg: str | Sequence[Any] | None, + values: object, option_string: str | None = None, - ): - if not isinstance(kwarg, str): + ) -> None: + if not isinstance(values, str): return - if "=" not in kwarg: + if "=" not in values: raise InvalidCommandArgumentError( f"Option {option_string} expect a key=value format" ) kwargs = getattr(namespace, self.dest, None) or {} - key, value = kwarg.split("=", 1) + key, value = values.split("=", 1) if not key: raise InvalidCommandArgumentError( f"Option {option_string} expect a key=value format" @@ -550,22 +549,27 @@ def __call__( def commitizen_excepthook( - type, value, traceback, debug=False, no_raise: list[int] | None = None -): + type: type[BaseException], + value: BaseException, + traceback: TracebackType | None, + debug: bool = False, + no_raise: list[int] | None = None, +) -> None: traceback = traceback if isinstance(traceback, TracebackType) else None + if not isinstance(value, CommitizenException): + original_excepthook(type, value, traceback) + return + if not no_raise: no_raise = [] - if isinstance(value, CommitizenException): - if value.message: - value.output_method(value.message) - if debug: - original_excepthook(type, value, traceback) - exit_code = value.exit_code - if exit_code in no_raise: - exit_code = ExitCode.EXPECTED_EXIT - sys.exit(exit_code) - else: + if value.message: + value.output_method(value.message) + if debug: original_excepthook(type, value, traceback) + exit_code = value.exit_code + if exit_code in no_raise: + exit_code = ExitCode.EXPECTED_EXIT + sys.exit(exit_code) commitizen_debug_excepthook = partial(commitizen_excepthook, debug=True) @@ -580,7 +584,7 @@ def parse_no_raise(comma_separated_no_raise: str) -> list[int]: represents the exit code found in exceptions. """ no_raise_items: list[str] = comma_separated_no_raise.split(",") - no_raise_codes = [] + no_raise_codes: list[int] = [] for item in no_raise_items: if item.isdecimal(): no_raise_codes.append(int(item)) @@ -595,8 +599,33 @@ def parse_no_raise(comma_separated_no_raise: str) -> list[int]: return no_raise_codes -def main(): - parser = cli(data) +if TYPE_CHECKING: + + class Args(argparse.Namespace): + config: str | None = None + debug: bool = False + name: str | None = None + no_raise: str | None = None # comma-separated string, later parsed as list[int] + report: bool = False + project: bool = False + commitizen: bool = False + verbose: bool = False + func: type[ + commands.Init # init + | commands.Commit # commit (c) + | commands.ListCz # ls + | commands.Example # example + | commands.Info # info + | commands.Schema # schema + | commands.Bump # bump + | commands.Changelog # changelog (ch) + | commands.Check # check + | commands.Version # version + ] + + +def main() -> None: + parser: argparse.ArgumentParser = cli(data) argcomplete.autocomplete(parser) # Show help if no arg provided if len(sys.argv) == 1: @@ -606,11 +635,8 @@ def main(): # This is for the command required constraint in 2.0 try: args, unknown_args = parser.parse_known_args() - except (TypeError, SystemExit) as e: - # https://github.com/commitizen-tools/commitizen/issues/429 - # argparse raises TypeError when non exist command is provided on Python < 3.9 - # but raise SystemExit with exit code == 2 on Python 3.9 - if isinstance(e, TypeError) or (isinstance(e, SystemExit) and e.code == 2): + except SystemExit as e: + if e.code == 2: raise NoCommandFoundError() raise e @@ -636,14 +662,11 @@ def main(): extra_args = " ".join(unknown_args[1:]) arguments["extra_cli_args"] = extra_args - if args.config: - conf = config.read_cfg(args.config) - else: - conf = config.read_cfg() - + conf = config.read_cfg(args.config) + args = cast("Args", args) if args.name: conf.update({"name": args.name}) - elif not args.name and not conf.path: + elif not conf.path: conf.update({"name": "cz_conventional_commits"}) if args.debug: @@ -656,7 +679,7 @@ def main(): ) sys.excepthook = no_raise_debug_excepthook - args.func(conf, arguments)() + args.func(conf, arguments)() # type: ignore[arg-type] if __name__ == "__main__": diff --git a/commitizen/cmd.py b/commitizen/cmd.py index ba48ac7881..3f13087233 100644 --- a/commitizen/cmd.py +++ b/commitizen/cmd.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import os import subprocess +from collections.abc import Mapping from typing import NamedTuple from charset_normalizer import from_bytes @@ -28,7 +31,7 @@ def _try_decode(bytes_: bytes) -> str: raise CharacterSetDecodeError() from e -def run(cmd: str, env=None) -> Command: +def run(cmd: str, env: Mapping[str, str] | None = None) -> Command: if env is not None: env = {**os.environ, **env} process = subprocess.Popen( diff --git a/commitizen/commands/__init__.py b/commitizen/commands/__init__.py index 806e384522..58b18297dc 100644 --- a/commitizen/commands/__init__.py +++ b/commitizen/commands/__init__.py @@ -11,13 +11,13 @@ __all__ = ( "Bump", + "Changelog", "Check", "Commit", - "Changelog", "Example", "Info", + "Init", "ListCz", "Schema", "Version", - "Init", ) diff --git a/commitizen/commands/bump.py b/commitizen/commands/bump.py index 0a2bbe37fc..a4024eee35 100644 --- a/commitizen/commands/bump.py +++ b/commitizen/commands/bump.py @@ -37,37 +37,63 @@ logger = getLogger("commitizen") +class BumpArgs(Settings, total=False): + allow_no_commit: bool | None + annotated_tag_message: str | None + build_metadata: str | None + changelog_to_stdout: bool + changelog: bool + check_consistency: bool + devrelease: int | None + dry_run: bool + file_name: str + files_only: bool | None + get_next: bool + git_output_to_stderr: bool + increment_mode: str + increment: Increment | None + local_version: bool + manual_version: str | None + no_verify: bool + prerelease: Prerelease | None + retry: bool + yes: bool + + class Bump: """Show prompt for the user to create a guided commit.""" - def __init__(self, config: BaseConfig, arguments: dict): + def __init__(self, config: BaseConfig, arguments: BumpArgs) -> None: if not git.is_git_project(): raise NotAGitProjectError() self.config: BaseConfig = config self.encoding = config.settings["encoding"] - self.arguments: dict = arguments - self.bump_settings: dict = { - **config.settings, - **{ - key: arguments[key] - for key in [ - "tag_format", - "prerelease", - "increment", - "increment_mode", - "bump_message", - "gpg_sign", - "annotated_tag", - "annotated_tag_message", - "major_version_zero", - "prerelease_offset", - "template", - "file_name", - ] - if arguments[key] is not None + self.arguments = arguments + self.bump_settings = cast( + BumpArgs, + { + **config.settings, + **{ + k: v + for k in ( + "annotated_tag_message", + "annotated_tag", + "bump_message", + "file_name", + "gpg_sign", + "increment_mode", + "increment", + "major_version_zero", + "prerelease_offset", + "prerelease", + "tag_format", + "template", + ) + if (v := arguments.get(k)) is not None + }, }, - } + ) self.cz = factory.committer_factory(self.config) self.changelog_flag = arguments["changelog"] self.changelog_config = self.config.settings.get("update_changelog_on_bump") @@ -82,7 +108,7 @@ def __init__(self, config: BaseConfig, arguments: dict): if deprecated_version_type: warnings.warn( DeprecationWarning( - "`--version-type` parameter is deprecated and will be removed in commitizen 4. " + "`--version-type` parameter is deprecated and will be removed in v5. " "Please use `--version-scheme` instead" ) ) @@ -101,31 +127,29 @@ def __init__(self, config: BaseConfig, arguments: dict): ) self.extras = arguments["extras"] - def is_initial_tag( + def _is_initial_tag( self, current_tag: git.GitTag | None, is_yes: bool = False ) -> bool: """Check if reading the whole git tree up to HEAD is needed.""" - is_initial = False - if not current_tag: - if is_yes: - is_initial = True - else: - out.info("No tag matching configuration could not be found.") - out.info( - "Possible causes:\n" - "- version in configuration is not the current version\n" - "- tag_format or legacy_tag_formats is missing, check them using 'git tag --list'\n" - ) - is_initial = questionary.confirm("Is this the first tag created?").ask() - return is_initial + if current_tag: + return False + if is_yes: + return True + + out.info("No tag matching configuration could be found.") + out.info( + "Possible causes:\n" + "- version in configuration is not the current version\n" + "- tag_format or legacy_tag_formats is missing, check them using 'git tag --list'\n" + ) + return bool(questionary.confirm("Is this the first tag created?").ask()) - def find_increment(self, commits: list[git.GitCommit]) -> Increment | None: + def _find_increment(self, commits: list[git.GitCommit]) -> Increment | None: # Update the bump map to ensure major version doesn't increment. - is_major_version_zero: bool = self.bump_settings["major_version_zero"] # self.cz.bump_map = defaults.bump_map_major_version_zero bump_map = ( self.cz.bump_map_major_version_zero - if is_major_version_zero + if self.bump_settings["major_version_zero"] else self.cz.bump_map ) bump_pattern = self.cz.bump_pattern @@ -134,12 +158,9 @@ def find_increment(self, commits: list[git.GitCommit]) -> Increment | None: raise NoPatternMapError( f"'{self.config.settings['name']}' rule does not support bump" ) - increment = bump.find_increment( - commits, regex=bump_pattern, increments_map=bump_map - ) - return increment + return bump.find_increment(commits, regex=bump_pattern, increments_map=bump_map) - def __call__(self) -> None: # noqa: C901 + def __call__(self) -> None: """Steps executed to bump.""" provider = get_provider(self.config) @@ -148,23 +169,14 @@ def __call__(self) -> None: # noqa: C901 except TypeError: raise NoVersionSpecifiedError() - bump_commit_message: str = self.bump_settings["bump_message"] - version_files: list[str] = self.bump_settings["version_files"] - major_version_zero: bool = self.bump_settings["major_version_zero"] - prerelease_offset: int = self.bump_settings["prerelease_offset"] - - dry_run: bool = self.arguments["dry_run"] - is_yes: bool = self.arguments["yes"] - increment: Increment | None = self.arguments["increment"] - prerelease: Prerelease | None = self.arguments["prerelease"] - devrelease: int | None = self.arguments["devrelease"] - is_files_only: bool | None = self.arguments["files_only"] - is_local_version: bool = self.arguments["local_version"] + increment = self.arguments["increment"] + prerelease = self.arguments["prerelease"] + devrelease = self.arguments["devrelease"] + is_local_version = self.arguments["local_version"] manual_version = self.arguments["manual_version"] build_metadata = self.arguments["build_metadata"] - increment_mode: str = self.arguments["increment_mode"] - get_next: bool = self.arguments["get_next"] - allow_no_commit: bool | None = self.arguments["allow_no_commit"] + get_next = self.arguments["get_next"] + allow_no_commit = self.arguments["allow_no_commit"] if manual_version: if increment: @@ -186,7 +198,7 @@ def __call__(self) -> None: # noqa: C901 "--build-metadata cannot be combined with MANUAL_VERSION" ) - if major_version_zero: + if self.bump_settings["major_version_zero"]: raise NotAllowed( "--major-version-zero cannot be combined with MANUAL_VERSION" ) @@ -194,7 +206,7 @@ def __call__(self) -> None: # noqa: C901 if get_next: raise NotAllowed("--get-next cannot be combined with MANUAL_VERSION") - if major_version_zero: + if self.bump_settings["major_version_zero"]: if not current_version.release[0] == 0: raise NotAllowed( f"--major-version-zero is meaningless for current version {current_version}" @@ -208,7 +220,7 @@ def __call__(self) -> None: # noqa: C901 if get_next: # if trying to use --get-next, we should not allow --changelog or --changelog-to-stdout - if self.changelog_flag or bool(self.changelog_to_stdout): + if self.changelog_flag or self.changelog_to_stdout: raise NotAllowed( "--changelog or --changelog-to-stdout is not allowed with --get-next" ) @@ -219,10 +231,8 @@ def __call__(self) -> None: # noqa: C901 else: # If user specified changelog_to_stdout, they probably want the # changelog to be generated as well, this is the most intuitive solution - self.changelog_flag = ( - self.changelog_flag - or bool(self.changelog_to_stdout) - or self.changelog_config + self.changelog_flag = any( + (self.changelog_flag, self.changelog_to_stdout, self.changelog_config) ) rules = TagRules.from_settings(cast(Settings, self.bump_settings)) @@ -231,7 +241,7 @@ def __call__(self) -> None: # noqa: C901 current_tag, "name", rules.normalize_tag(current_version) ) - is_initial = self.is_initial_tag(current_tag, is_yes) + is_initial = self._is_initial_tag(current_tag, self.arguments["yes"]) if manual_version: try: @@ -259,7 +269,7 @@ def __call__(self) -> None: # noqa: C901 "[NO_COMMITS_FOUND]\nNo new commits found." ) - increment = self.find_increment(commits) + increment = self._find_increment(commits) # It may happen that there are commits, but they are not eligible # for an increment, this generates a problem when using prerelease (#281) @@ -277,16 +287,16 @@ def __call__(self) -> None: # noqa: C901 new_version = current_version.bump( increment, prerelease=prerelease, - prerelease_offset=prerelease_offset, + prerelease_offset=self.bump_settings["prerelease_offset"], devrelease=devrelease, is_local_version=is_local_version, build_metadata=build_metadata, - exact_increment=increment_mode == "exact", + exact_increment=self.arguments["increment_mode"] == "exact", ) new_tag_version = rules.normalize_tag(new_version) message = bump.create_commit_message( - current_version, new_version, bump_commit_message + current_version, new_version, self.bump_settings["bump_message"] ) if get_next: @@ -318,6 +328,7 @@ def __call__(self) -> None: # noqa: C901 ) files: list[str] = [] + dry_run = self.arguments["dry_run"] if self.changelog_flag: args = { "unreleased_version": new_tag_version, @@ -327,14 +338,14 @@ def __call__(self) -> None: # noqa: C901 "dry_run": dry_run, } if self.changelog_to_stdout: - changelog_cmd = Changelog(self.config, {**args, "dry_run": True}) + changelog_cmd = Changelog(self.config, {**args, "dry_run": True}) # type: ignore try: changelog_cmd() except DryRunExit: pass args["file_name"] = self.file_name - changelog_cmd = Changelog(self.config, args) + changelog_cmd = Changelog(self.config, args) # type: ignore changelog_cmd() files.append(changelog_cmd.file_name) @@ -346,7 +357,7 @@ def __call__(self) -> None: # noqa: C901 bump.update_version_in_files( str(current_version), str(new_version), - version_files, + self.bump_settings["version_files"], check_consistency=self.check_consistency, encoding=self.encoding, ) @@ -370,7 +381,7 @@ def __call__(self) -> None: # noqa: C901 else None, ) - if is_files_only: + if self.arguments["files_only"]: raise ExpectedExit() # FIXME: check if any changes have been staged @@ -399,11 +410,19 @@ def __call__(self) -> None: # noqa: C901 c = git.tag( new_tag_version, - signed=self.bump_settings.get("gpg_sign", False) - or bool(self.config.settings.get("gpg_sign", False)), - annotated=self.bump_settings.get("annotated_tag", False) - or bool(self.config.settings.get("annotated_tag", False)) - or bool(self.bump_settings.get("annotated_tag_message", False)), + signed=any( + ( + self.bump_settings.get("gpg_sign"), + self.config.settings.get("gpg_sign"), + ) + ), + annotated=any( + ( + self.bump_settings.get("annotated_tag"), + self.config.settings.get("annotated_tag"), + self.bump_settings.get("annotated_tag_message"), + ) + ), msg=self.bump_settings.get("annotated_tag_message", None), # TODO: also get from self.config.settings? ) diff --git a/commitizen/commands/changelog.py b/commitizen/commands/changelog.py index 0e4efabfa1..8ca36bdb9a 100644 --- a/commitizen/commands/changelog.py +++ b/commitizen/commands/changelog.py @@ -2,10 +2,11 @@ import os import os.path +from collections.abc import Generator, Iterable from difflib import SequenceMatcher from operator import itemgetter from pathlib import Path -from typing import Callable, cast +from typing import Any, TypedDict, cast from commitizen import changelog, defaults, factory, git, out from commitizen.changelog_formats import get_changelog_format @@ -25,16 +26,35 @@ from commitizen.version_schemes import get_version_scheme +class ChangelogArgs(TypedDict, total=False): + change_type_map: dict[str, str] + change_type_order: list[str] + current_version: str + dry_run: bool + file_name: str + incremental: bool + merge_prerelease: bool + rev_range: str + start_rev: str + tag_format: str + unreleased_version: str | None + version_scheme: str + template: str + extras: dict[str, Any] + export_template: str + + class Changelog: """Generate a changelog based on the commit history.""" - def __init__(self, config: BaseConfig, args): + def __init__(self, config: BaseConfig, arguments: ChangelogArgs) -> None: if not git.is_git_project(): raise NotAGitProjectError() - self.config: BaseConfig = config - changelog_file_name = args.get("file_name") or cast( - str, self.config.settings.get("changelog_file") + self.config = config + + changelog_file_name = arguments.get("file_name") or self.config.settings.get( + "changelog_file" ) if not isinstance(changelog_file_name, str): raise NotAllowed( @@ -51,57 +71,61 @@ def __init__(self, config: BaseConfig, args): self.encoding = self.config.settings["encoding"] self.cz = factory.committer_factory(self.config) - self.start_rev = args.get("start_rev") or self.config.settings.get( + self.start_rev = arguments.get("start_rev") or self.config.settings.get( "changelog_start_rev" ) self.changelog_format = get_changelog_format(self.config, self.file_name) - self.incremental = args["incremental"] or self.config.settings.get( - "changelog_incremental" + self.incremental = bool( + arguments.get("incremental") + or self.config.settings.get("changelog_incremental") ) - self.dry_run = args["dry_run"] + self.dry_run = bool(arguments.get("dry_run")) self.scheme = get_version_scheme( - self.config.settings, args.get("version_scheme") + self.config.settings, arguments.get("version_scheme") ) current_version = ( - args.get("current_version", config.settings.get("version")) or "" + arguments.get("current_version") + or self.config.settings.get("version") + or "" ) self.current_version = self.scheme(current_version) if current_version else None - self.unreleased_version = args["unreleased_version"] + self.unreleased_version = arguments["unreleased_version"] self.change_type_map = ( self.config.settings.get("change_type_map") or self.cz.change_type_map ) - self.change_type_order = ( + self.change_type_order = cast( + list[str], self.config.settings.get("change_type_order") or self.cz.change_type_order - or defaults.CHANGE_TYPE_ORDER + or defaults.CHANGE_TYPE_ORDER, ) - self.rev_range = args.get("rev_range") - self.tag_format: str = ( - args.get("tag_format") or self.config.settings["tag_format"] + self.rev_range = arguments.get("rev_range") + self.tag_format = ( + arguments.get("tag_format") or self.config.settings["tag_format"] ) self.tag_rules = TagRules( scheme=self.scheme, tag_format=self.tag_format, legacy_tag_formats=self.config.settings["legacy_tag_formats"], ignored_tag_formats=self.config.settings["ignored_tag_formats"], - merge_prereleases=args.get("merge_prerelease") + merge_prereleases=arguments.get("merge_prerelease") or self.config.settings["changelog_merge_prerelease"], ) self.template = ( - args.get("template") + arguments.get("template") or self.config.settings.get("template") or self.changelog_format.template ) - self.extras = args.get("extras") or {} - self.export_template_to = args.get("export_template") + self.extras = arguments.get("extras") or {} + self.export_template_to = arguments.get("export_template") - def _find_incremental_rev(self, latest_version: str, tags: list[GitTag]) -> str: + def _find_incremental_rev(self, latest_version: str, tags: Iterable[GitTag]) -> str: """Try to find the 'start_rev'. We use a similarity approach. We know how to parse the version from the markdown @@ -114,28 +138,28 @@ def _find_incremental_rev(self, latest_version: str, tags: list[GitTag]) -> str: on our experience. """ SIMILARITY_THRESHOLD = 0.89 - tag_ratio = map( - lambda tag: ( - SequenceMatcher( + scores_and_tag_names: Generator[tuple[float, str]] = ( + ( + score, + tag.name, + ) + for tag in tags + if ( + score := SequenceMatcher( None, latest_version, strip_local_version(tag.name) - ).ratio(), - tag, - ), - tags, + ).ratio() + ) + >= SIMILARITY_THRESHOLD ) try: - score, tag = max(tag_ratio, key=itemgetter(0)) + _, start_rev = max(scores_and_tag_names, key=itemgetter(0)) except ValueError: raise NoRevisionError() - if score < SIMILARITY_THRESHOLD: - raise NoRevisionError() - start_rev = tag.name return start_rev - def write_changelog( + def _write_changelog( self, changelog_out: str, lines: list[str], changelog_meta: changelog.Metadata - ): - changelog_hook: Callable | None = self.cz.changelog_hook + ) -> None: with smart_open(self.file_name, "w", encoding=self.encoding) as changelog_file: partial_changelog: str | None = None if self.incremental: @@ -145,23 +169,24 @@ def write_changelog( changelog_out = "".join(new_lines) partial_changelog = changelog_out - if changelog_hook: - changelog_out = changelog_hook(changelog_out, partial_changelog) + if self.cz.changelog_hook: + changelog_out = self.cz.changelog_hook(changelog_out, partial_changelog) changelog_file.write(changelog_out) - def export_template(self): + def _export_template(self) -> None: tpl = changelog.get_changelog_template(self.cz.template_loader, self.template) - src = Path(tpl.filename) - Path(self.export_template_to).write_text(src.read_text()) + # TODO: fix the following type ignores + src = Path(tpl.filename) # type: ignore + Path(self.export_template_to).write_text(src.read_text()) # type: ignore - def __call__(self): + def __call__(self) -> None: commit_parser = self.cz.commit_parser changelog_pattern = self.cz.changelog_pattern start_rev = self.start_rev unreleased_version = self.unreleased_version changelog_meta = changelog.Metadata() - change_type_map: dict | None = self.change_type_map + change_type_map: dict[str, str] | None = self.change_type_map changelog_message_builder_hook: MessageBuilderHook | None = ( self.cz.changelog_message_builder_hook ) @@ -170,7 +195,7 @@ def __call__(self): ) if self.export_template_to: - return self.export_template() + return self._export_template() if not changelog_pattern or not commit_parser: raise NoPatternMapError( @@ -189,7 +214,7 @@ def __call__(self): changelog_meta = self.changelog_format.get_metadata(self.file_name) if changelog_meta.latest_version: start_rev = self._find_incremental_rev( - strip_local_version(changelog_meta.latest_version_tag), tags + strip_local_version(changelog_meta.latest_version_tag or ""), tags ) if self.rev_range: start_rev, end_rev = changelog.get_oldest_and_newest_rev( @@ -214,21 +239,21 @@ def __call__(self): rules=self.tag_rules, ) if self.change_type_order: - tree = changelog.order_changelog_tree(tree, self.change_type_order) + tree = changelog.generate_ordered_changelog_tree( + tree, self.change_type_order + ) extras = self.cz.template_extras.copy() extras.update(self.config.settings["extras"]) extras.update(self.extras) changelog_out = changelog.render_changelog( tree, loader=self.cz.template_loader, template=self.template, **extras - ) - changelog_out = changelog_out.lstrip("\n") + ).lstrip("\n") # Dry_run is executed here to avoid checking and reading the files if self.dry_run: - changelog_hook: Callable | None = self.cz.changelog_hook - if changelog_hook: - changelog_out = changelog_hook(changelog_out, "") + if self.cz.changelog_hook: + changelog_out = self.cz.changelog_hook(changelog_out, "") out.write(changelog_out) raise DryRunExit() @@ -237,4 +262,4 @@ def __call__(self): with open(self.file_name, encoding=self.encoding) as changelog_file: lines = changelog_file.readlines() - self.write_changelog(changelog_out, lines, changelog_meta) + self._write_changelog(changelog_out, lines, changelog_meta) diff --git a/commitizen/commands/check.py b/commitizen/commands/check.py index 1e3b8464e1..8a7d0dd019 100644 --- a/commitizen/commands/check.py +++ b/commitizen/commands/check.py @@ -1,9 +1,8 @@ from __future__ import annotations -import os import re import sys -from typing import Any +from typing import TypedDict from commitizen import factory, git, out from commitizen.config import BaseConfig @@ -14,10 +13,20 @@ ) +class CheckArgs(TypedDict, total=False): + commit_msg_file: str + commit_msg: str + rev_range: str + allow_abort: bool + message_length_limit: int + allowed_prefixes: list[str] + message: str + + class Check: """Check if the current commit msg matches the commitizen format.""" - def __init__(self, config: BaseConfig, arguments: dict[str, Any], cwd=os.getcwd()): + def __init__(self, config: BaseConfig, arguments: CheckArgs, *args: object) -> None: """Initial check command. Args: @@ -25,16 +34,15 @@ def __init__(self, config: BaseConfig, arguments: dict[str, Any], cwd=os.getcwd( arguments: All the flags provided by the user cwd: Current work directory """ - self.commit_msg_file: str | None = arguments.get("commit_msg_file") - self.commit_msg: str | None = arguments.get("message") - self.rev_range: str | None = arguments.get("rev_range") - self.allow_abort: bool = bool( + self.commit_msg_file = arguments.get("commit_msg_file") + self.commit_msg = arguments.get("message") + self.rev_range = arguments.get("rev_range") + self.allow_abort = bool( arguments.get("allow_abort", config.settings["allow_abort"]) ) - self.max_msg_length: int = arguments.get("message_length_limit", 0) + self.max_msg_length = arguments.get("message_length_limit", 0) # we need to distinguish between None and [], which is a valid value - allowed_prefixes = arguments.get("allowed_prefixes") self.allowed_prefixes: list[str] = ( allowed_prefixes @@ -48,7 +56,7 @@ def __init__(self, config: BaseConfig, arguments: dict[str, Any], cwd=os.getcwd( self.encoding = config.settings["encoding"] self.cz = factory.committer_factory(self.config) - def _valid_command_argument(self): + def _valid_command_argument(self) -> None: num_exclusive_args_provided = sum( arg is not None for arg in (self.commit_msg_file, self.commit_msg, self.rev_range) @@ -61,7 +69,7 @@ def _valid_command_argument(self): "See 'cz check -h' for more information" ) - def __call__(self): + def __call__(self) -> None: """Validate if commit messages follows the conventional pattern. Raises: @@ -97,7 +105,7 @@ def _get_commit_message(self) -> str | None: # Get commit message from file (--commit-msg-file) return commit_file.read() - def _get_commits(self): + def _get_commits(self) -> list[git.GitCommit]: if (msg := self._get_commit_message()) is not None: return [git.GitCommit(rev="", title="", body=self._filter_comments(msg))] diff --git a/commitizen/commands/commit.py b/commitizen/commands/commit.py index cb34c41a50..43feb272ba 100644 --- a/commitizen/commands/commit.py +++ b/commitizen/commands/commit.py @@ -5,6 +5,8 @@ import shutil import subprocess import tempfile +from pathlib import Path +from typing import TypedDict import questionary @@ -26,10 +28,22 @@ from commitizen.git import smart_open +class CommitArgs(TypedDict, total=False): + all: bool + dry_run: bool + edit: bool + extra_cli_args: str + message_length_limit: int + no_retry: bool + signoff: bool + write_message_to_file: Path | None + retry: bool + + class Commit: """Show prompt for the user to create a guided commit.""" - def __init__(self, config: BaseConfig, arguments: dict): + def __init__(self, config: BaseConfig, arguments: CommitArgs) -> None: if not git.is_git_project(): raise NotAGitProjectError() @@ -39,7 +53,7 @@ def __init__(self, config: BaseConfig, arguments: dict): self.arguments = arguments self.temp_file: str = get_backup_file_path() - def read_backup_message(self) -> str | None: + def _read_backup_message(self) -> str | None: # Check the commit backup file exists if not os.path.isfile(self.temp_file): return None @@ -48,7 +62,7 @@ def read_backup_message(self) -> str | None: with open(self.temp_file, encoding=self.encoding) as f: return f.read().strip() - def prompt_commit_questions(self) -> str: + def _prompt_commit_questions(self) -> str: # Prompt user for the commit message cz = self.cz questions = cz.questions() @@ -67,7 +81,7 @@ def prompt_commit_questions(self) -> str: message = cz.message(answers) message_len = len(message.partition("\n")[0].strip()) - message_length_limit: int = self.arguments.get("message_length_limit", 0) + message_length_limit = self.arguments.get("message_length_limit", 0) if 0 < message_length_limit < message_len: raise CommitMessageLengthExceededError( f"Length of commit message exceeds limit ({message_len}/{message_length_limit})" @@ -92,41 +106,36 @@ def manual_edit(self, message: str) -> str: os.unlink(file.name) return message - def __call__(self): - extra_args: str = self.arguments.get("extra_cli_args", "") + def _get_message(self) -> str: + if self.arguments.get("retry"): + m = self._read_backup_message() + if m is None: + raise NoCommitBackupError() + return m - allow_empty: bool = "--allow-empty" in extra_args + if self.config.settings.get("retry_after_failure") and not self.arguments.get( + "no_retry" + ): + return self._read_backup_message() or self._prompt_commit_questions() + return self._prompt_commit_questions() - dry_run: bool = self.arguments.get("dry_run") - write_message_to_file: bool = self.arguments.get("write_message_to_file") - manual_edit: bool = self.arguments.get("edit") + def __call__(self) -> None: + extra_args = self.arguments.get("extra_cli_args", "") + dry_run = bool(self.arguments.get("dry_run")) + write_message_to_file = self.arguments.get("write_message_to_file") + signoff = bool(self.arguments.get("signoff")) - is_all: bool = self.arguments.get("all") - if is_all: - c = git.add("-u") + if self.arguments.get("all"): + git.add("-u") - if git.is_staging_clean() and not (dry_run or allow_empty): + if git.is_staging_clean() and not (dry_run or "--allow-empty" in extra_args): raise NothingToCommitError("No files added to staging!") if write_message_to_file is not None and write_message_to_file.is_dir(): raise NotAllowed(f"{write_message_to_file} is a directory") - retry: bool = self.arguments.get("retry") - no_retry: bool = self.arguments.get("no_retry") - retry_after_failure: bool = self.config.settings.get("retry_after_failure") - - if retry: - m = self.read_backup_message() - if m is None: - raise NoCommitBackupError() - elif retry_after_failure and not no_retry: - m = self.read_backup_message() - if m is None: - m = self.prompt_commit_questions() - else: - m = self.prompt_commit_questions() - - if manual_edit: + m = self._get_message() + if self.arguments.get("edit"): m = self.manual_edit(m) out.info(f"\n{m}\n") @@ -138,19 +147,15 @@ def __call__(self): if dry_run: raise DryRunExit() - always_signoff: bool = self.config.settings["always_signoff"] - signoff: bool = self.arguments.get("signoff") - if signoff: out.warn( "signoff mechanic is deprecated, please use `cz commit -- -s` instead." ) - if always_signoff or signoff: + if self.config.settings["always_signoff"] or signoff: extra_args = f"{extra_args} -s".strip() c = git.commit(m, args=extra_args) - if c.return_code != 0: out.error(c.err) @@ -160,11 +165,12 @@ def __call__(self): raise CommitError() - if "nothing added" in c.out or "no changes added to commit" in c.out: + if any(s in c.out for s in ("nothing added", "no changes added to commit")): out.error(c.out) - else: - with contextlib.suppress(FileNotFoundError): - os.remove(self.temp_file) - out.write(c.err) - out.write(c.out) - out.success("Commit successful!") + return + + with contextlib.suppress(FileNotFoundError): + os.remove(self.temp_file) + out.write(c.err) + out.write(c.out) + out.success("Commit successful!") diff --git a/commitizen/commands/example.py b/commitizen/commands/example.py index a28ad85f16..ba9f34adc4 100644 --- a/commitizen/commands/example.py +++ b/commitizen/commands/example.py @@ -5,9 +5,9 @@ class Example: """Show an example so people understands the rules.""" - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, *args: object) -> None: self.config: BaseConfig = config self.cz = factory.committer_factory(self.config) - def __call__(self): + def __call__(self) -> None: out.write(self.cz.example()) diff --git a/commitizen/commands/info.py b/commitizen/commands/info.py index abd4197e7f..5ea8227313 100644 --- a/commitizen/commands/info.py +++ b/commitizen/commands/info.py @@ -5,9 +5,9 @@ class Info: """Show in depth explanation of your rules.""" - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, *args: object) -> None: self.config: BaseConfig = config self.cz = factory.committer_factory(self.config) - def __call__(self): + def __call__(self) -> None: out.write(self.cz.info()) diff --git a/commitizen/commands/init.py b/commitizen/commands/init.py index 0eb3d99d17..1207cd02ee 100644 --- a/commitizen/commands/init.py +++ b/commitizen/commands/init.py @@ -79,13 +79,13 @@ def is_pre_commit_installed(self) -> bool: class Init: - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, *args: object) -> None: self.config: BaseConfig = config self.encoding = config.settings["encoding"] self.cz = factory.committer_factory(self.config) self.project_info = ProjectInfo() - def __call__(self): + def __call__(self) -> None: if self.config.path: out.line(f"Config file {self.config.path} already exists") return @@ -120,7 +120,7 @@ def __call__(self): self.config = JsonConfig(data="{}", path=config_path) elif "yaml" in config_path: self.config = YAMLConfig(data="", path=config_path) - values_to_add = {} + values_to_add: dict[str, Any] = {} values_to_add["name"] = cz_name values_to_add["tag_format"] = tag_format values_to_add["version_scheme"] = version_scheme @@ -207,7 +207,7 @@ def _ask_tag(self) -> str: raise NoAnswersError("Tag is required!") return latest_tag - def _ask_tag_format(self, latest_tag) -> str: + def _ask_tag_format(self, latest_tag: str) -> str: is_correct_format = False if latest_tag.startswith("v"): tag_format = r"v$version" @@ -302,7 +302,7 @@ def _ask_update_changelog_on_bump(self) -> bool: ).unsafe_ask() return update_changelog_on_bump - def _exec_install_pre_commit_hook(self, hook_types: list[str]): + def _exec_install_pre_commit_hook(self, hook_types: list[str]) -> None: cmd_str = self._gen_pre_commit_cmd(hook_types) c = cmd.run(cmd_str) if c.return_code != 0: @@ -323,7 +323,7 @@ def _gen_pre_commit_cmd(self, hook_types: list[str]) -> str: ) return cmd_str - def _install_pre_commit_hook(self, hook_types: list[str] | None = None): + def _install_pre_commit_hook(self, hook_types: list[str] | None = None) -> None: pre_commit_config_filename = ".pre-commit-config.yaml" cz_hook_config = { "repo": "https://github.com/commitizen-tools/commitizen", @@ -369,6 +369,6 @@ def _install_pre_commit_hook(self, hook_types: list[str] | None = None): self._exec_install_pre_commit_hook(hook_types) out.write("commitizen pre-commit hook is now installed in your '.git'\n") - def _update_config_file(self, values: dict[str, Any]): + def _update_config_file(self, values: dict[str, Any]) -> None: for key, value in values.items(): self.config.set_key(key, value) diff --git a/commitizen/commands/list_cz.py b/commitizen/commands/list_cz.py index 99701865af..412266f6c3 100644 --- a/commitizen/commands/list_cz.py +++ b/commitizen/commands/list_cz.py @@ -6,8 +6,8 @@ class ListCz: """List currently installed rules.""" - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, *args: object) -> None: self.config: BaseConfig = config - def __call__(self): + def __call__(self) -> None: out.write("\n".join(registry.keys())) diff --git a/commitizen/commands/schema.py b/commitizen/commands/schema.py index 4af5679cf5..a7aeb53569 100644 --- a/commitizen/commands/schema.py +++ b/commitizen/commands/schema.py @@ -5,9 +5,9 @@ class Schema: """Show structure of the rule.""" - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, *args: object) -> None: self.config: BaseConfig = config self.cz = factory.committer_factory(self.config) - def __call__(self): + def __call__(self) -> None: out.write(self.cz.schema()) diff --git a/commitizen/commands/version.py b/commitizen/commands/version.py index 45d553c710..6b0aa331ad 100644 --- a/commitizen/commands/version.py +++ b/commitizen/commands/version.py @@ -1,5 +1,6 @@ import platform import sys +from typing import TypedDict from commitizen import out from commitizen.__version__ import __version__ @@ -7,16 +8,22 @@ from commitizen.providers import get_provider +class VersionArgs(TypedDict, total=False): + report: bool + project: bool + verbose: bool + + class Version: """Get the version of the installed commitizen or the current project.""" - def __init__(self, config: BaseConfig, *args): + def __init__(self, config: BaseConfig, arguments: VersionArgs) -> None: self.config: BaseConfig = config - self.parameter = args[0] + self.parameter = arguments self.operating_system = platform.system() self.python_version = sys.version - def __call__(self): + def __call__(self) -> None: if self.parameter.get("report"): out.write(f"Commitizen Version: {__version__}") out.write(f"Python Version: {self.python_version}") diff --git a/commitizen/config/base_config.py b/commitizen/config/base_config.py index 478691aa14..59c0d16a06 100644 --- a/commitizen/config/base_config.py +++ b/commitizen/config/base_config.py @@ -1,12 +1,22 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING, Any from commitizen.defaults import DEFAULT_SETTINGS, Settings +if TYPE_CHECKING: + import sys + + # Self is Python 3.11+ but backported in typing-extensions + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + class BaseConfig: - def __init__(self): + def __init__(self) -> None: self._settings: Settings = DEFAULT_SETTINGS.copy() self.encoding = self.settings["encoding"] self._path: Path | None = None @@ -16,10 +26,14 @@ def settings(self) -> Settings: return self._settings @property - def path(self) -> Path | None: - return self._path + def path(self) -> Path: + return self._path # type: ignore + + @path.setter + def path(self, path: str | Path) -> None: + self._path = Path(path) - def set_key(self, key, value): + def set_key(self, key: str, value: Any) -> Self: """Set or update a key in the conf. For now only strings are supported. @@ -30,8 +44,8 @@ def set_key(self, key, value): def update(self, data: Settings) -> None: self._settings.update(data) - def add_path(self, path: str | Path) -> None: - self._path = Path(path) - def _parse_setting(self, data: bytes | str) -> None: raise NotImplementedError() + + def init_empty_config_content(self) -> None: + raise NotImplementedError() diff --git a/commitizen/config/json_config.py b/commitizen/config/json_config.py index b6a07f4ced..be1f1c36b0 100644 --- a/commitizen/config/json_config.py +++ b/commitizen/config/json_config.py @@ -2,25 +2,35 @@ import json from pathlib import Path +from typing import TYPE_CHECKING, Any from commitizen.exceptions import InvalidConfigurationError from commitizen.git import smart_open from .base_config import BaseConfig +if TYPE_CHECKING: + import sys + + # Self is Python 3.11+ but backported in typing-extensions + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + class JsonConfig(BaseConfig): - def __init__(self, *, data: bytes | str, path: Path | str): + def __init__(self, *, data: bytes | str, path: Path | str) -> None: super().__init__() self.is_empty_config = False - self.add_path(path) + self.path = path self._parse_setting(data) - def init_empty_config_content(self): + def init_empty_config_content(self) -> None: with smart_open(self.path, "a", encoding=self.encoding) as json_file: json.dump({"commitizen": {}}, json_file) - def set_key(self, key, value): + def set_key(self, key: str, value: Any) -> Self: """Set or update a key in the conf. For now only strings are supported. diff --git a/commitizen/config/toml_config.py b/commitizen/config/toml_config.py index 813389cbcf..3571c9c882 100644 --- a/commitizen/config/toml_config.py +++ b/commitizen/config/toml_config.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from typing import TYPE_CHECKING, Any from tomlkit import exceptions, parse, table @@ -9,15 +10,24 @@ from .base_config import BaseConfig +if TYPE_CHECKING: + import sys + + # Self is Python 3.11+ but backported in typing-extensions + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + class TomlConfig(BaseConfig): - def __init__(self, *, data: bytes | str, path: Path | str): + def __init__(self, *, data: bytes | str, path: Path | str) -> None: super().__init__() self.is_empty_config = False - self.add_path(path) + self.path = path self._parse_setting(data) - def init_empty_config_content(self): + def init_empty_config_content(self) -> None: if os.path.isfile(self.path): with open(self.path, "rb") as input_toml_file: parser = parse(input_toml_file.read()) @@ -27,10 +37,10 @@ def init_empty_config_content(self): with open(self.path, "wb") as output_toml_file: if parser.get("tool") is None: parser["tool"] = table() - parser["tool"]["commitizen"] = table() + parser["tool"]["commitizen"] = table() # type: ignore output_toml_file.write(parser.as_string().encode(self.encoding)) - def set_key(self, key, value): + def set_key(self, key: str, value: Any) -> Self: """Set or update a key in the conf. For now only strings are supported. @@ -39,7 +49,7 @@ def set_key(self, key, value): with open(self.path, "rb") as f: parser = parse(f.read()) - parser["tool"]["commitizen"][key] = value + parser["tool"]["commitizen"][key] = value # type: ignore with open(self.path, "wb") as f: f.write(parser.as_string().encode(self.encoding)) return self diff --git a/commitizen/config/yaml_config.py b/commitizen/config/yaml_config.py index 2bb6fe3af8..f2a79e6937 100644 --- a/commitizen/config/yaml_config.py +++ b/commitizen/config/yaml_config.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING, Any import yaml @@ -9,15 +10,24 @@ from .base_config import BaseConfig +if TYPE_CHECKING: + import sys + + # Self is Python 3.11+ but backported in typing-extensions + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + class YAMLConfig(BaseConfig): - def __init__(self, *, data: bytes | str, path: Path | str): + def __init__(self, *, data: bytes | str, path: Path | str) -> None: super().__init__() self.is_empty_config = False - self.add_path(path) + self.path = path self._parse_setting(data) - def init_empty_config_content(self): + def init_empty_config_content(self) -> None: with smart_open(self.path, "a", encoding=self.encoding) as json_file: yaml.dump({"commitizen": {}}, json_file, explicit_start=True) @@ -41,7 +51,7 @@ def _parse_setting(self, data: bytes | str) -> None: except (KeyError, TypeError): self.is_empty_config = True - def set_key(self, key, value): + def set_key(self, key: str, value: Any) -> Self: """Set or update a key in the conf. For now only strings are supported. diff --git a/commitizen/cz/base.py b/commitizen/cz/base.py index 43455a74ca..9e803c3d01 100644 --- a/commitizen/cz/base.py +++ b/commitizen/cz/base.py @@ -76,13 +76,13 @@ def message(self, answers: dict) -> str: """Format your git message.""" @property - def style(self): + def style(self) -> Style: return merge_styles( [ Style(BaseCommitizen.default_style_config), Style(self.config.settings["style"]), ] - ) + ) # type: ignore[return-value] def example(self) -> str: """Example of the commit message.""" @@ -99,10 +99,3 @@ def schema_pattern(self) -> str: def info(self) -> str: """Information about the standardized commit message.""" raise NotImplementedError("Not Implemented yet") - - def process_commit(self, commit: str) -> str: - """Process commit for changelog. - - If not overwritten, it returns the first line of commit. - """ - return commit.split("\n")[0] diff --git a/commitizen/cz/conventional_commits/conventional_commits.py b/commitizen/cz/conventional_commits/conventional_commits.py index af29a209fc..912770a726 100644 --- a/commitizen/cz/conventional_commits/conventional_commits.py +++ b/commitizen/cz/conventional_commits/conventional_commits.py @@ -1,5 +1,4 @@ import os -import re from commitizen import defaults from commitizen.cz.base import BaseCommitizen @@ -9,29 +8,19 @@ __all__ = ["ConventionalCommitsCz"] -def parse_scope(text): - if not text: - return "" +def _parse_scope(text: str) -> str: + return "-".join(text.strip().split()) - scope = text.strip().split() - if len(scope) == 1: - return scope[0] - return "-".join(scope) - - -def parse_subject(text): - if isinstance(text, str): - text = text.strip(".").strip() - - return required_validator(text, msg="Subject is required.") +def _parse_subject(text: str) -> str: + return required_validator(text.strip(".").strip(), msg="Subject is required.") class ConventionalCommitsCz(BaseCommitizen): bump_pattern = defaults.BUMP_PATTERN bump_map = defaults.BUMP_MAP bump_map_major_version_zero = defaults.BUMP_MAP_MAJOR_VERSION_ZERO - commit_parser = r"^((?Pfeat|fix|refactor|perf|BREAKING CHANGE)(?:\((?P[^()\r\n]*)\)|\()?(?P!)?|\w+!):\s(?P.*)?" # noqa + commit_parser = r"^((?Pfeat|fix|refactor|perf|BREAKING CHANGE)(?:\((?P[^()\r\n]*)\)|\()?(?P!)?|\w+!):\s(?P.*)?" change_type_map = { "feat": "Feat", "fix": "Fix", @@ -41,7 +30,7 @@ class ConventionalCommitsCz(BaseCommitizen): changelog_pattern = defaults.BUMP_PATTERN def questions(self) -> Questions: - questions: Questions = [ + return [ { "type": "list", "name": "prefix", @@ -113,12 +102,12 @@ def questions(self) -> Questions: "message": ( "What is the scope of this change? (class or file name): (press [enter] to skip)\n" ), - "filter": parse_scope, + "filter": _parse_scope, }, { "type": "input", "name": "subject", - "filter": parse_subject, + "filter": _parse_subject, "message": ( "Write a short and imperative summary of the code changes: (lower case and no period)\n" ), @@ -146,7 +135,6 @@ def questions(self) -> Questions: ), }, ] - return questions def message(self, answers: dict) -> str: prefix = answers["prefix"] @@ -165,9 +153,7 @@ def message(self, answers: dict) -> str: if footer: footer = f"\n\n{footer}" - message = f"{prefix}{scope}: {subject}{body}{footer}" - - return message + return f"{prefix}{scope}: {subject}{body}{footer}" def example(self) -> str: return ( @@ -188,25 +174,16 @@ def schema(self) -> str: ) def schema_pattern(self) -> str: - PATTERN = ( + return ( r"(?s)" # To explicitly make . match new line r"(build|ci|docs|feat|fix|perf|refactor|style|test|chore|revert|bump)" # type r"(\(\S+\))?!?:" # scope r"( [^\n\r]+)" # subject r"((\n\n.*)|(\s*))?$" ) - return PATTERN def info(self) -> str: dir_path = os.path.dirname(os.path.realpath(__file__)) filepath = os.path.join(dir_path, "conventional_commits_info.txt") with open(filepath, encoding=self.config.settings["encoding"]) as f: - content = f.read() - return content - - def process_commit(self, commit: str) -> str: - pat = re.compile(self.schema_pattern()) - m = re.match(pat, commit) - if m is None: - return "" - return m.group(3).strip() + return f.read() diff --git a/commitizen/cz/customize/customize.py b/commitizen/cz/customize/customize.py index 53ada4b2c0..8f844501ec 100644 --- a/commitizen/cz/customize/customize.py +++ b/commitizen/cz/customize/customize.py @@ -26,7 +26,7 @@ class CustomizeCommitsCz(BaseCommitizen): bump_map_major_version_zero = defaults.BUMP_MAP_MAJOR_VERSION_ZERO change_type_order = defaults.CHANGE_TYPE_ORDER - def __init__(self, config: BaseConfig): + def __init__(self, config: BaseConfig) -> None: super().__init__(config) if "customize" not in self.config.settings: diff --git a/commitizen/cz/jira/jira.py b/commitizen/cz/jira/jira.py index b8fd056a71..05e23e1690 100644 --- a/commitizen/cz/jira/jira.py +++ b/commitizen/cz/jira/jira.py @@ -8,7 +8,7 @@ class JiraSmartCz(BaseCommitizen): def questions(self) -> Questions: - questions = [ + return [ { "type": "input", "name": "message", @@ -42,7 +42,6 @@ def questions(self) -> Questions: "filter": lambda x: "#comment " + x if x else "", }, ] - return questions def message(self, answers: dict) -> str: return " ".join( @@ -68,7 +67,7 @@ def example(self) -> str: ) def schema(self) -> str: - return " # " # noqa + return " # " def schema_pattern(self) -> str: return r".*[A-Z]{2,}\-[0-9]+( #| .* #).+( #.+)*" @@ -77,5 +76,4 @@ def info(self) -> str: dir_path = os.path.dirname(os.path.realpath(__file__)) filepath = os.path.join(dir_path, "jira_info.txt") with open(filepath, encoding=self.config.settings["encoding"]) as f: - content = f.read() - return content + return f.read() diff --git a/commitizen/cz/utils.py b/commitizen/cz/utils.py index 7bc89673c6..a6f687226c 100644 --- a/commitizen/cz/utils.py +++ b/commitizen/cz/utils.py @@ -5,28 +5,27 @@ from commitizen import git from commitizen.cz import exceptions +_RE_LOCAL_VERSION = re.compile(r"\+.+") -def required_validator(answer, msg=None): + +def required_validator(answer: str, msg: object = None) -> str: if not answer: raise exceptions.AnswerRequiredError(msg) return answer -def multiple_line_breaker(answer, sep="|"): +def multiple_line_breaker(answer: str, sep: str = "|") -> str: return "\n".join(line.strip() for line in answer.split(sep) if line) def strip_local_version(version: str) -> str: - return re.sub(r"\+.+", "", version) + return _RE_LOCAL_VERSION.sub("", version) def get_backup_file_path() -> str: project_root = git.find_git_project_root() - - if project_root is None: - project = "" - else: - project = project_root.as_posix().replace("/", "%") + project = project_root.as_posix().replace("/", "%") if project_root else "" user = os.environ.get("USER", "") + return os.path.join(tempfile.gettempdir(), f"cz.commit%{user}%{project}.backup") diff --git a/commitizen/defaults.py b/commitizen/defaults.py index 0b6c28e6a9..7a51389b7a 100644 --- a/commitizen/defaults.py +++ b/commitizen/defaults.py @@ -1,6 +1,7 @@ from __future__ import annotations import pathlib +import warnings from collections import OrderedDict from collections.abc import Iterable, MutableMapping, Sequence from typing import Any, TypedDict @@ -28,36 +29,39 @@ class CzSettings(TypedDict, total=False): class Settings(TypedDict, total=False): - name: str - version: str | None - version_files: list[str] - version_provider: str | None - version_scheme: str | None - version_type: str | None - tag_format: str - legacy_tag_formats: Sequence[str] - ignored_tag_formats: Sequence[str] - bump_message: str | None - retry_after_failure: bool allow_abort: bool allowed_prefixes: list[str] + always_signoff: bool + annotated_tag: bool + bump_message: str | None + change_type_map: dict[str, str] changelog_file: str changelog_format: str | None changelog_incremental: bool - changelog_start_rev: str | None changelog_merge_prerelease: bool - update_changelog_on_bump: bool - use_shortcuts: bool - style: list[tuple[str, str]] + changelog_start_rev: str | None customize: CzSettings + encoding: str + extras: dict[str, Any] + gpg_sign: bool + ignored_tag_formats: Sequence[str] + legacy_tag_formats: Sequence[str] major_version_zero: bool - pre_bump_hooks: list[str] | None + name: str post_bump_hooks: list[str] | None + pre_bump_hooks: list[str] | None prerelease_offset: int - encoding: str - always_signoff: bool + retry_after_failure: bool + style: list[tuple[str, str]] + tag_format: str template: str | None - extras: dict[str, Any] + update_changelog_on_bump: bool + use_shortcuts: bool + version_files: list[str] + version_provider: str | None + version_scheme: str | None + version_type: str | None + version: str | None CONFIG_FILES: list[str] = [ @@ -141,7 +145,7 @@ class Settings(TypedDict, total=False): def get_tag_regexes( version_regex: str, ) -> dict[str, str]: - regexs = { + regexes = { "version": version_regex, "major": r"(?P\d+)", "minor": r"(?P\d+)", @@ -150,6 +154,34 @@ def get_tag_regexes( "devrelease": r"(?P\.dev\d+)?", } return { - **{f"${k}": v for k, v in regexs.items()}, - **{f"${{{k}}}": v for k, v in regexs.items()}, + **{f"${k}": v for k, v in regexes.items()}, + **{f"${{{k}}}": v for k, v in regexes.items()}, + } + + +def __getattr__(name: str) -> Any: + # PEP-562: deprecate module-level variable + + # {"deprecated key": (value, "new key")} + deprecated_vars = { + "bump_pattern": (BUMP_PATTERN, "BUMP_PATTERN"), + "bump_map": (BUMP_MAP, "BUMP_MAP"), + "bump_map_major_version_zero": ( + BUMP_MAP_MAJOR_VERSION_ZERO, + "BUMP_MAP_MAJOR_VERSION_ZERO", + ), + "bump_message": (BUMP_MESSAGE, "BUMP_MESSAGE"), + "change_type_order": (CHANGE_TYPE_ORDER, "CHANGE_TYPE_ORDER"), + "encoding": (ENCODING, "ENCODING"), + "name": (DEFAULT_SETTINGS["name"], "DEFAULT_SETTINGS['name']"), } + if name in deprecated_vars: + value, replacement = deprecated_vars[name] + warnings.warn( + f"{name} is deprecated and will be removed in v5. " + f"Use {replacement} instead.", + DeprecationWarning, + stacklevel=2, + ) + return value + raise AttributeError(f"{name} is not an attribute of {__name__}") diff --git a/commitizen/exceptions.py b/commitizen/exceptions.py index 29733b624b..8c0956be53 100644 --- a/commitizen/exceptions.py +++ b/commitizen/exceptions.py @@ -1,4 +1,5 @@ import enum +from typing import Any from commitizen import out @@ -40,7 +41,7 @@ class ExitCode(enum.IntEnum): class CommitizenException(Exception): - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: Any) -> None: self.output_method = kwargs.get("output_method") or out.error self.exit_code: ExitCode = self.__class__.exit_code if args: @@ -50,14 +51,14 @@ def __init__(self, *args, **kwargs): else: self.message = "" - def __str__(self): + def __str__(self) -> str: return self.message class ExpectedExit(CommitizenException): exit_code = ExitCode.EXPECTED_EXIT - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: Any) -> None: output_method = kwargs.get("output_method") or out.write kwargs["output_method"] = output_method super().__init__(*args, **kwargs) diff --git a/commitizen/git.py b/commitizen/git.py index 19ca46b6c3..fb59750eaf 100644 --- a/commitizen/git.py +++ b/commitizen/git.py @@ -2,56 +2,64 @@ import os from enum import Enum -from os import linesep +from functools import lru_cache from pathlib import Path from tempfile import NamedTemporaryFile from commitizen import cmd, out from commitizen.exceptions import GitCommandError -UNIX_EOL = "\n" -WINDOWS_EOL = "\r\n" - -class EOLTypes(Enum): +class EOLType(Enum): """The EOL type from `git config core.eol`.""" LF = "lf" CRLF = "crlf" NATIVE = "native" - def get_eol_for_open(self) -> str: + @classmethod + def for_open(cls) -> str: + c = cmd.run("git config core.eol") + eol = c.out.strip().upper() + return cls._char_for_open()[cls._safe_cast(eol)] + + @classmethod + def _safe_cast(cls, eol: str) -> EOLType: + try: + return cls[eol] + except KeyError: + return cls.NATIVE + + @classmethod + @lru_cache + def _char_for_open(cls) -> dict[EOLType, str]: """Get the EOL character for `open()`.""" - map = { - EOLTypes.CRLF: WINDOWS_EOL, - EOLTypes.LF: UNIX_EOL, - EOLTypes.NATIVE: linesep, + return { + cls.LF: "\n", + cls.CRLF: "\r\n", + cls.NATIVE: os.linesep, } - return map[self] - class GitObject: rev: str name: str date: str - def __eq__(self, other) -> bool: - if not hasattr(other, "rev"): - return False - return self.rev == other.rev # type: ignore + def __eq__(self, other: object) -> bool: + return hasattr(other, "rev") and self.rev == other.rev class GitCommit(GitObject): def __init__( self, - rev, - title, + rev: str, + title: str, body: str = "", author: str = "", author_email: str = "", parents: list[str] | None = None, - ): + ) -> None: self.rev = rev.strip() self.title = title.strip() self.body = body.strip() @@ -60,26 +68,86 @@ def __init__( self.parents = parents or [] @property - def message(self): + def message(self) -> str: return f"{self.title}\n\n{self.body}".strip() - def __repr__(self): + @classmethod + def from_rev_and_commit(cls, rev_and_commit: str) -> GitCommit: + """Create a GitCommit instance from a formatted commit string. + + This method parses a multi-line string containing commit information in the following format: + ``` + + + + <author> + <author_email> + <body_line_1> + <body_line_2> + ... + ``` + + Args: + rev_and_commit (str): A string containing commit information with fields separated by newlines. + - rev: The commit hash/revision + - parents: Space-separated list of parent commit hashes + - title: The commit title/message + - author: The commit author's name + - author_email: The commit author's email + - body: Optional multi-line commit body + + Returns: + GitCommit: A new GitCommit instance with the parsed information. + + Example: + >>> commit_str = '''abc123 + ... def456 ghi789 + ... feat: add new feature + ... John Doe + ... john@example.com + ... This is a detailed description + ... of the new feature''' + >>> commit = GitCommit.from_rev_and_commit(commit_str) + >>> commit.rev + 'abc123' + >>> commit.title + 'feat: add new feature' + >>> commit.parents + ['def456', 'ghi789'] + """ + rev, parents, title, author, author_email, *body_list = rev_and_commit.split( + "\n" + ) + return cls( + rev=rev.strip(), + title=title.strip(), + body="\n".join(body_list).strip(), + author=author, + author_email=author_email, + parents=[p for p in parents.strip().split(" ") if p], + ) + + def __repr__(self) -> str: return f"{self.title} ({self.rev})" class GitTag(GitObject): - def __init__(self, name, rev, date): + def __init__(self, name: str, rev: str, date: str) -> None: self.rev = rev.strip() self.name = name.strip() self._date = date.strip() - def __repr__(self): + def __repr__(self) -> str: return f"GitTag('{self.name}', '{self.rev}', '{self.date}')" @property - def date(self): + def date(self) -> str: return self._date + @date.setter + def date(self, value: str) -> None: + self._date = value + @classmethod def from_line(cls, line: str, inner_delimiter: str) -> GitTag: name, objectname, date, obj = line.split(inner_delimiter) @@ -101,13 +169,11 @@ def tag( # according to https://git-scm.com/book/en/v2/Git-Basics-Tagging, # we're not able to create lightweight tag with message. # by adding message, we make it a annotated tags - c = cmd.run(f'git tag {_opt} "{tag if _opt == "" or msg is None else msg}"') - return c + return cmd.run(f'git tag {_opt} "{tag if _opt == "" or msg is None else msg}"') def add(*args: str) -> cmd.Command: - c = cmd.run(f"git add {' '.join(args)}") - return c + return cmd.run(f"git add {' '.join(args)}") def commit( @@ -119,19 +185,22 @@ def commit( f.write(message.encode("utf-8")) f.close() - command = f'git commit {args} -F "{f.name}"' - - if committer_date and os.name == "nt": # pragma: no cover - # Using `cmd /v /c "{command}"` sets environment variables only for that command - command = f'cmd /v /c "set GIT_COMMITTER_DATE={committer_date}&& {command}"' - elif committer_date: - command = f"GIT_COMMITTER_DATE={committer_date} {command}" - + command = _create_commit_cmd_string(args, committer_date, f.name) c = cmd.run(command) os.unlink(f.name) return c +def _create_commit_cmd_string(args: str, committer_date: str | None, name: str) -> str: + command = f'git commit {args} -F "{name}"' + if not committer_date: + return command + if os.name != "nt": + return f"GIT_COMMITTER_DATE={committer_date} {command}" + # Using `cmd /v /c "{command}"` sets environment variables only for that command + return f'cmd /v /c "set GIT_COMMITTER_DATE={committer_date}&& {command}"' + + def get_commits( start: str | None = None, end: str = "HEAD", @@ -140,27 +209,13 @@ def get_commits( ) -> list[GitCommit]: """Get the commits between start and end.""" git_log_entries = _get_log_as_str_list(start, end, args) - git_commits = [] - for rev_and_commit in git_log_entries: - if not rev_and_commit: - continue - rev, parents, title, author, author_email, *body_list = rev_and_commit.split( - "\n" - ) - if rev_and_commit: - git_commit = GitCommit( - rev=rev.strip(), - title=title.strip(), - body="\n".join(body_list).strip(), - author=author, - author_email=author_email, - parents=[p for p in parents.strip().split(" ") if p], - ) - git_commits.append(git_commit) - return git_commits - - -def get_filenames_in_commit(git_reference: str = ""): + return [ + GitCommit.from_rev_and_commit(rev_and_commit) + for rev_and_commit in filter(None, git_log_entries) + ] + + +def get_filenames_in_commit(git_reference: str = "") -> list[str]: """Get the list of files that were committed in the requested git reference. :param git_reference: a git reference as accepted by `git show`, default: the last commit @@ -170,8 +225,7 @@ def get_filenames_in_commit(git_reference: str = ""): c = cmd.run(f"git show --name-only --pretty=format: {git_reference}") if c.return_code == 0: return c.out.strip().split("\n") - else: - raise GitCommandError(c.err) + raise GitCommandError(c.err) def get_tags( @@ -197,16 +251,11 @@ def get_tags( if c.err: out.warn(f"Attempting to proceed after: {c.err}") - if not c.out: - return [] - - git_tags = [ + return [ GitTag.from_line(line=line, inner_delimiter=inner_delimiter) for line in c.out.split("\n")[:-1] ] - return git_tags - def tag_exist(tag: str) -> bool: c = cmd.run(f"git tag --list {tag}") @@ -231,18 +280,18 @@ def get_tag_message(tag: str) -> str | None: return c.out.strip() -def get_tag_names() -> list[str | None]: +def get_tag_names() -> list[str]: c = cmd.run("git tag --list") if c.err: return [] - return [tag.strip() for tag in c.out.split("\n") if tag.strip()] + return [tag for raw in c.out.split("\n") if (tag := raw.strip())] def find_git_project_root() -> Path | None: c = cmd.run("git rev-parse --show-toplevel") - if not c.err: - return Path(c.out.strip()) - return None + if c.err: + return None + return Path(c.out.strip()) def is_staging_clean() -> bool: @@ -253,32 +302,7 @@ def is_staging_clean() -> bool: def is_git_project() -> bool: c = cmd.run("git rev-parse --is-inside-work-tree") - if c.out.strip() == "true": - return True - return False - - -def get_eol_style() -> EOLTypes: - c = cmd.run("git config core.eol") - eol = c.out.strip().lower() - - # We enumerate the EOL types of the response of - # `git config core.eol`, and map it to our enumration EOLTypes. - # - # It is just like the variant of the "match" syntax. - map = { - "lf": EOLTypes.LF, - "crlf": EOLTypes.CRLF, - "native": EOLTypes.NATIVE, - } - - # If the response of `git config core.eol` is in the map: - if eol in map: - return map[eol] - else: - # The default value is "native". - # https://git-scm.com/docs/git-config#Documentation/git-config.txt-coreeol - return map["native"] + return c.out.strip() == "true" def get_core_editor() -> str | None: @@ -288,22 +312,18 @@ def get_core_editor() -> str | None: return None -def smart_open(*args, **kargs): +def smart_open(*args, **kwargs): # type: ignore[no-untyped-def,unused-ignore] # noqa: ANN201 """Open a file with the EOL style determined from Git.""" - return open(*args, newline=get_eol_style().get_eol_for_open(), **kargs) + return open(*args, newline=EOLType.for_open(), **kwargs) def _get_log_as_str_list(start: str | None, end: str, args: str) -> list[str]: """Get string representation of each log entry""" delimiter = "----------commit-delimiter----------" log_format: str = "%H%n%P%n%s%n%an%n%ae%n%b" - git_log_cmd = ( - f"git -c log.showSignature=False log --pretty={log_format}{delimiter} {args}" - ) - if start: - command = f"{git_log_cmd} {start}..{end}" - else: - command = f"{git_log_cmd} {end}" + command_range = f"{start}..{end}" if start else end + command = f"git -c log.showSignature=False log --pretty={log_format}{delimiter} {args} {command_range}" + c = cmd.run(command) if c.return_code != 0: raise GitCommandError(c.err) diff --git a/commitizen/hooks.py b/commitizen/hooks.py index f5505d0e82..f60bd9b43e 100644 --- a/commitizen/hooks.py +++ b/commitizen/hooks.py @@ -1,12 +1,13 @@ from __future__ import annotations import os +from collections.abc import Mapping from commitizen import cmd, out from commitizen.exceptions import RunHookError -def run(hooks, _env_prefix="CZ_", **env): +def run(hooks: str | list[str], _env_prefix: str = "CZ_", **env: object) -> None: if isinstance(hooks, str): hooks = [hooks] @@ -24,7 +25,7 @@ def run(hooks, _env_prefix="CZ_", **env): raise RunHookError(f"Running hook '{hook}' failed") -def _format_env(prefix: str, env: dict[str, str]) -> dict[str, str]: +def _format_env(prefix: str, env: Mapping[str, object]) -> dict[str, str]: """_format_env() prefixes all given environment variables with the given prefix so it can be passed directly to cmd.run().""" penv = dict(os.environ) diff --git a/commitizen/out.py b/commitizen/out.py index 40342e9de5..1bbfe4329d 100644 --- a/commitizen/out.py +++ b/commitizen/out.py @@ -1,5 +1,6 @@ import io import sys +from typing import Any from termcolor import colored @@ -8,12 +9,12 @@ sys.stdout.reconfigure(encoding="utf-8") -def write(value: str, *args) -> None: +def write(value: str, *args: object) -> None: """Intended to be used when value is multiline.""" print(value, *args) -def line(value: str, *args, **kwargs) -> None: +def line(value: str, *args: object, **kwargs: Any) -> None: """Wrapper in case I want to do something different later.""" print(value, *args, **kwargs) @@ -33,7 +34,7 @@ def info(value: str) -> None: line(message) -def diagnostic(value: str): +def diagnostic(value: str) -> None: line(value, file=sys.stderr) diff --git a/commitizen/providers/__init__.py b/commitizen/providers/__init__.py index 9cf4ce5927..3e01fe22f8 100644 --- a/commitizen/providers/__init__.py +++ b/commitizen/providers/__init__.py @@ -21,7 +21,6 @@ from commitizen.providers.uv_provider import UvProvider __all__ = [ - "get_provider", "CargoProvider", "CommitizenProvider", "ComposerProvider", @@ -30,6 +29,7 @@ "PoetryProvider", "ScmProvider", "UvProvider", + "get_provider", ] PROVIDER_ENTRYPOINT = "commitizen.provider" diff --git a/commitizen/providers/base_provider.py b/commitizen/providers/base_provider.py index 34048318e2..27c3123416 100644 --- a/commitizen/providers/base_provider.py +++ b/commitizen/providers/base_provider.py @@ -2,6 +2,7 @@ import json from abc import ABC, abstractmethod +from collections.abc import Mapping from pathlib import Path from typing import Any, ClassVar @@ -19,7 +20,7 @@ class VersionProvider(ABC): config: BaseConfig - def __init__(self, config: BaseConfig): + def __init__(self, config: BaseConfig) -> None: self.config = config @abstractmethod @@ -29,7 +30,7 @@ def get_version(self) -> str: """ @abstractmethod - def set_version(self, version: str): + def set_version(self, version: str) -> None: """ Set the new current version """ @@ -58,15 +59,15 @@ def get_version(self) -> str: document = json.loads(self.file.read_text()) return self.get(document) - def set_version(self, version: str): + def set_version(self, version: str) -> None: document = json.loads(self.file.read_text()) self.set(document, version) self.file.write_text(json.dumps(document, indent=self.indent) + "\n") - def get(self, document: dict[str, Any]) -> str: - return document["version"] # type: ignore + def get(self, document: Mapping[str, str]) -> str: + return document["version"] - def set(self, document: dict[str, Any], version: str): + def set(self, document: dict[str, Any], version: str) -> None: document["version"] = version @@ -79,7 +80,7 @@ def get_version(self) -> str: document = tomlkit.parse(self.file.read_text()) return self.get(document) - def set_version(self, version: str): + def set_version(self, version: str) -> None: document = tomlkit.parse(self.file.read_text()) self.set(document, version) self.file.write_text(tomlkit.dumps(document)) @@ -87,5 +88,5 @@ def set_version(self, version: str): def get(self, document: tomlkit.TOMLDocument) -> str: return document["project"]["version"] # type: ignore - def set(self, document: tomlkit.TOMLDocument, version: str): + def set(self, document: tomlkit.TOMLDocument, version: str) -> None: document["project"]["version"] = version # type: ignore diff --git a/commitizen/providers/cargo_provider.py b/commitizen/providers/cargo_provider.py index 2e73ff35a1..87e45cd71c 100644 --- a/commitizen/providers/cargo_provider.py +++ b/commitizen/providers/cargo_provider.py @@ -28,7 +28,7 @@ def get(self, document: tomlkit.TOMLDocument) -> str: ... return document["workspace"]["package"]["version"] # type: ignore - def set(self, document: tomlkit.TOMLDocument, version: str): + def set(self, document: tomlkit.TOMLDocument, version: str) -> None: try: document["workspace"]["package"]["version"] = version # type: ignore return diff --git a/commitizen/providers/commitizen_provider.py b/commitizen/providers/commitizen_provider.py index a1da25ff72..7ce177a604 100644 --- a/commitizen/providers/commitizen_provider.py +++ b/commitizen/providers/commitizen_provider.py @@ -11,5 +11,5 @@ class CommitizenProvider(VersionProvider): def get_version(self) -> str: return self.config.settings["version"] # type: ignore - def set_version(self, version: str): + def set_version(self, version: str) -> None: self.config.set_key("version", version) diff --git a/commitizen/providers/npm_provider.py b/commitizen/providers/npm_provider.py index 12900ff7de..3125447250 100644 --- a/commitizen/providers/npm_provider.py +++ b/commitizen/providers/npm_provider.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from collections.abc import Mapping from pathlib import Path from typing import Any, ClassVar @@ -58,8 +59,8 @@ def set_version(self, version: str) -> None: json.dumps(shrinkwrap_document, indent=self.indent) + "\n" ) - def get_package_version(self, document: dict[str, Any]) -> str: - return document["version"] # type: ignore + def get_package_version(self, document: Mapping[str, str]) -> str: + return document["version"] def set_package_version( self, document: dict[str, Any], version: str diff --git a/commitizen/providers/poetry_provider.py b/commitizen/providers/poetry_provider.py index 7aa28f56d9..1dd33f053e 100644 --- a/commitizen/providers/poetry_provider.py +++ b/commitizen/providers/poetry_provider.py @@ -15,5 +15,5 @@ class PoetryProvider(TomlProvider): def get(self, pyproject: tomlkit.TOMLDocument) -> str: return pyproject["tool"]["poetry"]["version"] # type: ignore - def set(self, pyproject: tomlkit.TOMLDocument, version: str): + def set(self, pyproject: tomlkit.TOMLDocument, version: str) -> None: pyproject["tool"]["poetry"]["version"] = version # type: ignore diff --git a/commitizen/providers/scm_provider.py b/commitizen/providers/scm_provider.py index cb575148cb..3085b16efa 100644 --- a/commitizen/providers/scm_provider.py +++ b/commitizen/providers/scm_provider.py @@ -23,6 +23,6 @@ def get_version(self) -> str: return "0.0.0" return str(versions[-1]) - def set_version(self, version: str): + def set_version(self, version: str) -> None: # Not necessary pass diff --git a/commitizen/tags.py b/commitizen/tags.py index 2b9a4b091a..b19bb89e09 100644 --- a/commitizen/tags.py +++ b/commitizen/tags.py @@ -2,7 +2,7 @@ import re import warnings -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from functools import cached_property from string import Template @@ -89,14 +89,14 @@ class TagRules: merge_prereleases: bool = False @cached_property - def version_regexes(self) -> Sequence[re.Pattern]: + def version_regexes(self) -> list[re.Pattern]: """Regexes for all legit tag formats, current and legacy""" tag_formats = [self.tag_format, *self.legacy_tag_formats] regexes = (self._format_regex(p) for p in tag_formats) return [re.compile(r) for r in regexes] @cached_property - def ignored_regexes(self) -> Sequence[re.Pattern]: + def ignored_regexes(self) -> list[re.Pattern]: """Regexes for known but ignored tag formats""" regexes = (self._format_regex(p, star=True) for p in self.ignored_tag_formats) return [re.compile(r) for r in regexes] @@ -135,8 +135,8 @@ def is_ignored_tag(self, tag: str | GitTag) -> bool: return any(regex.match(tag) for regex in self.ignored_regexes) def get_version_tags( - self, tags: Sequence[GitTag], warn: bool = False - ) -> Sequence[GitTag]: + self, tags: Iterable[GitTag], warn: bool = False + ) -> list[GitTag]: """Filter in version tags and warn on unexpected tags""" return [tag for tag in tags if self.is_version_tag(tag, warn)] @@ -174,11 +174,7 @@ def include_in_changelog(self, tag: GitTag) -> bool: version = self.extract_version(tag) except InvalidVersion: return False - - if self.merge_prereleases and version.is_prerelease: - return False - - return True + return not (self.merge_prereleases and version.is_prerelease) def search_version(self, text: str, last: bool = False) -> VersionTag | None: """ @@ -240,15 +236,15 @@ def normalize_tag( ) def find_tag_for( - self, tags: Sequence[GitTag], version: Version | str + self, tags: Iterable[GitTag], version: Version | str ) -> GitTag | None: """Find the first matching tag for a given version.""" version = self.scheme(version) if isinstance(version, str) else version - possible_tags = [ + possible_tags = set( self.normalize_tag(version, f) for f in (self.tag_format, *self.legacy_tag_formats) - ] - candidates = [t for t in tags if any(t.name == p for p in possible_tags)] + ) + candidates = [t for t in tags if t.name in possible_tags] if len(candidates) > 1: warnings.warn( UserWarning( diff --git a/commitizen/version_schemes.py b/commitizen/version_schemes.py index 84ded9316e..a59d3c0aa0 100644 --- a/commitizen/version_schemes.py +++ b/commitizen/version_schemes.py @@ -19,7 +19,7 @@ else: import importlib_metadata as metadata -from packaging.version import InvalidVersion # noqa: F401: expose the common exception +from packaging.version import InvalidVersion # noqa: F401 (expose the common exception) from packaging.version import Version as _BaseVersion from commitizen.defaults import MAJOR, MINOR, PATCH, Settings @@ -41,7 +41,9 @@ Increment: TypeAlias = Literal["MAJOR", "MINOR", "PATCH"] Prerelease: TypeAlias = Literal["alpha", "beta", "rc"] -DEFAULT_VERSION_PARSER = r"v?(?P<version>([0-9]+)\.([0-9]+)(?:\.([0-9]+))?(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+[0-9A-Za-z.]+)?(\w+)?)" +_DEFAULT_VERSION_PARSER = re.compile( + r"v?(?P<version>([0-9]+)\.([0-9]+)(?:\.([0-9]+))?(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+[0-9A-Za-z.]+)?(\w+)?)" +) @runtime_checkable @@ -49,7 +51,7 @@ class VersionProtocol(Protocol): parser: ClassVar[re.Pattern] """Regex capturing this version scheme into a `version` group""" - def __init__(self, version: str): + def __init__(self, version: str) -> None: """ Initialize a version object from its string representation. @@ -156,7 +158,7 @@ class BaseVersion(_BaseVersion): A base class implementing the `VersionProtocol` for PEP440-like versions. """ - parser: ClassVar[re.Pattern] = re.compile(DEFAULT_VERSION_PARSER) + parser: ClassVar[re.Pattern] = _DEFAULT_VERSION_PARSER """Regex capturing this version scheme into a `version` group""" @property @@ -265,39 +267,35 @@ def bump( if self.local and is_local_version: local_version = self.scheme(self.local).bump(increment) return self.scheme(f"{self.public}+{local_version}") # type: ignore - else: - if not self.is_prerelease: - base = self.increment_base(increment) - elif exact_increment: - base = self.increment_base(increment) - else: - base = f"{self.major}.{self.minor}.{self.micro}" - if increment == PATCH: - pass - elif increment == MINOR: - if self.micro != 0: - base = self.increment_base(increment) - elif increment == MAJOR: - if self.minor != 0 or self.micro != 0: - base = self.increment_base(increment) - dev_version = self.generate_devrelease(devrelease) - - release = list(self.release) - if len(release) < 3: - release += [0] * (3 - len(release)) - current_base = ".".join(str(part) for part in release) - if base == current_base: - pre_version = self.generate_prerelease( - prerelease, offset=prerelease_offset - ) - else: - base_version = cast(BaseVersion, self.scheme(base)) - pre_version = base_version.generate_prerelease( - prerelease, offset=prerelease_offset - ) - build_metadata = self.generate_build_metadata(build_metadata) - # TODO: post version - return self.scheme(f"{base}{pre_version}{dev_version}{build_metadata}") # type: ignore + + base = self._get_increment_base(increment, exact_increment) + dev_version = self.generate_devrelease(devrelease) + + release = list(self.release) + if len(release) < 3: + release += [0] * (3 - len(release)) + current_base = ".".join(str(part) for part in release) + + pre_version = ( + self if base == current_base else cast(BaseVersion, self.scheme(base)) + ).generate_prerelease(prerelease, offset=prerelease_offset) + + # TODO: post version + return self.scheme( + f"{base}{pre_version}{dev_version}{self.generate_build_metadata(build_metadata)}" + ) # type: ignore + + def _get_increment_base( + self, increment: Increment | None, exact_increment: bool + ) -> str: + if ( + not self.is_prerelease + or exact_increment + or (increment == MINOR and self.micro != 0) + or (increment == MAJOR and (self.minor != 0 or self.micro != 0)) + ): + return self.increment_base(increment) + return f"{self.major}.{self.minor}.{self.micro}" class Pep440(BaseVersion): @@ -316,7 +314,7 @@ class SemVer(BaseVersion): """ def __str__(self) -> str: - parts = [] + parts: list[str] = [] # Epoch if self.epoch != 0: @@ -364,7 +362,7 @@ def prerelease(self) -> str | None: return None def __str__(self) -> str: - parts = [] + parts: list[str] = [] # Epoch if self.epoch != 0: @@ -373,9 +371,19 @@ def __str__(self) -> str: # Release segment parts.append(".".join(str(x) for x in self.release)) + if prerelease := self._get_prerelease(): + parts.append(f"-{prerelease}") + + # Local version segment + if self.local: + parts.append(f"+{self.local}") + + return "".join(parts) + + def _get_prerelease(self) -> str: # Pre-release identifiers # See: https://semver.org/spec/v2.0.0.html#spec-item-9 - prerelease_parts = [] + prerelease_parts: list[str] = [] if self.prerelease: prerelease_parts.append(f"{self.prerelease}") @@ -387,15 +395,7 @@ def __str__(self) -> str: if self.dev is not None: prerelease_parts.append(f"dev.{self.dev}") - if prerelease_parts: - parts.append("-") - parts.append(".".join(prerelease_parts)) - - # Local version segment - if self.local: - parts.append(f"+{self.local}") - - return "".join(parts) + return ".".join(prerelease_parts) DEFAULT_SCHEME: VersionScheme = Pep440 @@ -419,7 +419,7 @@ def get_version_scheme(settings: Settings, name: str | None = None) -> VersionSc if deprecated_setting: warnings.warn( DeprecationWarning( - "`version_type` setting is deprecated and will be removed in commitizen 4. " + "`version_type` setting is deprecated and will be removed in v5. " "Please use `version_scheme` instead" ) ) diff --git a/docs/commands/commit.md b/docs/commands/commit.md index 5a073a2644..be9d193b97 100644 --- a/docs/commands/commit.md +++ b/docs/commands/commit.md @@ -42,7 +42,7 @@ cz c -a -- -n # Stage all changes and skip the pre-commit and commit- ``` !!! warning - The `--signoff` option (or `-s`) is now recommended being used with the new syntax: `cz commit -- -s`. The old syntax `cz commit --signoff` is deprecated. + The `--signoff` option (or `-s`) is now recommended being used with the new syntax: `cz commit -- -s`. The old syntax `cz commit --signoff` is deprecated and will be removed in v5. ### Retry diff --git a/pyproject.toml b/pyproject.toml index cff51a2094..f3d73d53ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,9 @@ line-length = 88 [tool.ruff.lint] select = [ + # flake8-annotations + "ANN001", + "ANN2", # pycodestyle "E", # Pyflakes @@ -194,9 +197,16 @@ select = [ "UP", # isort "I", + # unsorted-dunder-all + "RUF022", + # unused-noqa + "RUF100", ] ignore = ["E501", "D1", "D415"] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["ANN"] + [tool.ruff.lint.isort] known-first-party = ["commitizen", "tests"] @@ -204,7 +214,7 @@ known-first-party = ["commitizen", "tests"] convention = "google" [tool.mypy] -files = "commitizen" +files = ["commitizen", "tests"] disallow_untyped_decorators = true disallow_subclassing_any = true warn_return_any = true @@ -228,14 +238,14 @@ poetry_command = "" [tool.poe.tasks] format.help = "Format the code" format.sequence = [ - { cmd = "ruff check --fix commitizen tests" }, - { cmd = "ruff format commitizen tests" }, + { cmd = "ruff check --fix" }, + { cmd = "ruff format" }, ] lint.help = "Lint the code" lint.sequence = [ - { cmd = "ruff check commitizen/ tests/ --fix" }, - { cmd = "mypy commitizen/ tests/" }, + { cmd = "ruff check" }, + { cmd = "mypy" }, ] check-commit.help = "Check the commit message" diff --git a/tests/commands/test_bump_command.py b/tests/commands/test_bump_command.py index e15539d8a7..64b810e4de 100644 --- a/tests/commands/test_bump_command.py +++ b/tests/commands/test_bump_command.py @@ -12,8 +12,9 @@ from pytest_mock import MockFixture import commitizen.commands.bump as bump -from commitizen import cli, cmd, git, hooks +from commitizen import cli, cmd, defaults, git, hooks from commitizen.changelog_formats import ChangelogFormat +from commitizen.config.base_config import BaseConfig from commitizen.cz.base import BaseCommitizen from commitizen.exceptions import ( BumpTagFailedError, @@ -41,8 +42,8 @@ "fix(user): username exception", "refactor: remove ini configuration support", "refactor(config): remove ini configuration support", - "perf: update to use multiproess", - "perf(worker): update to use multiproess", + "perf: update to use multiprocess", + "perf(worker): update to use multiprocess", ), ) @pytest.mark.usefixtures("tmp_commitizen_project") @@ -1688,3 +1689,51 @@ def test_bump_warn_but_dont_fail_on_invalid_tags( assert err.count("Invalid version tag: '0.4.3.deadbeaf'") == 1 assert git.tag_exist("0.4.3") + + +def test_is_initial_tag(mocker: MockFixture, tmp_commitizen_project): + """Test the _is_initial_tag method behavior.""" + # Create a commit but no tags + create_file_and_commit("feat: initial commit") + + # Initialize Bump with minimal config + config = BaseConfig() + config.settings.update( + { + "name": defaults.DEFAULT_SETTINGS["name"], + "encoding": "utf-8", + "pre_bump_hooks": [], + "post_bump_hooks": [], + } + ) + + # Initialize with required arguments + arguments = { + "changelog": False, + "changelog_to_stdout": False, + "git_output_to_stderr": False, + "no_verify": False, + "check_consistency": False, + "retry": False, + "version_scheme": None, + "file_name": None, + "template": None, + "extras": None, + } + + bump_cmd = bump.Bump(config, arguments) # type: ignore + + # Test case 1: No current tag, not yes mode + mocker.patch("questionary.confirm", return_value=mocker.Mock(ask=lambda: True)) + assert bump_cmd._is_initial_tag(None, is_yes=False) is True + + # Test case 2: No current tag, yes mode + assert bump_cmd._is_initial_tag(None, is_yes=True) is True + + # Test case 3: Has current tag + mock_tag = mocker.Mock() + assert bump_cmd._is_initial_tag(mock_tag, is_yes=False) is False + + # Test case 4: No current tag, user denies + mocker.patch("questionary.confirm", return_value=mocker.Mock(ask=lambda: False)) + assert bump_cmd._is_initial_tag(None, is_yes=False) is False diff --git a/tests/commands/test_commit_command.py b/tests/commands/test_commit_command.py index 3a92f5af48..930e1a7a9b 100644 --- a/tests/commands/test_commit_command.py +++ b/tests/commands/test_commit_command.py @@ -523,3 +523,34 @@ def test_commit_command_shows_description_when_use_help_option( out, _ = capsys.readouterr() file_regression.check(out, extension=".txt") + + +@pytest.mark.usefixtures("staging_is_clean") +@pytest.mark.parametrize( + "out", ["no changes added to commit", "nothing added to commit"] +) +def test_commit_when_nothing_added_to_commit(config, mocker: MockFixture, out): + prompt_mock = mocker.patch("questionary.prompt") + prompt_mock.return_value = { + "prefix": "feat", + "subject": "user created", + "scope": "", + "is_breaking_change": False, + "body": "", + "footer": "", + } + + commit_mock = mocker.patch("commitizen.git.commit") + commit_mock.return_value = cmd.Command( + out=out, + err="", + stdout=out.encode(), + stderr=b"", + return_code=0, + ) + error_mock = mocker.patch("commitizen.out.error") + + commands.Commit(config, {})() + + commit_mock.assert_called_once() + error_mock.assert_called_once_with(out) diff --git a/tests/commands/test_init_command.py b/tests/commands/test_init_command.py index f617c51d8f..3f12d0bd7f 100644 --- a/tests/commands/test_init_command.py +++ b/tests/commands/test_init_command.py @@ -86,7 +86,7 @@ def test_init_without_setup_pre_commit_hook(tmpdir, mocker: MockFixture, config) def test_init_when_config_already_exists(config, capsys): # Set config path path = os.sep.join(["tests", "pyproject.toml"]) - config.add_path(path) + config.path = path commands.Init(config)() captured = capsys.readouterr() diff --git a/tests/commands/test_version_command.py b/tests/commands/test_version_command.py index 927cf55f25..3dcbed168b 100644 --- a/tests/commands/test_version_command.py +++ b/tests/commands/test_version_command.py @@ -97,7 +97,6 @@ def test_version_use_version_provider( { "report": False, "project": project, - "commitizen": False, "verbose": not project, }, )() diff --git a/tests/conftest.py b/tests/conftest.py index 60c586f2e6..1b49dcbfaa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,7 +169,7 @@ class SemverCommitizen(BaseCommitizen): "patch": "PATCH", } changelog_pattern = r"^(patch|minor|major)" - commit_parser = r"^(?P<change_type>patch|minor|major)(?:\((?P<scope>[^()\r\n]*)\)|\()?:?\s(?P<message>.+)" # noqa + commit_parser = r"^(?P<change_type>patch|minor|major)(?:\((?P<scope>[^()\r\n]*)\)|\()?:?\s(?P<message>.+)" change_type_map = { "major": "Breaking Changes", "minor": "Features", diff --git a/tests/test_bump_find_increment.py b/tests/test_bump_find_increment.py index ff24ff17a7..77e11c78c7 100644 --- a/tests/test_bump_find_increment.py +++ b/tests/test_bump_find_increment.py @@ -32,14 +32,14 @@ MAJOR_INCREMENTS_BREAKING_CHANGE_CC = [ "feat(cli): added version", "docs(README): motivation", - "BREAKING CHANGE: `extends` key in config file is now used for extending other config files", # noqa + "BREAKING CHANGE: `extends` key in config file is now used for extending other config files", "fix(setup.py): future is now required for every python version", ] MAJOR_INCREMENTS_BREAKING_CHANGE_ALT_CC = [ "feat(cli): added version", "docs(README): motivation", - "BREAKING-CHANGE: `extends` key in config file is now used for extending other config files", # noqa + "BREAKING-CHANGE: `extends` key in config file is now used for extending other config files", "fix(setup.py): future is now required for every python version", ] diff --git a/tests/test_changelog.py b/tests/test_changelog.py index b1c7c802e1..ed90ed08e4 100644 --- a/tests/test_changelog.py +++ b/tests/test_changelog.py @@ -1215,28 +1215,28 @@ def test_generate_tree_from_commits_with_no_commits(tags): ), ), ) -def test_order_changelog_tree(change_type_order, expected_reordering): - tree = changelog.order_changelog_tree(COMMITS_TREE, change_type_order) +def test_generate_ordered_changelog_tree(change_type_order, expected_reordering): + tree = changelog.generate_ordered_changelog_tree(COMMITS_TREE, change_type_order) for index, entry in enumerate(tuple(tree)): - version = tree[index]["version"] + version = entry["version"] if version in expected_reordering: # Verify that all keys are present - assert [*tree[index].keys()] == [*COMMITS_TREE[index].keys()] + assert [*entry.keys()] == [*COMMITS_TREE[index].keys()] # Verify that the reorder only impacted the returned dict and not the original expected = expected_reordering[version] - assert [*tree[index]["changes"].keys()] == expected["sorted"] + assert [*entry["changes"].keys()] == expected["sorted"] assert [*COMMITS_TREE[index]["changes"].keys()] == expected["original"] else: - assert [*entry["changes"].keys()] == [*tree[index]["changes"].keys()] + assert [*entry["changes"].keys()] == [*entry["changes"].keys()] -def test_order_changelog_tree_raises(): +def test_generate_ordered_changelog_tree_raises(): change_type_order = ["BREAKING CHANGE", "feat", "refactor", "feat"] with pytest.raises(InvalidConfigurationError) as excinfo: - changelog.order_changelog_tree(COMMITS_TREE, change_type_order) + list(changelog.generate_ordered_changelog_tree(COMMITS_TREE, change_type_order)) - assert "Change types contain duplicates types" in str(excinfo) + assert "Change types contain duplicated types" in str(excinfo) def test_render_changelog( @@ -1639,7 +1639,9 @@ def test_tags_rules_get_version_tags(capsys: pytest.CaptureFixture): def test_changelog_file_name_from_args_and_config(): mock_config = Mock(spec=BaseConfig) - mock_config.path.parent = "/my/project" + mock_path = Mock(spec=Path) + mock_path.parent = Path("/my/project") + mock_config.path = mock_path mock_config.settings = { "name": "cz_conventional_commits", "changelog_file": "CHANGELOG.md", diff --git a/tests/test_cli.py b/tests/test_cli.py index a91e633128..31371caea4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ import os import subprocess import sys +import types from functools import partial import pytest @@ -182,3 +183,59 @@ def test_unknown_args_before_double_dash_raises(mocker: MockFixture): assert "Invalid commitizen arguments were found before -- separator" in str( excinfo.value ) + + +def test_commitizen_excepthook_non_commitizen_exception(mocker: MockFixture): + """Test that commitizen_excepthook delegates to original_excepthook for non-CommitizenException.""" + # Mock the original excepthook + mock_original_excepthook = mocker.Mock() + mocker.patch("commitizen.cli.original_excepthook", mock_original_excepthook) + + # Create a regular exception + test_exception = ValueError("test error") + + # Call commitizen_excepthook with the regular exception + cli.commitizen_excepthook(ValueError, test_exception, None) + + # Verify original_excepthook was called with correct arguments + mock_original_excepthook.assert_called_once_with(ValueError, test_exception, None) + + +def test_commitizen_excepthook_non_commitizen_exception_with_traceback( + mocker: MockFixture, +): + """Test that commitizen_excepthook handles traceback correctly for non-CommitizenException.""" + # Mock the original excepthook + mock_original_excepthook = mocker.Mock() + mocker.patch("commitizen.cli.original_excepthook", mock_original_excepthook) + + # Create a regular exception with a traceback + test_exception = ValueError("test error") + test_traceback = mocker.Mock(spec=types.TracebackType) + + # Call commitizen_excepthook with the regular exception and traceback + cli.commitizen_excepthook(ValueError, test_exception, test_traceback) + + # Verify original_excepthook was called with correct arguments including traceback + mock_original_excepthook.assert_called_once_with( + ValueError, test_exception, test_traceback + ) + + +def test_commitizen_excepthook_non_commitizen_exception_with_invalid_traceback( + mocker: MockFixture, +): + """Test that commitizen_excepthook handles invalid traceback correctly for non-CommitizenException.""" + # Mock the original excepthook + mock_original_excepthook = mocker.Mock() + mocker.patch("commitizen.cli.original_excepthook", mock_original_excepthook) + + # Create a regular exception with an invalid traceback + test_exception = ValueError("test error") + test_traceback = mocker.Mock() # Not a TracebackType + + # Call commitizen_excepthook with the regular exception and invalid traceback + cli.commitizen_excepthook(ValueError, test_exception, test_traceback) + + # Verify original_excepthook was called with None as traceback + mock_original_excepthook.assert_called_once_with(ValueError, test_exception, None) diff --git a/tests/test_cz_base.py b/tests/test_cz_base.py index 4ee1cc6eda..be93b4ca0f 100644 --- a/tests/test_cz_base.py +++ b/tests/test_cz_base.py @@ -42,9 +42,3 @@ def test_info(config): cz = DummyCz(config) with pytest.raises(NotImplementedError): cz.info() - - -def test_process_commit(config): - cz = DummyCz(config) - message = cz.process_commit("test(test_scope): this is test msg") - assert message == "test(test_scope): this is test msg" diff --git a/tests/test_cz_conventional_commits.py b/tests/test_cz_conventional_commits.py index 6d4e0f7435..c96e036707 100644 --- a/tests/test_cz_conventional_commits.py +++ b/tests/test_cz_conventional_commits.py @@ -2,48 +2,42 @@ from commitizen.cz.conventional_commits.conventional_commits import ( ConventionalCommitsCz, - parse_scope, - parse_subject, + _parse_scope, + _parse_subject, ) from commitizen.cz.exceptions import AnswerRequiredError -valid_scopes = ["", "simple", "dash-separated", "camelCaseUPPERCASE"] -scopes_transformations = [["with spaces", "with-spaces"], [None, ""]] - -valid_subjects = ["this is a normal text", "aword"] - -subjects_transformations = [["with dot.", "with dot"]] - -invalid_subjects = ["", " ", ".", " .", "", None] - - -def test_parse_scope_valid_values(): - for valid_scope in valid_scopes: - assert valid_scope == parse_scope(valid_scope) +@pytest.mark.parametrize( + "valid_scope", ["", "simple", "dash-separated", "camelCaseUPPERCASE"] +) +def test_parse_scope_valid_values(valid_scope): + assert valid_scope == _parse_scope(valid_scope) -def test_scopes_transformations(): - for scopes_transformation in scopes_transformations: - invalid_scope, transformed_scope = scopes_transformation - assert transformed_scope == parse_scope(invalid_scope) +@pytest.mark.parametrize( + "scopes_transformation", [["with spaces", "with-spaces"], ["", ""]] +) +def test_scopes_transformations(scopes_transformation): + invalid_scope, transformed_scope = scopes_transformation + assert transformed_scope == _parse_scope(invalid_scope) -def test_parse_subject_valid_values(): - for valid_subject in valid_subjects: - assert valid_subject == parse_subject(valid_subject) +@pytest.mark.parametrize("valid_subject", ["this is a normal text", "aword"]) +def test_parse_subject_valid_values(valid_subject): + assert valid_subject == _parse_subject(valid_subject) -def test_parse_subject_invalid_values(): - for valid_subject in invalid_subjects: - with pytest.raises(AnswerRequiredError): - parse_subject(valid_subject) +@pytest.mark.parametrize("invalid_subject", ["", " ", ".", " .", "\t\t."]) +def test_parse_subject_invalid_values(invalid_subject): + with pytest.raises(AnswerRequiredError): + _parse_subject(invalid_subject) -def test_subject_transformations(): - for subject_transformation in subjects_transformations: - invalid_subject, transformed_subject = subject_transformation - assert transformed_subject == parse_subject(invalid_subject) +@pytest.mark.parametrize("subject_transformation", [["with dot.", "with dot"]]) +def test_subject_transformations(subject_transformation): + invalid_subject, transformed_subject = subject_transformation + assert transformed_subject == _parse_subject(invalid_subject) def test_questions(config): @@ -89,7 +83,7 @@ def test_long_answer(config): message = conventional_commits.message(answers) assert ( message - == "fix(users): email pattern corrected\n\ncomplete content\n\ncloses #24" # noqa + == "fix(users): email pattern corrected\n\ncomplete content\n\ncloses #24" ) @@ -107,7 +101,7 @@ def test_breaking_change_in_footer(config): print(message) assert ( message - == "fix(users): email pattern corrected\n\ncomplete content\n\nBREAKING CHANGE: migrate by renaming user to users" # noqa + == "fix(users): email pattern corrected\n\ncomplete content\n\nBREAKING CHANGE: migrate by renaming user to users" ) @@ -130,26 +124,3 @@ def test_info(config): conventional_commits = ConventionalCommitsCz(config) info = conventional_commits.info() assert isinstance(info, str) - - -@pytest.mark.parametrize( - ("commit_message", "expected_message"), - [ - ( - "test(test_scope): this is test msg", - "this is test msg", - ), - ( - "test(test_scope)!: this is test msg", - "this is test msg", - ), - ( - "test!(test_scope): this is test msg", - "", - ), - ], -) -def test_process_commit(commit_message, expected_message, config): - conventional_commits = ConventionalCommitsCz(config) - message = conventional_commits.process_commit(commit_message) - assert message == expected_message diff --git a/tests/test_defaults.py b/tests/test_defaults.py new file mode 100644 index 0000000000..73cd35b80c --- /dev/null +++ b/tests/test_defaults.py @@ -0,0 +1,29 @@ +import pytest + +from commitizen import defaults + + +def test_getattr_deprecated_vars(): + # Test each deprecated variable + with pytest.warns(DeprecationWarning) as record: + assert defaults.bump_pattern == defaults.BUMP_PATTERN + assert defaults.bump_map == defaults.BUMP_MAP + assert ( + defaults.bump_map_major_version_zero == defaults.BUMP_MAP_MAJOR_VERSION_ZERO + ) + assert defaults.bump_message == defaults.BUMP_MESSAGE + assert defaults.change_type_order == defaults.CHANGE_TYPE_ORDER + assert defaults.encoding == defaults.ENCODING + assert defaults.name == defaults.DEFAULT_SETTINGS["name"] + + # Verify warning messages + assert len(record) == 7 + for warning in record: + assert "is deprecated and will be removed" in str(warning.message) + + +def test_getattr_non_existent(): + # Test non-existent attribute + with pytest.raises(AttributeError) as exc_info: + _ = defaults.non_existent_attribute + assert "is not an attribute of" in str(exc_info.value) diff --git a/tests/test_git.py b/tests/test_git.py index 8b2fc2b86e..e242b3a2ae 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -18,6 +18,13 @@ ) +@pytest.mark.parametrize("date", ["2020-01-21", "1970-01-01"]) +def test_git_tag_date(date: str): + git_tag = git.GitTag(rev="sha1-code", name="0.0.1", date="2025-05-30") + git_tag.date = date + assert git_tag.date == date + + def test_git_object_eq(): git_commit = git.GitCommit( rev="sha1-code", title="this is title", body="this is body" @@ -79,8 +86,7 @@ def test_get_reachable_tags_with_commits( monkeypatch.setenv("LANGUAGE", f"{locale}.UTF-8") monkeypatch.setenv("LC_ALL", f"{locale}.UTF-8") with tmp_commitizen_project.as_cwd(): - tags = git.get_tags(reachable_only=True) - assert tags == [] + assert git.get_tags(reachable_only=True) == [] def test_get_tag_names(mocker: MockFixture): @@ -271,7 +277,7 @@ def test_get_commits_with_signature(): def test_get_tag_names_has_correct_arrow_annotation(): arrow_annotation = inspect.getfullargspec(git.get_tag_names).annotations["return"] - assert arrow_annotation == "list[str | None]" + assert arrow_annotation == "list[str]" def test_get_latest_tag_name(tmp_commitizen_project): @@ -317,24 +323,18 @@ def test_is_staging_clean_when_updating_file(tmp_commitizen_project): assert git.is_staging_clean() is False -def test_git_eol_style(tmp_commitizen_project): +def test_get_eol_for_open(tmp_commitizen_project): with tmp_commitizen_project.as_cwd(): - assert git.get_eol_style() == git.EOLTypes.NATIVE + assert git.EOLType.for_open() == os.linesep cmd.run("git config core.eol lf") - assert git.get_eol_style() == git.EOLTypes.LF + assert git.EOLType.for_open() == "\n" cmd.run("git config core.eol crlf") - assert git.get_eol_style() == git.EOLTypes.CRLF + assert git.EOLType.for_open() == "\r\n" cmd.run("git config core.eol native") - assert git.get_eol_style() == git.EOLTypes.NATIVE - - -def test_eoltypes_get_eol_for_open(): - assert git.EOLTypes.get_eol_for_open(git.EOLTypes.NATIVE) == os.linesep - assert git.EOLTypes.get_eol_for_open(git.EOLTypes.LF) == "\n" - assert git.EOLTypes.get_eol_for_open(git.EOLTypes.CRLF) == "\r\n" + assert git.EOLType.for_open() == os.linesep def test_get_core_editor(mocker): @@ -401,3 +401,82 @@ def test_commit_with_spaces_in_path(mocker, file_path, expected_cmd): mock_run.assert_called_once_with(expected_cmd) mock_unlink.assert_called_once_with(file_path) + + +def test_get_filenames_in_commit_error(mocker: MockFixture): + """Test that GitCommandError is raised when git command fails.""" + mocker.patch( + "commitizen.cmd.run", + return_value=FakeCommand(out="", err="fatal: bad object HEAD", return_code=1), + ) + with pytest.raises(exceptions.GitCommandError) as excinfo: + git.get_filenames_in_commit() + assert str(excinfo.value) == "fatal: bad object HEAD" + + +def test_git_commit_from_rev_and_commit(): + # Test data with all fields populated + rev_and_commit = ( + "abc123\n" # rev + "def456 ghi789\n" # parents + "feat: add new feature\n" # title + "John Doe\n" # author + "john@example.com\n" # author_email + "This is a detailed description\n" # body + "of the new feature\n" + "with multiple lines" + ) + + commit = git.GitCommit.from_rev_and_commit(rev_and_commit) + + assert commit.rev == "abc123" + assert commit.title == "feat: add new feature" + assert ( + commit.body + == "This is a detailed description\nof the new feature\nwith multiple lines" + ) + assert commit.author == "John Doe" + assert commit.author_email == "john@example.com" + assert commit.parents == ["def456", "ghi789"] + + # Test with minimal data + minimal_commit = ( + "abc123\n" # rev + "\n" # no parents + "feat: minimal commit\n" # title + "John Doe\n" # author + "john@example.com\n" # author_email + ) + + commit = git.GitCommit.from_rev_and_commit(minimal_commit) + + assert commit.rev == "abc123" + assert commit.title == "feat: minimal commit" + assert commit.body == "" + assert commit.author == "John Doe" + assert commit.author_email == "john@example.com" + assert commit.parents == [] + + +@pytest.mark.parametrize( + "os_name,committer_date,expected_cmd", + [ + ( + "nt", + "2024-03-20", + 'cmd /v /c "set GIT_COMMITTER_DATE=2024-03-20&& git commit -F "temp.txt""', + ), + ( + "posix", + "2024-03-20", + 'GIT_COMMITTER_DATE=2024-03-20 git commit -F "temp.txt"', + ), + ("nt", None, 'git commit -F "temp.txt"'), + ("posix", None, 'git commit -F "temp.txt"'), + ], +) +def test_create_commit_cmd_string(mocker, os_name, committer_date, expected_cmd): + """Test the OS-specific behavior of _create_commit_cmd_string""" + mocker.patch("os.name", os_name) + result = git._create_commit_cmd_string("", committer_date, "temp.txt") + assert result == expected_cmd