Skip to content

Commit d2eff5a

Browse files
committed
Add CLI support for execute sql scripts
1 parent d906a10 commit d2eff5a

File tree

12 files changed

+131
-9
lines changed

12 files changed

+131
-9
lines changed

backend/app/admin/api/v1/sys/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from backend.common.response.response_schema import ResponseSchemaModel, response_base
99
from backend.common.security.permission import RequestPermission
1010
from backend.common.security.rbac import DependsRBAC
11-
from backend.utils.file_ops import file_verify, upload_file
11+
from backend.utils.file_ops import upload_file, upload_file_verify
1212

1313
router = APIRouter()
1414

@@ -22,6 +22,6 @@
2222
],
2323
)
2424
async def upload_files(file: Annotated[UploadFile, File()]) -> ResponseSchemaModel[UploadUrl]:
25-
file_verify(file)
25+
upload_file_verify(file)
2626
filename = await upload_file(file)
2727
return response_base.success(data={'url': f'/static/upload/{filename}'})

backend/cli.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111

1212
from rich.panel import Panel
1313
from rich.text import Text
14+
from sqlalchemy import text
1415

1516
from backend import console, get_version
17+
from backend.common.enums import DataBaseType, PrimaryKeyType
1618
from backend.common.exception.errors import BaseExceptionMixin
1719
from backend.core.conf import settings
18-
from backend.utils.file_ops import install_git_plugin, install_zip_plugin
20+
from backend.database.db import async_db_session
21+
from backend.plugin.tools import get_plugin_sql
22+
from backend.utils.file_ops import install_git_plugin, install_zip_plugin, parse_sql_script
1923

2024

2125
def run(host: str, port: int, reload: bool, workers: int | None) -> None:
@@ -44,7 +48,9 @@ def run(host: str, port: int, reload: bool, workers: int | None) -> None:
4448
)
4549

4650

47-
async def install_plugin(path: str, repo_url: str) -> None:
51+
async def install_plugin(
52+
path: str, repo_url: str, execute_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType
53+
) -> None:
4854
if not path and not repo_url:
4955
raise cappa.Exit('path 或 repo_url 必须指定其中一项', code=1)
5056
if path and repo_url:
@@ -58,10 +64,28 @@ async def install_plugin(path: str, repo_url: str) -> None:
5864
plugin_name = await install_zip_plugin(file=path)
5965
if repo_url:
6066
plugin_name = await install_git_plugin(repo_url=repo_url)
67+
68+
console.print(Text(f'插件 {plugin_name} 安装成功', style='bold green'))
69+
70+
sql_file = await get_plugin_sql(plugin_name, db_type, pk_type)
71+
if sql_file and execute_sql:
72+
console.print(Text('开始自动执行插件 SQL 脚本...', style='bold cyan'))
73+
await execute_sql_scripts(sql_file)
74+
6175
except Exception as e:
6276
raise cappa.Exit(e.msg if isinstance(e, BaseExceptionMixin) else str(e), code=1)
6377

64-
console.print(Text(f'插件 {plugin_name} 安装成功', style='bold cyan'))
78+
79+
async def execute_sql_scripts(sql_scripts: str) -> None:
80+
async with async_db_session.begin() as db:
81+
try:
82+
stmts = await parse_sql_script(sql_scripts)
83+
for stmt in stmts:
84+
await db.execute(text(stmt))
85+
except Exception as e:
86+
raise cappa.Exit(f'SQL 脚本执行失败:{e}', code=1)
87+
88+
console.print(Text('SQL 脚本已执行完成', style='bold green'))
6589

6690

6791
@cappa.command(help='运行服务')
@@ -104,9 +128,21 @@ class Add:
104128
str | None,
105129
cappa.Arg(long=True, help='Git 插件的仓库地址'),
106130
]
131+
execute_sql: Annotated[
132+
bool,
133+
cappa.Arg(long=True, default=False, help='启用插件安装后自动执行 SQL 脚本'),
134+
]
135+
db_type: Annotated[
136+
DataBaseType,
137+
cappa.Arg(long=True, default='mysql', help='指定数据库类型,需启用 `--sql`'),
138+
]
139+
pk_type: Annotated[
140+
PrimaryKeyType,
141+
cappa.Arg(long=True, default='autoincrement', help='指定数据库主键类型,需启用 `--sql`'),
142+
]
107143

108144
async def __call__(self):
109-
await install_plugin(path=self.path, repo_url=self.repo_url)
145+
await install_plugin(self.path, self.repo_url, self.execute_sql, self.db_type, self.pk_type)
110146

111147

