Skip to content

Commit 4f42588

Browse files
committed
drop foreign key first before dropping column
1 parent 4dfc236 commit 4f42588

File tree

3 files changed

+87
-25
lines changed

3 files changed

+87
-25
lines changed

sqlalchemy_iris/alembic.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import logging
22

3+
from typing import Optional
4+
35
from sqlalchemy.ext.compiler import compiles
6+
from sqlalchemy.sql.base import Executable
7+
from sqlalchemy.sql.elements import ClauseElement
8+
49
from alembic.ddl import DefaultImpl
510
from alembic.ddl.base import ColumnNullable
611
from alembic.ddl.base import ColumnType
@@ -56,25 +61,43 @@ def compare_server_default(
5661
rendered_inspector_default,
5762
)
5863

59-
def correct_for_autogen_constraints(
64+
def drop_column(
6065
self,
61-
conn_unique_constraints,
62-
conn_indexes,
63-
metadata_unique_constraints,
64-
metadata_indexes,
65-
):
66-
67-
doubled_constraints = {
68-
index
69-
for index in conn_indexes
70-
if index.info.get("duplicates_constraint")
71-
}
72-
73-
for ix in doubled_constraints:
74-
conn_indexes.remove(ix)
66+
table_name: str,
67+
column: Column,
68+
schema: Optional[str] = None,
69+
**kw,
70+
) -> None:
71+
column_name = column.name
72+
fkeys = self.dialect.get_foreign_keys(self.connection, table_name, schema)
73+
fkey = [
74+
fkey["name"] for fkey in fkeys if column_name in fkey["constrained_columns"]
75+
]
76+
if len(fkey) == 1:
77+
self._exec(_ExecDropForeignKey(table_name, fkey[0], schema))
78+
super().drop_column(table_name, column, schema, **kw)
79+
80+
81+
class _ExecDropForeignKey(Executable, ClauseElement):
82+
inherit_cache = False
83+
84+
def __init__(
85+
self, table_name: str, foreignkey_name: Column, schema: Optional[str]
86+
) -> None:
87+
self.table_name = table_name
88+
self.foreignkey_name = foreignkey_name
89+
self.schema = schema
90+
91+
92+
@compiles(_ExecDropForeignKey, "iris")
93+
def _exec_drop_foreign_key(
94+
element: _ExecDropForeignKey, compiler: IRISDDLCompiler, **kw
95+
) -> str:
96+
return "%s DROP FOREIGN KEY %s" % (
97+
alter_table(compiler, element.table_name, element.schema),
98+
format_column_name(compiler, element.foreignkey_name),
99+
)
75100

76-
# if not sqla_compat.sqla_2:
77-
# self._skip_functional_indexes(metadata_indexes, conn_indexes)
78101

79102
@compiles(ColumnNullable, "iris")
80103
def visit_column_nullable(

sqlalchemy_iris/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,6 @@ def get_temp_table_names(self, connection, dblink=None, **kw):
10731073

10741074
@reflection.cache
10751075
def has_table(self, connection, table_name, schema=None, **kw):
1076-
self._ensure_has_table_connection(connection)
10771076
tables = ischema.tables
10781077
schema_name = self.get_schema(schema)
10791078

@@ -1085,14 +1084,7 @@ def has_table(self, connection, table_name, schema=None, **kw):
10851084
)
10861085
return bool(connection.execute(s).scalar())
10871086

1088-
def _default_or_error(self, connection, tablename, schema, method, **kw):
1089-
if self.has_table(connection, tablename, schema, **kw):
1090-
return method()
1091-
else:
1092-
raise exc.NoSuchTableError(f"{schema}.{tablename}")
1093-
10941087
def _get_all_objects(self, connection, schema, filter_names, scope, kind, **kw):
1095-
self._ensure_has_table_connection(connection)
10961088
tables = ischema.tables
10971089
schema_name = self.get_schema(schema)
10981090

tests/test_alembic.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@
33
except: # noqa
44
pass
55
else:
6+
from sqlalchemy import MetaData
7+
from sqlalchemy import Table
8+
from sqlalchemy import inspect
9+
from sqlalchemy import ForeignKey
10+
from sqlalchemy import Column
11+
from sqlalchemy import Integer
12+
from sqlalchemy import text
13+
from sqlalchemy.types import Text
14+
from sqlalchemy.types import LargeBinary
15+
16+
from alembic import op
17+
from alembic.testing import fixture
18+
from alembic.testing import combinations
19+
from alembic.testing import eq_
20+
from alembic.testing.fixtures import TestBase
21+
from alembic.testing.fixtures import op_fixture
22+
from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest
23+
624
from alembic.testing.suite.test_op import (
725
BackendAlterColumnTest as _BackendAlterColumnTest,
826
)
@@ -23,3 +41,32 @@ def test_alter_column_autoincrement_pk_implicit_true(self):
2341

2442
def test_alter_column_autoincrement_pk_explicit_true(self):
2543
pass
44+
45+
@combinations(
46+
(None,),
47+
("test",),
48+
argnames="schema",
49+
id_="s",
50+
)
51+
class RoundTripTest(TestBase):
52+
@fixture
53+
def tables(self, connection):
54+
self.meta = MetaData()
55+
self.meta.schema = self.schema
56+
self.tbl_other = Table(
57+
"other", self.meta, Column("oid", Integer, primary_key=True)
58+
)
59+
self.tbl = Table(
60+
"round_trip_table",
61+
self.meta,
62+
Column("id", Integer, primary_key=True),
63+
Column("oid_fk", ForeignKey("other.oid")),
64+
)
65+
self.meta.create_all(connection)
66+
yield
67+
self.meta.drop_all(connection)
68+
69+
def test_drop_col_with_fk(self, ops_context, connection, tables):
70+
ops_context.drop_column("round_trip_table", "oid_fk", self.meta.schema)
71+
insp = inspect(connection)
72+
eq_(insp.get_foreign_keys("round_trip_table", schema=self.meta.schema), [])

0 commit comments

Comments
 (0)