Skip to content

Commit cbc6c98

Browse files
pankajastroCloud Composer Team
authored and
Cloud Composer Team
committed
Add deferrable mode in RedshiftPauseClusterOperator (#28850)
In this PR, I'm adding `aiobotocore` as an additional dependency until aio-libs/aiobotocore#976 has been resolved. I'm adding `AwsBaseAsyncHook` a basic async AWS hook. This will at present support the default `botocore` auth i'e if airflow connection is not provided then auth using ENV and if airflow connection is provided then basic auth with secret-key/access-key-id/profile/token and arn-method. maybe we can support the other auth incrementally depending on the community interest. Because the dependency making things a little bit complicated so I have created a new test module `deferrable` inside AWS provider tests and I'm keeping the async-related test in this particular module. I have also added a new CI job to test/run the AWS deferable operator tests and I'm ignoring the deferable tests in other CI job test runs. Add a trigger class which will wait until the Redshift pause request reaches a terminal state i.e paused or fail, We also have retry logic like the sync operator. This PR donates the following developed RedshiftPauseClusterOperatorAsync` in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow. GitOrigin-RevId: cf77c3b96609aa8c260566274d54b06eb38c8100
1 parent bd8aab2 commit cbc6c98

File tree

20 files changed

+821
-18
lines changed

20 files changed

+821
-18
lines changed

.github/workflows/ci.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,30 @@ jobs:
799799
run: breeze ci fix-ownership
800800
if: always()
801801

802+
tests-aws-async-provider:
803+
timeout-minutes: 50
804+
name: "Pytest for AWS Async Provider"
805+
runs-on: "${{needs.build-info.outputs.runs-on}}"
806+
needs: [build-info, wait-for-ci-images]
807+
if: needs.build-info.outputs.run-tests == 'true'
808+
steps:
809+
- name: Cleanup repo
810+
shell: bash
811+
run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*"
812+
- name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )"
813+
uses: actions/checkout@v3
814+
with:
815+
persist-credentials: false
816+
- name: "Prepare breeze & CI image"
817+
uses: ./.github/actions/prepare_breeze_and_image
818+
- name: "Run AWS Async Test"
819+
run: "breeze shell \
820+
'pip install aiobotocore>=2.1.1 && pytest /opt/airflow/tests/providers/amazon/aws/deferrable'"
821+
- name: "Post Tests"
822+
uses: ./.github/actions/post_tests
823+
- name: "Fix ownership"
824+
run: breeze ci fix-ownership
825+
if: always()
802826

803827
tests-helm:
804828
timeout-minutes: 80

airflow/providers/amazon/aws/hooks/base_aws.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@
5151

5252
from airflow.compat.functools import cached_property
5353
from airflow.configuration import conf
54-
from airflow.exceptions import AirflowException, AirflowNotFoundException
54+
from airflow.exceptions import (
55+
AirflowException,
56+
AirflowNotFoundException,
57+
)
5558
from airflow.hooks.base import BaseHook
5659
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
5760
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
@@ -62,7 +65,6 @@
6265

6366
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
6467

65-
6668
if TYPE_CHECKING:
6769
from airflow.models.connection import Connection # Avoid circular imports.
6870

