Skip to content

Add CLI support for execute sql scripts #711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/app/admin/api/v1/sys/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from backend.common.response.response_schema import ResponseSchemaModel, response_base
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.utils.file_ops import file_verify, upload_file
from backend.utils.file_ops import upload_file, upload_file_verify

router = APIRouter()

Expand All @@ -22,6 +22,6 @@
],
)
async def upload_files(file: Annotated[UploadFile, File()]) -> ResponseSchemaModel[UploadUrl]:
file_verify(file)
upload_file_verify(file)
filename = await upload_file(file)
return response_base.success(data={'url': f'/static/upload/{filename}'})
53 changes: 48 additions & 5 deletions backend/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@

from rich.panel import Panel
from rich.text import Text
from sqlalchemy import text

from backend import console, get_version
from backend.common.enums import DataBaseType, PrimaryKeyType
from backend.common.exception.errors import BaseExceptionMixin
from backend.core.conf import settings
from backend.utils.file_ops import install_git_plugin, install_zip_plugin
from backend.database.db import async_db_session
from backend.plugin.tools import get_plugin_sql
from backend.utils.file_ops import install_git_plugin, install_zip_plugin, parse_sql_script


def run(host: str, port: int, reload: bool, workers: int | None) -> None:
Expand Down Expand Up @@ -44,7 +48,9 @@ def run(host: str, port: int, reload: bool, workers: int | None) -> None:
)


async def install_plugin(path: str, repo_url: str) -> None:
async def install_plugin(
path: str, repo_url: str, no_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType
) -> None:
if not path and not repo_url:
raise cappa.Exit('path 或 repo_url 必须指定其中一项', code=1)
if path and repo_url:
Expand All @@ -58,10 +64,28 @@ async def install_plugin(path: str, repo_url: str) -> None:
plugin_name = await install_zip_plugin(file=path)
if repo_url:
plugin_name = await install_git_plugin(repo_url=repo_url)

console.print(Text(f'插件 {plugin_name} 安装成功', style='bold green'))

sql_file = await get_plugin_sql(plugin_name, db_type, pk_type)
if sql_file and not no_sql:
console.print(Text('开始自动执行插件 SQL 脚本...', style='bold cyan'))
await execute_sql_scripts(sql_file)

except Exception as e:
raise cappa.Exit(e.msg if isinstance(e, BaseExceptionMixin) else str(e), code=1)

console.print(Text(f'插件 {plugin_name} 安装成功', style='bold cyan'))

async def execute_sql_scripts(sql_scripts: str) -> None:
async with async_db_session.begin() as db:
try:
stmts = await parse_sql_script(sql_scripts)
for stmt in stmts:
await db.execute(text(stmt))
except Exception as e:
raise cappa.Exit(f'SQL 脚本执行失败:{e}', code=1)

console.print(Text('SQL 脚本已执行完成', style='bold green'))


@cappa.command(help='运行服务')
Expand Down Expand Up @@ -104,22 +128,41 @@ class Add:
str | None,
cappa.Arg(long=True, help='Git 插件的仓库地址'),
]
no_sql: Annotated[
bool,
cappa.Arg(long=True, default=False, help='禁用插件 SQL 脚本自动执行'),
]
db_type: Annotated[
DataBaseType,
cappa.Arg(long=True, default='mysql', help='执行插件 SQL 脚本的数据库类型'),
]
pk_type: Annotated[
PrimaryKeyType,
cappa.Arg(long=True, default='autoincrement', help='执行插件 SQL 脚本数据库主键类型'),
]

async def __call__(self):
await install_plugin(path=self.path, repo_url=self.repo_url)
await install_plugin(self.path, self.repo_url, self.no_sql, self.db_type, self.pk_type)


@cappa.command(help='一个高效的 fba 命令行界面')
@dataclass
class FbaCli:
version: Annotated[
bool,
cappa.Arg(short='-V', long=True, default=False, help='打印当前版本号'),
]
sql: Annotated[
str,
cappa.Arg(long=True, default='', help='在事务中执行 SQL 脚本'),
]
subcmd: cappa.Subcommands[Run | Add | None] = None

