|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +数据库迁移脚本 - 添加敏感度阈值字段到 risk_type_config 表 |
| 4 | +为 risk_type_config 表添加敏感度阈值相关字段: |
| 5 | +- high_sensitivity_threshold: 高敏感度阈值 (默认0.40) |
| 6 | +- medium_sensitivity_threshold: 中敏感度阈值 (默认0.60) |
| 7 | +- low_sensitivity_threshold: 低敏感度阈值 (默认0.95) |
| 8 | +- sensitivity_trigger_level: 触发检测命中的最低敏感度等级 (默认"medium") |
| 9 | +""" |
| 10 | +import sys |
| 11 | +import os |
| 12 | +from pathlib import Path |
| 13 | + |
| 14 | +# 添加项目根目录到 Python 路径 |
| 15 | +backend_dir = Path(__file__).resolve().parent.parent |
| 16 | +sys.path.insert(0, str(backend_dir)) |
| 17 | + |
| 18 | +from sqlalchemy import create_engine, text |
| 19 | +from config import settings |
| 20 | +from utils.logger import setup_logger |
| 21 | + |
| 22 | +logger = setup_logger() |
| 23 | + |
| 24 | +def migrate(): |
| 25 | + """执行数据库迁移""" |
| 26 | + try: |
| 27 | + # 创建数据库引擎 |
| 28 | + engine = create_engine(settings.database_url) |
| 29 | + |
| 30 | + logger.info("开始数据库迁移:添加敏感度阈值字段到 risk_type_config 表...") |
| 31 | + |
| 32 | + with engine.connect() as conn: |
| 33 | + # 开始事务 |
| 34 | + trans = conn.begin() |
| 35 | + |
| 36 | + try: |
| 37 | + # 检查哪些字段已存在 |
| 38 | + logger.info("检查现有字段...") |
| 39 | + result = conn.execute(text(""" |
| 40 | + SELECT column_name |
| 41 | + FROM information_schema.columns |
| 42 | + WHERE table_name = 'risk_type_config' |
| 43 | + AND column_name IN ('high_sensitivity_threshold', 'medium_sensitivity_threshold', 'low_sensitivity_threshold', 'sensitivity_trigger_level') |
| 44 | + """)) |
| 45 | + existing_columns = [row[0] for row in result.fetchall()] |
| 46 | + logger.info(f"现有敏感度字段: {existing_columns}") |
| 47 | + |
| 48 | + # 1. 添加高敏感度阈值字段 |
| 49 | + if 'high_sensitivity_threshold' not in existing_columns: |
| 50 | + logger.info("添加 high_sensitivity_threshold 字段...") |
| 51 | + conn.execute(text(""" |
| 52 | + ALTER TABLE risk_type_config |
| 53 | + ADD COLUMN high_sensitivity_threshold FLOAT DEFAULT 0.40 |
| 54 | + """)) |
| 55 | + else: |
| 56 | + logger.info("high_sensitivity_threshold 字段已存在,跳过") |
| 57 | + |
| 58 | + # 2. 添加中敏感度阈值字段 |
| 59 | + if 'medium_sensitivity_threshold' not in existing_columns: |
| 60 | + logger.info("添加 medium_sensitivity_threshold 字段...") |
| 61 | + conn.execute(text(""" |
| 62 | + ALTER TABLE risk_type_config |
| 63 | + ADD COLUMN medium_sensitivity_threshold FLOAT DEFAULT 0.60 |
| 64 | + """)) |
| 65 | + else: |
| 66 | + logger.info("medium_sensitivity_threshold 字段已存在,跳过") |
| 67 | + |
| 68 | + # 3. 添加低敏感度阈值字段 |
| 69 | + if 'low_sensitivity_threshold' not in existing_columns: |
| 70 | + logger.info("添加 low_sensitivity_threshold 字段...") |
| 71 | + conn.execute(text(""" |
| 72 | + ALTER TABLE risk_type_config |
| 73 | + ADD COLUMN low_sensitivity_threshold FLOAT DEFAULT 0.95 |
| 74 | + """)) |
| 75 | + else: |
| 76 | + logger.info("low_sensitivity_threshold 字段已存在,跳过") |
| 77 | + |
| 78 | + # 4. 添加敏感度触发等级字段 |
| 79 | + if 'sensitivity_trigger_level' not in existing_columns: |
| 80 | + logger.info("添加 sensitivity_trigger_level 字段...") |
| 81 | + conn.execute(text(""" |
| 82 | + ALTER TABLE risk_type_config |
| 83 | + ADD COLUMN sensitivity_trigger_level VARCHAR(10) DEFAULT 'medium' |
| 84 | + """)) |
| 85 | + else: |
| 86 | + logger.info("sensitivity_trigger_level 字段已存在,跳过") |
| 87 | + |
| 88 | + # 5. 更新现有记录的默认值 |
| 89 | + logger.info("更新现有记录的默认值...") |
| 90 | + conn.execute(text(""" |
| 91 | + UPDATE risk_type_config |
| 92 | + SET |
| 93 | + high_sensitivity_threshold = COALESCE(high_sensitivity_threshold, 0.40), |
| 94 | + medium_sensitivity_threshold = COALESCE(medium_sensitivity_threshold, 0.60), |
| 95 | + low_sensitivity_threshold = COALESCE(low_sensitivity_threshold, 0.95), |
| 96 | + sensitivity_trigger_level = COALESCE(sensitivity_trigger_level, 'medium') |
| 97 | + WHERE high_sensitivity_threshold IS NULL |
| 98 | + OR medium_sensitivity_threshold IS NULL |
| 99 | + OR low_sensitivity_threshold IS NULL |
| 100 | + OR sensitivity_trigger_level IS NULL |
| 101 | + """)) |
| 102 | + |
| 103 | + # 6. 验证迁移结果 |
| 104 | + logger.info("验证迁移结果...") |
| 105 | + result = conn.execute(text(""" |
| 106 | + SELECT column_name, data_type, is_nullable, column_default |
| 107 | + FROM information_schema.columns |
| 108 | + WHERE table_name = 'risk_type_config' |
| 109 | + AND column_name IN ('high_sensitivity_threshold', 'medium_sensitivity_threshold', 'low_sensitivity_threshold', 'sensitivity_trigger_level') |
| 110 | + ORDER BY column_name |
| 111 | + """)) |
| 112 | + |
| 113 | + migrated_columns = result.fetchall() |
| 114 | + logger.info("迁移后的敏感度字段:") |
| 115 | + for col in migrated_columns: |
| 116 | + logger.info(f" {col[0]} ({col[1]}) - nullable: {col[2]}, default: {col[3]}") |
| 117 | + |
| 118 | + # 检查现有记录数量 |
| 119 | + result = conn.execute(text("SELECT COUNT(*) FROM risk_type_config")) |
| 120 | + record_count = result.scalar() |
| 121 | + logger.info(f"risk_type_config 表中共有 {record_count} 条记录") |
| 122 | + |
| 123 | + # 提交事务 |
| 124 | + trans.commit() |
| 125 | + logger.info("✅ 数据库迁移完成!") |
| 126 | + |
| 127 | + # 显示最终表结构 |
| 128 | + logger.info("最终表结构:") |
| 129 | + result = conn.execute(text("SELECT column_name FROM information_schema.columns WHERE table_name = 'risk_type_config' ORDER BY ordinal_position")) |
| 130 | + all_columns = [row[0] for row in result.fetchall()] |
| 131 | + for i, col in enumerate(all_columns, 1): |
| 132 | + logger.info(f" {i:2d}. {col}") |
| 133 | + |
| 134 | + except Exception as e: |
| 135 | + # 回滚事务 |
| 136 | + trans.rollback() |
| 137 | + logger.error(f"❌ 迁移失败,已回滚: {e}") |
| 138 | + raise |
| 139 | + |
| 140 | + except Exception as e: |
| 141 | + logger.error(f"❌ 数据库迁移错误: {e}") |
| 142 | + sys.exit(1) |
| 143 | + |
| 144 | +def check_migration_needed(): |
| 145 | + """检查是否需要迁移""" |
| 146 | + try: |
| 147 | + engine = create_engine(settings.database_url) |
| 148 | + with engine.connect() as conn: |
| 149 | + result = conn.execute(text(""" |
| 150 | + SELECT column_name |
| 151 | + FROM information_schema.columns |
| 152 | + WHERE table_name = 'risk_type_config' |
| 153 | + AND column_name IN ('high_sensitivity_threshold', 'medium_sensitivity_threshold', 'low_sensitivity_threshold', 'sensitivity_trigger_level') |
| 154 | + """)) |
| 155 | + existing_columns = [row[0] for row in result.fetchall()] |
| 156 | + |
| 157 | + required_columns = ['high_sensitivity_threshold', 'medium_sensitivity_threshold', 'low_sensitivity_threshold', 'sensitivity_trigger_level'] |
| 158 | + missing_columns = [col for col in required_columns if col not in existing_columns] |
| 159 | + |
| 160 | + if missing_columns: |
| 161 | + logger.info(f"需要迁移,缺失敏感度字段: {missing_columns}") |
| 162 | + return True |
| 163 | + else: |
| 164 | + logger.info("所有敏感度字段都已存在,无需迁移") |
| 165 | + return False |
| 166 | + |
| 167 | + except Exception as e: |
| 168 | + logger.error(f"检查迁移状态失败: {e}") |
| 169 | + return True # 如果检查失败,假设需要迁移 |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + logger.info("象信AI安全护栏平台 - risk_type_config 敏感度字段迁移") |
| 173 | + logger.info("=" * 60) |
| 174 | + |
| 175 | + if check_migration_needed(): |
| 176 | + migrate() |
| 177 | + else: |
| 178 | + logger.info("✅ 数据库已是最新版本,无需迁移") |
0 commit comments