@@ -877,3 +879,128 @@ def _parse_s3_config(config_file_name: str, config_format: str | None = "boto",
877879
config_format=config_format,
878880
profile=profile,
879881
)
882+
883+
884+
try:
885+
import aiobotocore.credentials
886+
from aiobotocore.session import AioSession, get_session
887+
except ImportError:
888+
pass
889+
890+
891+
class BaseAsyncSessionFactory(BaseSessionFactory):
892+
"""
893+
Base AWS Session Factory class to handle aiobotocore session creation.
894+
895+
It currently, handles ENV, AWS secret key and STS client method ``assume_role``
896+
provided in Airflow connection
897+
"""
898+
899+
async def get_role_credentials(self) -> dict:
900+
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
901+
async with self._basic_session.create_client("sts", region_name=self.region_name) as client:
902+
response = await client.assume_role(
903+
RoleArn=self.role_arn,
904+
RoleSessionName=self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
905+
**self.conn.assume_role_kwargs,
906+
)
907+
return response["Credentials"]
908+
909+
async def _get_refresh_credentials(self) -> dict[str, Any]:
910+
self.log.debug("Refreshing credentials")
911+
assume_role_method = self.conn.assume_role_method
912+
if assume_role_method != "assume_role":
913+
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")
914+
915+
credentials = await self.get_role_credentials()
916+
917+
expiry_time = credentials["Expiration"].isoformat()
918+
self.log.debug("New credentials expiry_time: %s", expiry_time)
919+
credentials = {
920+
"access_key": credentials.get("AccessKeyId"),
921+
"secret_key": credentials.get("SecretAccessKey"),
922+
"token": credentials.get("SessionToken"),
923+
"expiry_time": expiry_time,
924+
}
925+
return credentials
926+
927+
def _get_session_with_assume_role(self) -> AioSession:
928+
929+
assume_role_method = self.conn.assume_role_method
930+
if assume_role_method != "assume_role":
931+
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")
932+
933+
credentials = aiobotocore.credentials.AioRefreshableCredentials.create_from_metadata(
934+
metadata=self._get_refresh_credentials(),
935+
refresh_using=self._get_refresh_credentials,
936+
method="sts-assume-role",
937+
)
938+
939+
session = aiobotocore.session.get_session()
940+
session._credentials = credentials
941+
return session
942+
943+
@cached_property
944+
def _basic_session(self) -> AioSession:
945+
"""Cached property with basic aiobotocore.session.AioSession."""
946+
session_kwargs = self.conn.session_kwargs
947+
aws_access_key_id = session_kwargs.get("aws_access_key_id")
948+
aws_secret_access_key = session_kwargs.get("aws_secret_access_key")
949+
aws_session_token = session_kwargs.get("aws_session_token")
950+
region_name = session_kwargs.get("region_name")
951+
profile_name = session_kwargs.get("profile_name")
952+
953+
aio_session = get_session()
954+
if profile_name is not None:
955+
aio_session.set_config_variable("profile", profile_name)
956+
if aws_access_key_id or aws_secret_access_key or aws_session_token:
957+
aio_session.set_credentials(
958+
access_key=aws_access_key_id,
959+
secret_key=aws_secret_access_key,
960+
token=aws_session_token,
961+
)
962+
if region_name is not None:
963+
aio_session.set_config_variable("region", region_name)
964+
return aio_session
965+
966+
def create_session(self) -> AioSession:
967+
"""Create aiobotocore Session from connection and config."""
968+
if not self._conn:
969+
self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
970+
return get_session()
971+
elif not self.role_arn:
972+
return self._basic_session
973+
return self._get_session_with_assume_role()
974+
975+
976+
class AwsBaseAsyncHook(AwsBaseHook):
977+
"""
978+
Interacts with AWS using aiobotocore asynchronously.
979+
980+
:param aws_conn_id: The Airflow connection used for AWS credentials.
981+
If this is None or empty then the default botocore behaviour is used. If
982+
running Airflow in a distributed manner and aws_conn_id is None or
983+
empty, then default botocore configuration would be used (and must be
984+
maintained on each worker node).
985+
:param verify: Whether to verify SSL certificates.
986+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
987+
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
988+
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
989+
:param config: Configuration for botocore client.
990+
"""
991+
992+
def get_async_session(self) -> AioSession:
993+
"""Get the underlying aiobotocore.session.AioSession(...)."""
994+
return BaseAsyncSessionFactory(
995+
conn=self.conn_config, region_name=self.region_name, config=self.config
996+
).create_session()
997+
998+
async def get_client_async(self):
999+
"""Get the underlying aiobotocore client using aiobotocore session"""
1000+
return self.get_async_session().create_client(
1001+
self.client_type,
1002+
region_name=self.region_name,
1003+
verify=self.verify,
1004+
endpoint_url=self.conn_config.endpoint_url,
1005+
config=self.config,
1006+
)

airflow/providers/amazon/aws/hooks/redshift_cluster.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import asyncio
1920
import warnings
2021
from typing import Any, Sequence
2122

23+
import botocore.exceptions
2224
from botocore.exceptions import ClientError
2325

24-
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
2527

2628

