Skip to content

Commit 494942e

Browse files
authored
Add CLI support for execute sql scripts (#711)
* Add CLI support for execute sql scripts * Update the arg helps
1 parent aa2b766 commit 494942e

File tree

12 files changed

+132
-9
lines changed

12 files changed

+132
-9
lines changed

backend/app/admin/api/v1/sys/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from backend.common.response.response_schema import ResponseSchemaModel, response_base
99
from backend.common.security.permission import RequestPermission
1010
from backend.common.security.rbac import DependsRBAC
11-
from backend.utils.file_ops import file_verify, upload_file
11+
from backend.utils.file_ops import upload_file, upload_file_verify
1212

1313
router = APIRouter()
1414

@@ -22,6 +22,6 @@
2222
],
2323
)
2424
async def upload_files(file: Annotated[UploadFile, File()]) -> ResponseSchemaModel[UploadUrl]:
25-
file_verify(file)
25+
upload_file_verify(file)
2626
filename = await upload_file(file)
2727
return response_base.success(data={'url': f'/static/upload/{filename}'})

backend/cli.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212

1313
from rich.panel import Panel
1414
from rich.text import Text
15+
from sqlalchemy import text
1516

1617
from backend import console, get_version
18+
from backend.common.enums import DataBaseType, PrimaryKeyType
1719
from backend.common.exception.errors import BaseExceptionMixin
1820
from backend.core.conf import settings
19-
from backend.utils.file_ops import install_git_plugin, install_zip_plugin
21+
from backend.database.db import async_db_session
22+
from backend.plugin.tools import get_plugin_sql
23+
from backend.utils.file_ops import install_git_plugin, install_zip_plugin, parse_sql_script
2024

2125

2226
def run(host: str, port: int, reload: bool, workers: int | None) -> None:
@@ -45,7 +49,9 @@ def run(host: str, port: int, reload: bool, workers: int | None) -> None:
4549
)
4650

4751

48-
async def install_plugin(path: str, repo_url: str) -> None:
52+
async def install_plugin(
53+
path: str, repo_url: str, no_sql: bool, db_type: DataBaseType, pk_type: PrimaryKeyType
54+
) -> None:
4955
if not path and not repo_url:
5056
raise cappa.Exit('path 或 repo_url 必须指定其中一项', code=1)
5157
if path and repo_url:
@@ -59,10 +65,28 @@ async def install_plugin(path: str, repo_url: str) -> None:
5965
plugin_name = await install_zip_plugin(file=path)
6066
if repo_url:
6167
plugin_name = await install_git_plugin(repo_url=repo_url)
68+
69+
console.print(Text(f'插件 {plugin_name} 安装成功', style='bold green'))
70+
71+
sql_file = await get_plugin_sql(plugin_name, db_type, pk_type)
72+
if sql_file and not no_sql:
73+
console.print(Text('开始自动执行插件 SQL 脚本...', style='bold cyan'))
74+
await execute_sql_scripts(sql_file)
75+
6276
except Exception as e:
6377
raise cappa.Exit(e.msg if isinstance(e, BaseExceptionMixin) else str(e), code=1)
6478

65-
console.print(Text(f'插件 {plugin_name} 安装成功', style='bold cyan'))
79+
80+
async def execute_sql_scripts(sql_scripts: str) -> None:
81+
async with async_db_session.begin() as db:
82+
try:
83+
stmts = await parse_sql_script(sql_scripts)
84+
for stmt in stmts:
85+
await db.execute(text(stmt))
86+
except Exception as e:
87+
raise cappa.Exit(f'SQL 脚本执行失败:{e}', code=1)
88+
89+
console.print(Text('SQL 脚本已执行完成', style='bold green'))
6690

6791

6892
@cappa.command(help='运行服务')
@@ -105,22 +129,41 @@ class Add:
105129
str | None,
106130
cappa.Arg(long=True, help='Git 插件的仓库地址'),
107131
]
132+
no_sql: Annotated[
133+
bool,
134+
cappa.Arg(long=True, default=False, help='禁用插件 SQL 脚本自动执行'),
135+
]
136+
db_type: Annotated[
137+
DataBaseType,
138+
cappa.Arg(long=True, default='mysql', help='执行插件 SQL 脚本的数据库类型'),
139+
]
140+
pk_type: Annotated[
141+
PrimaryKeyType,
142+
cappa.Arg(long=True, default='autoincrement', help='执行插件 SQL 脚本数据库主键类型'),
143+
]
108144

109145
async def __call__(self):
110-
await install_plugin(path=self.path, repo_url=self.repo_url)
146+
await install_plugin(self.path, self.repo_url, self.no_sql, self.db_type, self.pk_type)
111147

112148

