Skip to content

Commit 67ffb0b

Browse files
committed
role_id 权限限制+查询
1 parent c812e28 commit 67ffb0b

File tree

5 files changed

+171
-55
lines changed

5 files changed

+171
-55
lines changed

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
2626
- `api_key`:AskTable API 密钥(必需,环境变量)
2727
- `datasource_id`:数据源ID(必需,环境变量)
2828
- `base_url`:本地IP服务地址(可选,填写则走本地部署,不填则走SaaS)
29+
- `role_id` :角色id(可选,填写则只能访问该角色被允许的数据,不填则即可查询所有数据)
2930

3031
---
3132

@@ -36,7 +37,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
3637
- gen_conclusion , 根据用户的问题,直接返回数据结果
3738
- 输入:用户问题,如:"请给我出销售额前10的产品"
3839
- 输出:对应的数据结果
39-
- list_available_datasources , 获取当前用户apikey下的可用的所有数据库(数据源)信息
40+
- list_available_datasources , 获取当前APIKEY下的用户(role_id)所有可用数据库(数据源)信息
4041
- 输入:我数据库中有哪些数据源?
4142
- 输出:对应的数据源信息,包括数据源id、数据库引擎、数据库描述
4243

@@ -76,8 +77,9 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
7677
"command": "uvx",
7778
"args": ["asktable-mcp-server@latest", "--transport", "stdio"],
7879
"env": {
79-
"api_key": "your_api_key",
80-
"datasource_id": "your_datasource_id"
80+
"api_key": "your_api_key", // 必填
81+
"datasource_id": "your_datasource_id", // 必填
82+
// "role_id": "your_role_id" // 可选:如需限定角色权限,请填写
8183
}
8284
}
8385
}
@@ -95,9 +97,10 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
9597
"command": "uvx",
9698
"args": ["asktable-mcp-server@latest", "--transport", "stdio"],
9799
"env": {
98-
"api_key": "your_api_key",
99-
"datasource_id": "your_datasource_id",
100-
"base_url": "http://your_local_ip:port/api"
100+
"api_key": "your_api_key", // 必填
101+
"datasource_id": "your_datasource_id",// 必填
102+
"base_url": "http://your_local_ip:port/api", // 必填
103+
// "role_id": "your_role_id" // 可选:如需限定角色权限,请填写
101104
}
102105
}
103106
}
@@ -112,7 +115,8 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
112115
{
113116
"mcpServers": {
114117
"asktable-mcp-server": {
115-
"url": "http://localhost:8095/sse/?apikey=your_apikey&datasouce_id=your_datasouce_id",
118+
// role_id 为可选参数,不指定则使用默认权限
119+
"url": "http://localhost:8095/sse/?apikey=your_apikey&datasouce_id=your_datasouce_id&role_id=your_role_id",
116120
"headers": {},
117121
"timeout": 300,
118122
"sse_read_timeout": 300

src/asktable_mcp_server/server.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ async def gen_sql(query: str) -> str:
3838
'question': query
3939
}
4040

41-
# 如果环境变量中有base_url,添加到参数中
41+
# 如果环境变量中有base_url、role_id,添加到参数中
4242
base_url = os.getenv('base_url')
43+
role_id = os.getenv('role_id')
4344
if base_url:
4445
params['base_url'] = base_url
45-
46+
if role_id:
47+
params['role_id'] = role_id
4648
message = await get_asktable_sql(**params)
4749
return message
4850

@@ -74,11 +76,13 @@ async def gen_conclusion(query: str) -> str:
7476
'question': query
7577
}
7678

77-
# 如果环境变量中有base_url,添加到参数中
79+
# 如果环境变量中有base_url、role_id,添加到参数中
7880
base_url = os.getenv('base_url')
81+
role_id = os.getenv('role_id')
7982
if base_url:
8083
params['base_url'] = base_url
81-
84+
if role_id:
85+
params['role_id'] = role_id
8286

8387
message = await get_asktable_data(**params)
8488
return message
@@ -126,8 +130,9 @@ async def list_available_datasources() -> str:
126130
"""
127131
api_key = os.getenv('api_key')
128132
base_url = os.getenv('base_url') or None
133+
role_id = os.getenv('role_id') or None
129134

130-
result = await get_datasources_info(api_key=api_key, base_url=base_url)
135+
result = await get_datasources_info(api_key=api_key, base_url=base_url,role_id=role_id)
131136
logging.info(result['status'])
132137
return result['data']
133138

src/asktable_mcp_server/sse_server.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ async def lifespan(fastmcp_instance):
3434
# 启动逻辑
3535
logger.info("服务器正在初始化...")
3636

37-
# 在这里可以进行一些初始化工作
38-
await asyncio.sleep(3) # 模拟初始化时间
37+
await asyncio.sleep(3)
3938

4039
server_ready = True
4140
logger.info("服务器初始化完成,准备接受请求")
@@ -65,14 +64,15 @@ async def health_check(request: Request):
6564
@mcp.custom_route("/sse/", methods=["GET"])
6665
async def sse_endpoint(request: Request):
6766
"""自定义SSE端点,检查服务器是否准备就绪"""
68-
global API_KEY, DATASOURCE_ID, server_ready
67+
global API_KEY, DATASOURCE_ID, server_ready,ROLE_ID
6968

7069
if not server_ready:
7170
return {"error": "Server is still initializing, please wait"}
7271

7372
# 从URL参数获取配置
7473
API_KEY = request.query_params.get('apikey')
7574
DATASOURCE_ID = request.query_params.get('datasource_id')
75+
ROLE_ID = request.query_params.get('role_id')
7676

7777
if not API_KEY or not DATASOURCE_ID:
7878
logging.info("error: Missing required parameters: apikey and datasource_id")
@@ -102,7 +102,7 @@ async def gen_sql(query: str) -> str:
102102
- 需要将自然语言转化为SQL查询
103103
- 仅需要SQL文本而不需要执行结果
104104
"""
105-
global API_KEY, DATASOURCE_ID, server_ready
105+
global API_KEY, DATASOURCE_ID, server_ready,ROLE_ID
106106

107107
if not server_ready:
108108
return "Server is still initializing, please wait"
@@ -119,13 +119,25 @@ async def gen_sql(query: str) -> str:
119119
api_key = API_KEY
120120
datasource_id = DATASOURCE_ID
121121

122+
try:
123+
request = get_http_request()
124+
role_id = request.query_params.get('role_id', None)
125+
except RuntimeError:
126+
role_id = None
127+
128+
if not role_id:
129+
role_id = ROLE_ID
130+
122131
logging.info(f"api_key:{api_key}")
123132
logging.info(f"datasource_id:{datasource_id}")
133+
logging.info(f"role_id:{role_id}")
124134

125135
params = {
126136
'api_key': api_key,
127137
'datasource_id': datasource_id,
128-
'question': query
138+
'question': query,
139+
'role_id':role_id
140+
129141
}
130142
if args.base_url:
131143
params['base_url'] = args.base_url
@@ -159,7 +171,6 @@ async def gen_conclusion(query: str) -> str:
159171
if not server_ready:
160172
return "Server is still initializing, please wait"
161173

162-
# 其余代码保持不变...
163174
if not API_KEY or not DATASOURCE_ID:
164175
try:
165176
request = get_http_request()
@@ -172,16 +183,26 @@ async def gen_conclusion(query: str) -> str:
172183
api_key = API_KEY
173184
datasource_id = DATASOURCE_ID
174185

186+
try:
187+
request = get_http_request()
188+
role_id = request.query_params.get('role_id', None)
189+
except RuntimeError:
190+
role_id = None
191+
192+
if not role_id:
193+
role_id = ROLE_ID
194+
175195
logging.info(f"api_key:{api_key}")
176196
logging.info(f"datasource_id:{datasource_id}")
197+
logging.info(f"role_id:{role_id}")
177198

178199
params = {
179-
'api_key': str(api_key),
180-
'datasource_id': str(datasource_id),
181-
'question': query
200+
'api_key': api_key,
201+
'datasource_id': datasource_id,
202+
'question': query,
203+
'role_id':role_id
204+
182205
}
183-
if args.base_url:
184-
params['base_url'] = args.base_url
185206

186207
message = await get_asktable_data(**params)
187208
return message

src/asktable_mcp_server/tools.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,63 @@
11
from asktable import Asktable
22
import json
33
import asyncio
4+
from utils import AskTableHelper
5+
6+
async def get_asktable_data(api_key, datasource_id, question, base_url=None,role_id=None):
7+
"""
8+
获取asktable数据
9+
:param api_key:
10+
:param datasource_id:
11+
:param question:
12+
:param base_url:
13+
:param role_id:
14+
:return:
15+
"""
416

5-
async def get_asktable_data(api_key, datasource_id, question, base_url=None):
617
# 如果没有传入base_url或传入None,使用默认值
718
if base_url is None:
819
base_url = "https://api.asktable.com"
920

1021
asktable_client = Asktable(base_url=base_url, api_key=api_key)
11-
answer_response = asktable_client.answers.create(datasource_id=datasource_id, question=question)
22+
answer_response = asktable_client.answers.create(datasource_id=datasource_id, question=question,role_id=role_id)
1223
if answer_response.answer is None:
1324
return "没有查询到相关信息"
1425
return answer_response.answer.text
1526

16-
async def get_asktable_sql(api_key, datasource_id, question, base_url=None):
27+
async def get_asktable_sql(api_key, datasource_id, question, base_url=None,role_id=None):
1728
# 如果没有传入base_url或传入None,使用默认值
1829
if base_url is None:
1930
base_url = "https://api.asktable.com"
2031

2132
asktable_client = Asktable(base_url=base_url, api_key=api_key)
22-
query_response = asktable_client.sqls.create(datasource_id=datasource_id, question=question)
33+
query_response = asktable_client.sqls.create(datasource_id=datasource_id, question=question,role_id=role_id)
2334
if query_response.query.sql is None:
2435
return "没有查询到相关信息"
2536
return query_response.query.sql
2637

2738

28-
async def get_datasources_info(api_key, base_url=None):
39+
async def get_datasources_info(api_key, base_url=None,role_id=None):
2940
""""
30-
返回用户数据库meta data
41+
返回用户数据库meta_data;
42+
若输入roleid,则返回该角色可访问的数据库meta_data
43+
若不输入roleid,则返回所有数据库meta_data
3144
args:
3245
api_key: str
3346
base_url: str
47+
role_id:str
3448
return:
3549
{
3650
"status": "success" or "failure",
3751
"data": json.dumps(result, ensure_ascii=False, indent=2)
3852
}
3953
"""
40-
if base_url is None:
41-
base_url = "https://api.asktable.com"
42-
43-
asktable_client = Asktable(base_url=base_url, api_key=api_key)
44-
meta_data_list = asktable_client.datasources.list()
45-
# 最前面判断
46-
if not meta_data_list.items:
47-
return {
48-
"status": "failure",
49-
"data": "该用户还没有创建任何数据库"
50-
}
51-
52-
# 提取指定字段
53-
result = [
54-
{
55-
"datasource_id": ds.id,
56-
"数据库引擎": ds.engine,
57-
"数据库描述": ds.desc,
58-
}
59-
for ds in meta_data_list.items
60-
]
54+
helper = AskTableHelper(api_key=api_key,base_url=base_url)
55+
if role_id is None:
56+
return helper.get_datasources_info()
57+
else:
58+
return helper.get_datasources_info_by_role(role_id=role_id)
6159

62-
return {
63-
"status": "success",
64-
"data": json.dumps(result, ensure_ascii=False, indent=2)
65-
}
6660

6761

6862
if __name__ == "__main__":
69-
pass
63+
pass

src/asktable_mcp_server/utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import json
2+
from asktable import Asktable
3+
4+
class AskTableHelper:
5+
def __init__(self, api_key, base_url=None):
6+
if base_url is None:
7+
base_url = "https://api.asktable.com"
8+
self.asktable_client = Asktable(base_url=base_url, api_key=api_key)
9+
10+
def get_datasource_ids_by_role(self, role_id):
11+
"""
12+
根据 role_id 获取该角色下所有 policy 涉及的 datasource_id 列表
13+
:param role_id: 角色ID
14+
:return: list, 包含所有 datasource_id
15+
"""
16+
role_source = self.asktable_client.roles.get_polices(role_id=role_id)
17+
datasource_ids = set()
18+
for policy in role_source:
19+
if hasattr(policy, 'dataset_config') and hasattr(policy.dataset_config, 'datasource_ids'):
20+
datasource_ids.update(policy.dataset_config.datasource_ids)
21+
return list(datasource_ids)
22+
23+
def get_role_name_by_role(self, role_id):
24+
"""
25+
根据 role_id 获取角色名称
26+
:param role_id: 角色ID
27+
:return: 角色名称字符串
28+
"""
29+
role_source = self.asktable_client.roles.retrieve(role_id=role_id)
30+
return role_source.name
31+
32+
def get_datasources_info(self):
33+
"""
34+
返回当前 API KEY 下所有数据库 meta data
35+
:return: dict, {status, data}
36+
"""
37+
meta_data_list = self.asktable_client.datasources.list()
38+
if not meta_data_list.items:
39+
return {
40+
"status": "failure",
41+
"data": "该用户还没有创建任何数据库"
42+
}
43+
result = [
44+
{
45+
"datasource_id": ds.id,
46+
"数据库引擎": ds.engine,
47+
"数据库描述": ds.desc,
48+
}
49+
for ds in meta_data_list.items
50+
]
51+
return {
52+
"status": "success",
53+
"data": json.dumps(result, ensure_ascii=False, indent=2)
54+
}
55+
56+
def get_datasources_info_by_role(self, role_id):
57+
"""
58+
输入 role_id,返回该角色可访问的数据源的描述、引擎和id
59+
:param role_id: 角色ID
60+
:return: dict, {status, data}
61+
"""
62+
# 1. 获取所有数据源
63+
meta_data_list = self.asktable_client.datasources.list()
64+
if not meta_data_list.items:
65+
return {
66+
"status": "failure",
67+
"data": "该用户还没有创建任何数据库"
68+
}
69+
70+
# 2. 获取该角色可访问的数据源ID
71+
datasource_ids = set(self.get_datasource_ids_by_role(role_id))
72+
73+
# 3. 只保留属于该 role 的数据源
74+
result = [
75+
{
76+
"datasource_id": ds.id,
77+
"数据库引擎": ds.engine,
78+
"数据库描述": ds.desc,
79+
}
80+
for ds in meta_data_list.items if ds.id in datasource_ids
81+
]
82+
83+
if not result:
84+
return {
85+
"status": "failure",
86+
"data": "该角色没有可访问的数据库"
87+
}
88+
89+
return {
90+
"status": "success",
91+
"data": json.dumps(result, ensure_ascii=False, indent=2)
92+
}

0 commit comments

Comments
 (0)