2729
class RedshiftHook(AwsBaseHook):
@@ -200,3 +202,82 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifi
200202
return snapshot_status
201203
except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
202204
return None
205+
206+
207+
class RedshiftAsyncHook(AwsBaseAsyncHook):
208+
"""Interact with AWS Redshift using aiobotocore library"""
209+
210+
def __init__(self, *args: Any, **kwargs: Any) -> None:
211+
kwargs["client_type"] = "redshift"
212+
super().__init__(*args, **kwargs)
213+
214+
async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]:
215+
"""
216+
Connects to the AWS redshift cluster via aiobotocore and get the status
217+
and returns the status of the cluster based on the cluster_identifier passed
218+
219+
:param cluster_identifier: unique identifier of a cluster
220+
:param delete_operation: whether the method has been called as part of delete cluster operation
221+
"""
222+
async with await self.get_client_async() as client:
223+
try:
224+
response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
225+
cluster_state = (
226+
response["Clusters"][0]["ClusterStatus"] if response and response["Clusters"] else None
227+
)
228+
return {"status": "success", "cluster_state": cluster_state}
229+
except botocore.exceptions.ClientError as error:
230+
if delete_operation and error.response.get("Error", {}).get("Code", "") == "ClusterNotFound":
231+
return {"status": "success", "cluster_state": "cluster_not_found"}
232+
return {"status": "error", "message": str(error)}
233+
234+
async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]:
235+
"""
236+
Connects to the AWS redshift cluster via aiobotocore and
237+
pause the cluster based on the cluster_identifier passed
238+
239+
:param cluster_identifier: unique identifier of a cluster
240+
:param poll_interval: polling period in seconds to check for the status
241+
"""
242+
try:
243+
async with await self.get_client_async() as client:
244+
response = await client.pause_cluster(ClusterIdentifier=cluster_identifier)
245+
status = response["Cluster"]["ClusterStatus"] if response and response["Cluster"] else None
246+
if status == "pausing":
247+
flag = asyncio.Event()
248+
while True:
249+
expected_response = await asyncio.create_task(
250+
self.get_cluster_status(cluster_identifier, "paused", flag)
251+
)
252+
await asyncio.sleep(poll_interval)
253+
if flag.is_set():
254+
return expected_response
255+
return {"status": "error", "cluster_state": status}
256+
except botocore.exceptions.ClientError as error:
257+
return {"status": "error", "message": str(error)}
258+
259+
async def get_cluster_status(
260+
self,
261+
cluster_identifier: str,
262+
expected_state: str,
263+
flag: asyncio.Event,
264+
delete_operation: bool = False,
265+
) -> dict[str, Any]:
266+
"""
267+
check for expected Redshift cluster state
268+
269+
:param cluster_identifier: unique identifier of a cluster
270+
:param expected_state: expected_state example("available", "pausing", "paused"")
271+
:param flag: asyncio even flag set true if success and if any error
272+
:param delete_operation: whether the method has been called as part of delete cluster operation
273+
"""
274+
try:
275+
response = await self.cluster_status(cluster_identifier, delete_operation=delete_operation)
276+
if ("cluster_state" in response and response["cluster_state"] == expected_state) or response[
277+
"status"
278+
] == "error":
279+
flag.set()
280+
return response
281+
except botocore.exceptions.ClientError as error:
282+
flag.set()
283+
return {"status": "error", "message": str(error)}

airflow/providers/amazon/aws/operators/redshift_cluster.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from airflow.exceptions import AirflowException
2323
from airflow.models import BaseOperator
2424
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
25+
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
2526

2627
if TYPE_CHECKING:
2728
from airflow.utils.context import Context
@@ -447,6 +448,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
447448
448449
:param cluster_identifier: id of the AWS Redshift Cluster
449450
:param aws_conn_id: aws connection to use
451+
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
450452
"""
451453

452454
template_fields: Sequence[str] = ("cluster_identifier",)
@@ -458,11 +460,15 @@ def __init__(
458460
*,
459461
cluster_identifier: str,
460462
aws_conn_id: str = "aws_default",
463+
deferrable: bool = False,
464+
poll_interval: int = 10,
461465
**kwargs,
462466
):
463467
super().__init__(**kwargs)
464468
self.cluster_identifier = cluster_identifier
465469
self.aws_conn_id = aws_conn_id
470+
self.deferrable = deferrable
471+
self.poll_interval = poll_interval
466472
# These parameters are added to address an issue with the boto3 API where the API
467473
# prematurely reports the cluster as available to receive requests. This causes the cluster
468474
# to reject initial attempts to pause the cluster despite reporting the correct state.
@@ -472,18 +478,48 @@ def __init__(
472478
def execute(self, context: Context):
473479
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
474480

475-
while self._attempts >= 1:
476-
try:
477-
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
478-
return
479-
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
480-
self._attempts = self._attempts - 1
481-
482-
if self._attempts > 0:
483-
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
484-
time.sleep(self._attempt_interval)
485-
else:
486-
raise error
481+
if self.deferrable:
482+
self.defer(
483+
timeout=self.execution_timeout,
484+
trigger=RedshiftClusterTrigger(
485+
task_id=self.task_id,
486+
poll_interval=self.poll_interval,
487+
aws_conn_id=self.aws_conn_id,
488+
cluster_identifier=self.cluster_identifier,
489+
attempts=self._attempts,
490+
operation_type="pause_cluster",
491+
),
492+
method_name="execute_complete",
493+
)
494+
else:
495+
while self._attempts >= 1:
496+
try:
497+
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
498+
return
499+
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
500+
self._attempts = self._attempts - 1
501+
502+
if self._attempts > 0:
503+
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
504+
time.sleep(self._attempt_interval)
505+
else:
506+
raise error
507+
508+
def execute_complete(self, context: Context, event: Any = None) -> None:
509+
"""
510+
Callback for when the trigger fires - returns immediately.
511+
Relies on trigger to throw an exception, otherwise it assumes execution was
512+
successful.
513+
"""
514+
if event:
515+
if "status" in event and event["status"] == "error":
516+
msg = f"{event['status']}: {event['message']}"
517+
raise AirflowException(msg)
518+
elif "status" in event and event["status"] == "success":
519+
self.log.info("%s completed successfully.", self.task_id)
520+
self.log.info("Paused cluster successfully")
521+
else:
522+
raise AirflowException("No event received from trigger")
487523

488524

489525
class RedshiftDeleteClusterOperator(BaseOperator):
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.

0 commit comments

Comments
 (0)