112148
@dataclass
@@ -115,11 +151,17 @@ class FbaCli:
115151
bool,
116152
cappa.Arg(short='-V', long=True, default=False, help='打印当前版本号'),
117153
]
154+
sql: Annotated[
155+
str,
156+
cappa.Arg(long=True, default='', help='脚本文件绝对路径,使用当前数据库配置在事务中执行 SQL 脚本'),
157+
]
118158
subcmd: cappa.Subcommands[Run | Add | None] = None
119159

120-
def __call__(self):
160+
async def __call__(self):
121161
if self.version:
122162
get_version()
163+
if self.sql:
164+
await execute_sql_scripts(self.sql)
123165

124166

125167
def main() -> None:

backend/common/enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,17 @@ class UserPermissionType(StrEnum):
137137
staff = 'staff'
138138
status = 'status'
139139
multi_login = 'multi_login'
140+
141+
142+
class DataBaseType(StrEnum):
143+
"""数据库类型"""
144+
145+
mysql = 'mysql'
146+
postgresql = 'postgresql'
147+
148+
149+
class PrimaryKeyType(StrEnum):
150+
"""主键类型"""
151+
152+
autoincrement = 'autoincrement'
153+
snowflake = 'snowflake'

backend/plugin/tools.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from packaging.requirements import Requirement
1818
from starlette.concurrency import run_in_threadpool
1919

20-
from backend.common.enums import StatusType
20+
from backend.common.enums import DataBaseType, PrimaryKeyType, StatusType
2121
from backend.common.exception import errors
2222
from backend.common.log import log
2323
from backend.core.conf import settings
@@ -75,6 +75,34 @@ def get_plugin_models() -> list[type]:
7575
return classes
7676

7777

78+
async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKeyType) -> str | None:
79+
"""
80+
获取插件 SQL 脚本
81+
82+
:param plugin: 插件名称
83+
:param db_type: 数据库类型
84+
:param pk_type: 主键类型
85+
:return:
86+
"""
87+
if db_type == DataBaseType.mysql.value:
88+
mysql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'mysql')
89+
if pk_type == PrimaryKeyType.autoincrement:
90+
sql_file = os.path.join(mysql_dir, 'init.sql')
91+
else:
92+
sql_file = os.path.join(mysql_dir, 'init_snowflake.sql')
93+
else:
94+
postgresql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'postgresql')
95+
if pk_type == PrimaryKeyType.autoincrement.value:
96+
sql_file = os.path.join(postgresql_dir, 'init.sql')
97+
else:
98+
sql_file = os.path.join(postgresql_dir, 'init_snowflake.sql')
99+
100+
if not os.path.exists(sql_file):
101+
return None
102+
103+
return sql_file
104+
105+
78106
def load_plugin_config(plugin: str) -> dict[str, Any]:
79107
"""
80108
加载插件配置

backend/utils/file_ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from dulwich import porcelain
1111
from fastapi import UploadFile
12+
from sqlparse import split
1213

1314
from backend.common.enums import FileType
1415
from backend.common.exception import errors
@@ -35,7 +36,7 @@ def build_filename(file: UploadFile) -> str:
3536
return new_filename
3637

3738

38-
def file_verify(file: UploadFile) -> None:
39+
def upload_file_verify(file: UploadFile) -> None:
3940
"""
4041
文件验证
4142
@@ -161,3 +162,26 @@ async def install_git_plugin(repo_url: str) -> str:
161162
await redis_client.set(f'{settings.PLUGIN_REDIS_PREFIX}:changed', 'ture')
162163

163164
return repo_name
165+
166+
167+
async def parse_sql_script(filepath: str) -> list[str]:
168+
"""
169+
解析 SQL 脚本
170+
171+
:param filepath: 脚本文件路径
172+
:return:
173+
"""
174+
if not os.path.exists(filepath):
175+
raise errors.NotFoundError(msg='SQL 脚本文件不存在')
176+
177+
async with aiofiles.open(filepath, mode='r', encoding='utf-8') as f:
178+
contents = await f.read(1024)
179+
while additional_contents := await f.read(1024):
180+
contents += additional_contents
181+
182+
statements = split(contents)
183+
for statement in statements:
184+
if not any(statement.lower().startswith(_) for _ in ['select', 'insert']):
185+
raise errors.RequestError(msg='SQL 脚本文件中存在非法操作,仅允许 SELECT 和 INSERT')
186+
187+
return statements

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"rtoml>=0.12.0",
4949
"sqlalchemy-crud-plus>=1.10.0",
5050
"sqlalchemy[asyncio]>=2.0.40",
51+
"sqlparse>=0.5.3",
5152
"user-agents==2.2.0",
5253
]
5354

0 commit comments

Comments
 (0)