149+
@cappa.command(help='一个高效的 fba 命令行界面')
113150
@dataclass
114151
class FbaCli:
115152
version: Annotated[
116153
bool,
117154
cappa.Arg(short='-V', long=True, default=False, help='打印当前版本号'),
118155
]
156+
sql: Annotated[
157+
str,
158+
cappa.Arg(long=True, default='', help='在事务中执行 SQL 脚本'),
159+
]
119160
subcmd: cappa.Subcommands[Run | Add | None] = None
120161

121-
def __call__(self):
162+
async def __call__(self):
122163
if self.version:
123164
get_version()
165+
if self.sql:
166+
await execute_sql_scripts(self.sql)
124167

125168

126169
def main() -> None:

backend/common/enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,17 @@ class UserPermissionType(StrEnum):
137137
staff = 'staff'
138138
status = 'status'
139139
multi_login = 'multi_login'
140+
141+
142+
class DataBaseType(StrEnum):
143+
"""数据库类型"""
144+
145+
mysql = 'mysql'
146+
postgresql = 'postgresql'
147+
148+
149+
class PrimaryKeyType(StrEnum):
150+
"""主键类型"""
151+
152+
autoincrement = 'autoincrement'
153+
snowflake = 'snowflake'

backend/plugin/tools.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from packaging.requirements import Requirement
1818
from starlette.concurrency import run_in_threadpool
1919

20-
from backend.common.enums import StatusType
20+
from backend.common.enums import DataBaseType, PrimaryKeyType, StatusType
2121
from backend.common.exception import errors
2222
from backend.common.log import log
2323
from backend.core.conf import settings
@@ -75,6 +75,34 @@ def get_plugin_models() -> list[type]:
7575
return classes
7676

7777

78+
async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKeyType) -> str | None:
79+
"""
80+
获取插件 SQL 脚本
81+
82+
:param plugin: 插件名称
83+
:param db_type: 数据库类型
84+
:param pk_type: 主键类型
85+
:return:
86+
"""
87+
if db_type == DataBaseType.mysql.value:
88+
mysql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'mysql')
89+
if pk_type == PrimaryKeyType.autoincrement:
90+
sql_file = os.path.join(mysql_dir, 'init.sql')
91+
else:
92+
sql_file = os.path.join(mysql_dir, 'init_snowflake.sql')
93+
else:
94+
postgresql_dir = os.path.join(PLUGIN_DIR, plugin, 'sql', 'postgresql')
95+
if pk_type == PrimaryKeyType.autoincrement.value:
96+
sql_file = os.path.join(postgresql_dir, 'init.sql')
97+
else:
98+
sql_file = os.path.join(postgresql_dir, 'init_snowflake.sql')
99+
100+
if not os.path.exists(sql_file):
101+
return None
102+
103+
return sql_file
104+
105+
78106
def load_plugin_config(plugin: str) -> dict[str, Any]:
79107
"""
80108
加载插件配置

backend/utils/file_ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from dulwich import porcelain
1111
from fastapi import UploadFile
12+
from sqlparse import split
1213

1314
from backend.common.enums import FileType
1415
from backend.common.exception import errors
@@ -35,7 +36,7 @@ def build_filename(file: UploadFile) -> str:
3536
return new_filename
3637

3738

38-
def file_verify(file: UploadFile) -> None:
39+
def upload_file_verify(file: UploadFile) -> None:
3940
"""
4041
文件验证
4142
@@ -161,3 +162,26 @@ async def install_git_plugin(repo_url: str) -> str:
161162
await redis_client.set(f'{settings.PLUGIN_REDIS_PREFIX}:changed', 'ture')
162163

163164
return repo_name
165+
166+
167+
async def parse_sql_script(filepath: str) -> list[str]:
168+
"""
169+
解析 SQL 脚本
170+
171+
:param filepath: 脚本文件路径
172+
:return:
173+
"""
174+
if not os.path.exists(filepath):
175+
raise errors.NotFoundError(msg='SQL 脚本文件不存在')
176+
177+
async with aiofiles.open(filepath, mode='r', encoding='utf-8') as f:
178+
contents = await f.read(1024)
179+
while additional_contents := await f.read(1024):
180+
contents += additional_contents
181+
182+
statements = split(contents)
183+
for statement in statements:
184+
if not any(statement.lower().startswith(_) for _ in ['select', 'insert']):
185+
raise errors.RequestError(msg='SQL 脚本文件中存在非法操作,仅允许 SELECT 和 INSERT')
186+
187+
return statements

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"rtoml>=0.12.0",
4949
"sqlalchemy-crud-plus>=1.10.0",
5050
"sqlalchemy[asyncio]>=2.0.40",
51+
"sqlparse>=0.5.3",
5152
"user-agents==2.2.0",
5253
]
5354

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ sqlalchemy==2.0.40
257257
# sqlalchemy-crud-plus
258258
sqlalchemy-crud-plus==1.10.0
259259
# via fastapi-best-architecture
260+
sqlparse==0.5.3
261+
# via fastapi-best-architecture
260262
starlette==0.46.1
261263
# via
262264
# asgi-correlation-id

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)