|
| 1 | +import random |
| 2 | +from utils.file_util import read_json_file, write_json_to_file, save_raw_text |
| 3 | +from utils.db_util import examples_to_str |
| 4 | +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
| 5 | + |
| 6 | + |
| 7 | +class MSchema: |
| 8 | + def __init__(self, db_id: str = 'Anonymous', schema: Optional[str] = None): |
| 9 | + self.db_id = db_id |
| 10 | + self.schema = schema |
| 11 | + self.tables = {} |
| 12 | + self.foreign_keys = [] |
| 13 | + |
| 14 | + def add_table(self, name, fields={}, comment=None): |
| 15 | + self.tables[name] = {"fields": fields.copy(), 'examples': [], 'comment': comment} |
| 16 | + |
| 17 | + def add_field(self, table_name: str, field_name: str, field_type: str = "", |
| 18 | + primary_key: bool = False, nullable: bool = True, default: Any = None, |
| 19 | + autoincrement: bool = False, comment: str = "", examples: list = [], **kwargs): |
| 20 | + self.tables[table_name]["fields"][field_name] = { |
| 21 | + "type": field_type, |
| 22 | + "primary_key": primary_key, |
| 23 | + "nullable": nullable, |
| 24 | + "default": default if default is None else f'{default}', |
| 25 | + "autoincrement": autoincrement, |
| 26 | + "comment": comment, |
| 27 | + "examples": examples.copy(), |
| 28 | + **kwargs} |
| 29 | + |
| 30 | + def add_foreign_key(self, table_name, field_name, ref_schema, ref_table_name, ref_field_name): |
| 31 | + self.foreign_keys.append([table_name, field_name, ref_schema, ref_table_name, ref_field_name]) |
| 32 | + |
| 33 | + def get_field_type(self, field_type, simple_mode=True)->str: |
| 34 | + if not simple_mode: |
| 35 | + return field_type |
| 36 | + else: |
| 37 | + return field_type.split("(")[0] |
| 38 | + |
| 39 | + def has_table(self, table_name: str) -> bool: |
| 40 | + if table_name in self.tables.keys(): |
| 41 | + return True |
| 42 | + else: |
| 43 | + return False |
| 44 | + |
| 45 | + def has_column(self, table_name: str, field_name: str) -> bool: |
| 46 | + if self.has_table(table_name): |
| 47 | + if field_name in self.tables[table_name]["fields"].keys(): |
| 48 | + return True |
| 49 | + else: |
| 50 | + return False |
| 51 | + else: |
| 52 | + return False |
| 53 | + |
| 54 | + def get_field_info(self, table_name: str, field_name: str) -> Dict: |
| 55 | + try: |
| 56 | + return self.tables[table_name]['fields'][field_name] |
| 57 | + except: |
| 58 | + return {} |
| 59 | + |
| 60 | + def single_table_mschema(self, table_name: str, selected_columns: List = None, |
| 61 | + example_num=3, show_type_detail=False, shuffle=True) -> str: |
| 62 | + table_info = self.tables.get(table_name, {}) |
| 63 | + output = [] |
| 64 | + table_comment = table_info.get('comment', '') |
| 65 | + if table_comment is not None and table_comment != 'None' and len(table_comment) > 0: |
| 66 | + if self.schema is not None and len(self.schema) > 0: |
| 67 | + output.append(f"# Table: {self.schema}.{table_name}, {table_comment}") |
| 68 | + else: |
| 69 | + output.append(f"# Table: {table_name}, {table_comment}") |
| 70 | + else: |
| 71 | + if self.schema is not None and len(self.schema) > 0: |
| 72 | + output.append(f"# Table: {self.schema}.{table_name}") |
| 73 | + else: |
| 74 | + output.append(f"# Table: {table_name}") |
| 75 | + |
| 76 | + field_lines = [] |
| 77 | + # 处理表中的每一个字段 |
| 78 | + for field_name, field_info in table_info['fields'].items(): |
| 79 | + if selected_columns is not None and field_name.lower() not in selected_columns: |
| 80 | + continue |
| 81 | + |
| 82 | + raw_type = self.get_field_type(field_info['type'], not show_type_detail) |
| 83 | + field_line = f"({field_name}:{raw_type.upper()}" |
| 84 | + if field_info['comment'] != '': |
| 85 | + field_line += f", {field_info['comment'].strip()}" |
| 86 | + else: |
| 87 | + pass |
| 88 | + |
| 89 | + ## 打上主键标识 |
| 90 | + is_primary_key = field_info.get('primary_key', False) |
| 91 | + if is_primary_key: |
| 92 | + field_line += f", Primary Key" |
| 93 | + |
| 94 | + # 如果有示例,添加上 |
| 95 | + if len(field_info.get('examples', [])) > 0 and example_num > 0: |
| 96 | + examples = field_info['examples'] |
| 97 | + examples = [s for s in examples if s is not None] |
| 98 | + examples = examples_to_str(examples) |
| 99 | + if len(examples) > example_num: |
| 100 | + examples = examples[:example_num] |
| 101 | + |
| 102 | + if raw_type in ['DATE', 'TIME', 'DATETIME', 'TIMESTAMP']: |
| 103 | + examples = [examples[0]] |
| 104 | + elif len(examples) > 0 and max([len(s) for s in examples]) > 20: |
| 105 | + if max([len(s) for s in examples]) > 50: |
| 106 | + examples = [] |
| 107 | + else: |
| 108 | + examples = [examples[0]] |
| 109 | + else: |
| 110 | + pass |
| 111 | + if len(examples) > 0: |
| 112 | + example_str = ', '.join([str(example) for example in examples]) |
| 113 | + field_line += f", Examples: [{example_str}]" |
| 114 | + else: |
| 115 | + pass |
| 116 | + else: |
| 117 | + field_line += "" |
| 118 | + field_line += ")" |
| 119 | + |
| 120 | + field_lines.append(field_line) |
| 121 | + |
| 122 | + if shuffle: |
| 123 | + random.shuffle(field_lines) |
| 124 | + |
| 125 | + output.append('[') |
| 126 | + output.append(',\n'.join(field_lines)) |
| 127 | + output.append(']') |
| 128 | + |
| 129 | + return '\n'.join(output) |
| 130 | + |
| 131 | + def to_mschema(self, selected_tables: List = None, selected_columns: List = None, |
| 132 | + example_num=3, show_type_detail=False, shuffle=True) -> str: |
| 133 | + """ |
| 134 | + convert to a MSchema string. |
| 135 | + selected_tables: 默认为None,表示选择所有的表 |
| 136 | + selected_columns: 默认为None,表示所有列全选,格式['table_name.column_name'] |
| 137 | + """ |
| 138 | + output = [] |
| 139 | + |
| 140 | + if selected_tables is not None: |
| 141 | + selected_tables = [s.lower() for s in selected_tables] |
| 142 | + if selected_columns is not None: |
| 143 | + selected_columns = [s.lower() for s in selected_columns] |
| 144 | + selected_tables = [s.split('.')[0].lower() for s in selected_columns] |
| 145 | + |
| 146 | + # 依次处理每一个表 |
| 147 | + for table_name, table_info in self.tables.items(): |
| 148 | + if selected_tables is None or table_name.lower() in selected_tables: |
| 149 | + cur_table_type = table_info.get('type', 'table') |
| 150 | + column_names = list(table_info['fields'].keys()) |
| 151 | + if selected_columns is not None: |
| 152 | + cur_selected_columns = [c for c in column_names if f"{table_name}.{c}".lower() in selected_columns] |
| 153 | + else: |
| 154 | + cur_selected_columns = selected_columns |
| 155 | + output.append(self.single_table_mschema(table_name, cur_selected_columns, example_num, show_type_detail, shuffle)) |
| 156 | + |
| 157 | + if shuffle: |
| 158 | + random.shuffle(output) |
| 159 | + |
| 160 | + output.insert(0, f"【DB_ID】 {self.db_id}") |
| 161 | + output.insert(1, f"【Schema】") |
| 162 | + |
| 163 | + # 添加外键信息,选择table_type为view时不展示外键 |
| 164 | + if self.foreign_keys: |
| 165 | + output.append("【Foreign keys】") |
| 166 | + for fk in self.foreign_keys: |
| 167 | + ref_schema = fk[2] |
| 168 | + table1, column1, _, table2, column2 = fk |
| 169 | + if selected_tables is None or \ |
| 170 | + (table1.lower() in selected_tables and table2.lower() in selected_tables): |
| 171 | + if ref_schema == self.schema: |
| 172 | + output.append(f"{fk[0]}.{fk[1]}={fk[3]}.{fk[4]}") |
| 173 | + |
| 174 | + return '\n'.join(output) |
| 175 | + |
| 176 | + def dump(self): |
| 177 | + schema_dict = { |
| 178 | + "db_id": self.db_id, |
| 179 | + "schema": self.schema, |
| 180 | + "tables": self.tables, |
| 181 | + "foreign_keys": self.foreign_keys |
| 182 | + } |
| 183 | + return schema_dict |
| 184 | + |
| 185 | + def save(self, file_path: str): |
| 186 | + schema_dict = self.dump() |
| 187 | + write_json_to_file(file_path, schema_dict, is_json_line=False) |
| 188 | + |
| 189 | + def load(self, file_path: str): |
| 190 | + data = read_json_file(file_path) |
| 191 | + self.db_id = data.get("db_id", "Anonymous") |
| 192 | + self.schema = data.get("schema", None) |
| 193 | + self.tables = data.get("tables", {}) |
| 194 | + self.foreign_keys = data.get("foreign_keys", []) |
0 commit comments