def __call__(self):
async def __call__(self):
if self.version:
get_version()
if self.sql:
await execute_sql_scripts(self.sql)


def main() -> None:
Expand Down
14 changes: 14 additions & 0 deletions backend/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,17 @@ class UserPermissionType(StrEnum):
staff = 'staff'
status = 'status'
multi_login = 'multi_login'


class DataBaseType(StrEnum):
"""数据库类型"""

mysql = 'mysql'
postgresql = 'postgresql'


class PrimaryKeyType(StrEnum):
"""主键类型"""

autoincrement = 'autoincrement'
snowflake = 'snowflake'
30 changes: 29 additions & 1 deletion backend/plugin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from packaging.requirements import Requirement
from starlette.concurrency import run_in_threadpool

from backend.common.enums import StatusType
from backend.common.enums import DataBaseType, PrimaryKeyType, StatusType
from backend.common.exception import errors
from backend.common.log import log
from backend.core.conf import settings
Expand Down Expand Up @@ -75,6 +75,34 @@ def get_plugin_models() -> list[type]:
return classes


async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKeyType) -> str | None:
"""
获取插件 SQL 脚本

:param plugin: 插件名称
:param db_type: 数据库类型
:param pk_type: 主键类型
:return:
"""
if db_type == DataBaseType.mysql.value:
mysql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'mysql')
if pk_type == PrimaryKeyType.autoincrement:
sql_file = os.path.join(mysql_dir, 'init.sql')
else:
sql_file = os.path.join(mysql_dir, 'init_snowflake.sql')
else:
postgresql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'postgresql')
if pk_type == PrimaryKeyType.autoincrement.value:
sql_file = os.path.join(postgresql_dir, 'init.sql')
else:
sql_file = os.path.join(postgresql_dir, 'init_snowflake.sql')

if not os.path.exists(sql_file):
return None

return sql_file


def load_plugin_config(plugin: str) -> dict[str, Any]:
"""
加载插件配置
Expand Down
26 changes: 25 additions & 1 deletion backend/utils/file_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dulwich import porcelain
from fastapi import UploadFile
from sqlparse import split

from backend.common.enums import FileType
from backend.common.exception import errors
Expand All @@ -35,7 +36,7 @@ def build_filename(file: UploadFile) -> str:
return new_filename


def file_verify(file: UploadFile) -> None:
def upload_file_verify(file: UploadFile) -> None:
"""
文件验证

Expand Down Expand Up @@ -161,3 +162,26 @@ async def install_git_plugin(repo_url: str) -> str:
await redis_client.set(f'{settings.PLUGIN_REDIS_PREFIX}:changed', 'ture')

return repo_name


async def parse_sql_script(filepath: str) -> list[str]:
"""
解析 SQL 脚本

:param filepath: 脚本文件路径
:return:
"""
if not os.path.exists(filepath):
raise errors.NotFoundError(msg='SQL 脚本文件不存在')

async with aiofiles.open(filepath, mode='r', encoding='utf-8') as f:
contents = await f.read(1024)
while additional_contents := await f.read(1024):
contents += additional_contents

statements = split(contents)
for statement in statements:
if not any(statement.lower().startswith(_) for _ in ['select', 'insert']):
raise errors.RequestError(msg='SQL 脚本文件中存在非法操作,仅允许 SELECT 和 INSERT')

return statements
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"rtoml>=0.12.0",
"sqlalchemy-crud-plus>=1.10.0",
"sqlalchemy[asyncio]>=2.0.40",
"sqlparse>=0.5.3",
"user-agents==2.2.0",
]

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ sqlalchemy==2.0.40
# sqlalchemy-crud-plus
sqlalchemy-crud-plus==1.10.0
# via fastapi-best-architecture
sqlparse==0.5.3
# via fastapi-best-architecture
starlette==0.46.1
# via
# asgi-correlation-id
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.