From d2eff5a84a923790848cb716ee22cfaad62abee1 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Wed, 2 Jul 2025 19:38:56 +0800 Subject: [PATCH 1/2] Add CLI support for execute sql scripts --- backend/app/admin/api/v1/sys/files.py | 4 +- backend/cli.py | 52 +++++++++++++++++-- backend/common/enums.py | 14 +++++ .../mysql/{init_test_data.sql => init.sql} | 0 ...flake_test_data.sql => init_snowflake.sql} | 0 .../{init_test_data.sql => init.sql} | 0 ...flake_test_data.sql => init_snowflake.sql} | 0 backend/plugin/tools.py | 30 ++++++++++- backend/utils/file_ops.py | 26 +++++++++- pyproject.toml | 1 + requirements.txt | 2 + uv.lock | 11 ++++ 12 files changed, 131 insertions(+), 9 deletions(-) rename backend/plugin/dict/sql/mysql/{init_test_data.sql => init.sql} (100%) rename backend/plugin/dict/sql/mysql/{init_snowflake_test_data.sql => init_snowflake.sql} (100%) rename backend/plugin/dict/sql/postgresql/{init_test_data.sql => init.sql} (100%) rename backend/plugin/dict/sql/postgresql/{init_snowflake_test_data.sql => init_snowflake.sql} (100%) diff --git a/backend/app/admin/api/v1/sys/files.py b/backend/app/admin/api/v1/sys/files.py index ff0fa400..e859bc96 100644 --- a/backend/app/admin/api/v1/sys/files.py +++ b/backend/app/admin/api/v1/sys/files.py @@ -8,7 +8,7 @@ from backend.common.response.response_schema import ResponseSchemaModel, response_base from backend.common.security.permission import RequestPermission from backend.common.security.rbac import DependsRBAC -from backend.utils.file_ops import file_verify, upload_file +from backend.utils.file_ops import upload_file, upload_file_verify router = APIRouter() @@ -22,6 +22,6 @@ ], ) async def upload_files(file: Annotated[UploadFile, File()]) -> ResponseSchemaModel[UploadUrl]: - file_verify(file) + upload_file_verify(file) filename = await upload_file(file) return response_base.success(data={'url': f'/static/upload/{filename}'}) diff --git a/backend/cli.py b/backend/cli.py index 1c967215..548e46cf 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -11,11 +11,15 @@ from rich.panel import Panel from rich.text import Text +from sqlalchemy import text from backend import console, get_version +from backend.common.enums import DataBaseType, PrimaryKeyType from backend.common.exception.errors import BaseExceptionMixin from backend.core.conf import settings -from backend.utils.file_ops import install_git_plugin, install_zip_plugin +from backend.database.db import async_db_session +from backend.plugin.tools import get_plugin_sql +from backend.utils.file_ops import install_git_plugin, install_zip_plugin, parse_sql_script def run(host: str, port: int, reload: bool, workers: int | None) -> None: @@ -44,7 +48,9 @@ def run(host: str, port: int, reload: bool, workers: int | None) -> None: ) -async def install_plugin(path: str, repo_url: str) -> None: +async def install_plugin( + path: str, repo_url: str, execute_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType +) -> None: if not path and not repo_url: raise cappa.Exit('path 或 repo_url 必须指定其中一项', code=1) if path and repo_url: @@ -58,10 +64,28 @@ async def install_plugin(path: str, repo_url: str) -> None: plugin_name = await install_zip_plugin(file=path) if repo_url: plugin_name = await install_git_plugin(repo_url=repo_url) + + console.print(Text(f'插件 {plugin_name} 安装成功', style='bold green')) + + sql_file = await get_plugin_sql(plugin_name, db_type, pk_type) + if sql_file and execute_sql: + console.print(Text('开始自动执行插件 SQL 脚本...', style='bold cyan')) + await execute_sql_scripts(sql_file) + except Exception as e: raise cappa.Exit(e.msg if isinstance(e, BaseExceptionMixin) else str(e), code=1) - console.print(Text(f'插件 {plugin_name} 安装成功', style='bold cyan')) + +async def execute_sql_scripts(sql_scripts: str) -> None: + async with async_db_session.begin() as db: + try: + stmts = await parse_sql_script(sql_scripts) + for stmt in stmts: + await db.execute(text(stmt)) + except Exception as e: + raise cappa.Exit(f'SQL 脚本执行失败:{e}', code=1) + + console.print(Text('SQL 脚本已执行完成', style='bold green')) @cappa.command(help='运行服务') @@ -104,9 +128,21 @@ class Add: str | None, cappa.Arg(long=True, help='Git 插件的仓库地址'), ] + execute_sql: Annotated[ + bool, + cappa.Arg(long=True, default=False, help='启用插件安装后自动执行 SQL 脚本'), + ] + db_type: Annotated[ + DataBaseType, + cappa.Arg(long=True, default='mysql', help='指定数据库类型,需启用 `--sql`'), + ] + pk_type: Annotated[ + PrimaryKeyType, + cappa.Arg(long=True, default='autoincrement', help='指定数据库主键类型,需启用 `--sql`'), + ] async def __call__(self): - await install_plugin(path=self.path, repo_url=self.repo_url) + await install_plugin(self.path, self.repo_url, self.execute_sql, self.db_type, self.pk_type) @dataclass @@ -115,11 +151,17 @@ class FbaCli: bool, cappa.Arg(short='-V', long=True, default=False, help='打印当前版本号'), ] + sql: Annotated[ + str, + cappa.Arg(long=True, default='', help='脚本文件绝对路径,使用当前数据库配置在事务中执行 SQL 脚本'), + ] subcmd: cappa.Subcommands[Run | Add | None] = None - def __call__(self): + async def __call__(self): if self.version: get_version() + if self.sql: + await execute_sql_scripts(self.sql) def main() -> None: diff --git a/backend/common/enums.py b/backend/common/enums.py index efbf7bb0..9cfaa3af 100644 --- a/backend/common/enums.py +++ b/backend/common/enums.py @@ -137,3 +137,17 @@ class UserPermissionType(StrEnum): staff = 'staff' status = 'status' multi_login = 'multi_login' + + +class DataBaseType(StrEnum): + """数据库类型""" + + mysql = 'mysql' + postgresql = 'postgresql' + + +class PrimaryKeyType(StrEnum): + """主键类型""" + + autoincrement = 'autoincrement' + snowflake = 'snowflake' diff --git a/backend/plugin/dict/sql/mysql/init_test_data.sql b/backend/plugin/dict/sql/mysql/init.sql similarity index 100% rename from backend/plugin/dict/sql/mysql/init_test_data.sql rename to backend/plugin/dict/sql/mysql/init.sql diff --git a/backend/plugin/dict/sql/mysql/init_snowflake_test_data.sql b/backend/plugin/dict/sql/mysql/init_snowflake.sql similarity index 100% rename from backend/plugin/dict/sql/mysql/init_snowflake_test_data.sql rename to backend/plugin/dict/sql/mysql/init_snowflake.sql diff --git a/backend/plugin/dict/sql/postgresql/init_test_data.sql b/backend/plugin/dict/sql/postgresql/init.sql similarity index 100% rename from backend/plugin/dict/sql/postgresql/init_test_data.sql rename to backend/plugin/dict/sql/postgresql/init.sql diff --git a/backend/plugin/dict/sql/postgresql/init_snowflake_test_data.sql b/backend/plugin/dict/sql/postgresql/init_snowflake.sql similarity index 100% rename from backend/plugin/dict/sql/postgresql/init_snowflake_test_data.sql rename to backend/plugin/dict/sql/postgresql/init_snowflake.sql diff --git a/backend/plugin/tools.py b/backend/plugin/tools.py index a57a0034..8c213d4a 100644 --- a/backend/plugin/tools.py +++ b/backend/plugin/tools.py @@ -17,7 +17,7 @@ from packaging.requirements import Requirement from starlette.concurrency import run_in_threadpool -from backend.common.enums import StatusType +from backend.common.enums import DataBaseType, PrimaryKeyType, StatusType from backend.common.exception import errors from backend.common.log import log from backend.core.conf import settings @@ -75,6 +75,34 @@ def get_plugin_models() -> list[type]: return classes +async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKeyType) -> str | None: + """ + 获取插件 SQL 脚本 + + :param plugin: 插件名称 + :param db_type: 数据库类型 + :param pk_type: 主键类型 + :return: + """ + if db_type == DataBaseType.mysql.value: + mysql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'mysql') + if pk_type == PrimaryKeyType.autoincrement: + sql_file = os.path.join(mysql_dir, 'init.sql') + else: + sql_file = os.path.join(mysql_dir, 'init_snowflake.sql') + else: + postgresql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'postgresql') + if pk_type == PrimaryKeyType.autoincrement.value: + sql_file = os.path.join(postgresql_dir, 'init.sql') + else: + sql_file = os.path.join(postgresql_dir, 'init_snowflake.sql') + + if not os.path.exists(sql_file): + return None + + return sql_file + + def load_plugin_config(plugin: str) -> dict[str, Any]: """ 加载插件配置 diff --git a/backend/utils/file_ops.py b/backend/utils/file_ops.py index f0a1de56..3d074b9d 100644 --- a/backend/utils/file_ops.py +++ b/backend/utils/file_ops.py @@ -9,6 +9,7 @@ from dulwich import porcelain from fastapi import UploadFile +from sqlparse import split from backend.common.enums import FileType from backend.common.exception import errors @@ -35,7 +36,7 @@ def build_filename(file: UploadFile) -> str: return new_filename -def file_verify(file: UploadFile) -> None: +def upload_file_verify(file: UploadFile) -> None: """ 文件验证 @@ -161,3 +162,26 @@ async def install_git_plugin(repo_url: str) -> str: await redis_client.set(f'{settings.PLUGIN_REDIS_PREFIX}:changed', 'ture') return repo_name + + +async def parse_sql_script(filepath: str) -> list[str]: + """ + 解析 SQL 脚本 + + :param filepath: 脚本文件路径 + :return: + """ + if not os.path.exists(filepath): + raise errors.NotFoundError(msg='SQL 脚本文件不存在') + + async with aiofiles.open(filepath, mode='r', encoding='utf-8') as f: + contents = await f.read(1024) + while additional_contents := await f.read(1024): + contents += additional_contents + + statements = split(contents) + for statement in statements: + if not any(statement.lower().startswith(_) for _ in ['select', 'insert']): + raise errors.RequestError(msg='SQL 脚本文件中存在非法操作,仅允许 SELECT 和 INSERT') + + return statements diff --git a/pyproject.toml b/pyproject.toml index 01ddcb5f..cc619da1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "rtoml>=0.12.0", "sqlalchemy-crud-plus>=1.10.0", "sqlalchemy[asyncio]>=2.0.40", + "sqlparse>=0.5.3", "user-agents==2.2.0", ] diff --git a/requirements.txt b/requirements.txt index 8ac129cc..341874f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -257,6 +257,8 @@ sqlalchemy==2.0.40 # sqlalchemy-crud-plus sqlalchemy-crud-plus==1.10.0 # via fastapi-best-architecture +sqlparse==0.5.3 + # via fastapi-best-architecture starlette==0.46.1 # via # asgi-correlation-id diff --git a/uv.lock b/uv.lock index c2e8e67a..b98cacc4 100644 --- a/uv.lock +++ b/uv.lock @@ -678,6 +678,7 @@ dependencies = [ { name = "rtoml" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlalchemy-crud-plus" }, + { name = "sqlparse" }, { name = "user-agents" }, ] @@ -731,6 +732,7 @@ requires-dist = [ { name = "rtoml", specifier = ">=0.12.0" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.40" }, { name = "sqlalchemy-crud-plus", specifier = ">=1.10.0" }, + { name = "sqlparse", specifier = ">=0.5.3" }, { name = "user-agents", specifier = "==2.2.0" }, ] @@ -2240,6 +2242,15 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/a2/72/a7553fb70a41559d74654d9111644e8326d5d420588b59f39de872a3da8f/sqlalchemy_crud_plus-1.10.0-py3-none-any.whl", hash = "sha256:44db51256b57aa00757e48ade1e8be5ce6b6bd334fdd6d1fed5a22e97c6b0a6b" }, ] +[[package]] +name = "sqlparse" +version = "0.5.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/e5/40/edede8dd6977b0d3da179a342c198ed100dd2aba4be081861ee5911e4da4/sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca" }, +] + [[package]] name = "starlette" version = "0.46.1" From dc5280cba14ad3ee4c15b5bca672e82eb2085f67 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Wed, 2 Jul 2025 20:02:37 +0800 Subject: [PATCH 2/2] Update the arg helps --- backend/cli.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/backend/cli.py b/backend/cli.py index 548e46cf..4ecb8d55 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -49,7 +49,7 @@ def run(host: str, port: int, reload: bool, workers: int | None) -> None: async def install_plugin( - path: str, repo_url: str, execute_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType + path: str, repo_url: str, no_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType ) -> None: if not path and not repo_url: raise cappa.Exit('path 或 repo_url 必须指定其中一项', code=1) @@ -68,7 +68,7 @@ async def install_plugin( console.print(Text(f'插件 {plugin_name} 安装成功', style='bold green')) sql_file = await get_plugin_sql(plugin_name, db_type, pk_type) - if sql_file and execute_sql: + if sql_file and not no_sql: console.print(Text('开始自动执行插件 SQL 脚本...', style='bold cyan')) await execute_sql_scripts(sql_file) @@ -128,23 +128,24 @@ class Add: str | None, cappa.Arg(long=True, help='Git 插件的仓库地址'), ] - execute_sql: Annotated[ + no_sql: Annotated[ bool, - cappa.Arg(long=True, default=False, help='启用插件安装后自动执行 SQL 脚本'), + cappa.Arg(long=True, default=False, help='禁用插件 SQL 脚本自动执行'), ] db_type: Annotated[ DataBaseType, - cappa.Arg(long=True, default='mysql', help='指定数据库类型,需启用 `--sql`'), + cappa.Arg(long=True, default='mysql', help='执行插件 SQL 脚本的数据库类型'), ] pk_type: Annotated[ PrimaryKeyType, - cappa.Arg(long=True, default='autoincrement', help='指定数据库主键类型,需启用 `--sql`'), + cappa.Arg(long=True, default='autoincrement', help='执行插件 SQL 脚本数据库主键类型'), ] async def __call__(self): - await install_plugin(self.path, self.repo_url, self.execute_sql, self.db_type, self.pk_type) + await install_plugin(self.path, self.repo_url, self.no_sql, self.db_type, self.pk_type) +@cappa.command(help='一个高效的 fba 命令行界面') @dataclass class FbaCli: version: Annotated[ @@ -153,7 +154,7 @@ class FbaCli: ] sql: Annotated[ str, - cappa.Arg(long=True, default='', help='脚本文件绝对路径,使用当前数据库配置在事务中执行 SQL 脚本'), + cappa.Arg(long=True, default='', help='在事务中执行 SQL 脚本'), ] subcmd: cappa.Subcommands[Run | Add | None] = None