Skip to content

Commit a7e1cb2

Browse files
Add general-purpose "notifier" concept to DAGs (#28569)
* Add general-purpose "notifier" concept to DAGs This makes it easy for users to setup notifications for their DAGs using on_*_callbacks It's extensible and we can add it to more providers. Implemented a SlackNotifier in this phase. In the course of this, I extracted a 'Templater' class from AbstractBaseOperator and have both the Notifier & ABO inherit from it. This is necessary in other to avoid code repetition. * Renames and a fixup not to require a call to super in subclasses * Raise compat exception and add docs * Ignore import error due to optional provider feature * fixup! Ignore import error due to optional provider feature * Apply suggestions from code review Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com> * fixup! Apply suggestions from code review Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
1 parent 24af35b commit a7e1cb2

File tree

18 files changed

+790
-132
lines changed

18 files changed

+790
-132
lines changed

airflow/models/abstractoperator.py

Lines changed: 14 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
from airflow.exceptions import AirflowException
2727
from airflow.models.expandinput import NotFullyPopulated
2828
from airflow.models.taskmixin import DAGNode
29+
from airflow.template.templater import Templater
2930
from airflow.utils.context import Context
30-
from airflow.utils.helpers import render_template_as_native, render_template_to_string
31-
from airflow.utils.log.logging_mixin import LoggingMixin
32-
from airflow.utils.mixins import ResolveMixin
3331
from airflow.utils.session import NEW_SESSION, provide_session
3432
from airflow.utils.sqlalchemy import skip_locked, with_row_locks
3533
from airflow.utils.state import State, TaskInstanceState
@@ -76,7 +74,7 @@ class NotMapped(Exception):
7674
"""Raise if a task is neither mapped nor has any parent mapped groups."""
7775

7876

79-
class AbstractOperator(LoggingMixin, DAGNode):
77+
class AbstractOperator(Templater, DAGNode):
8078
"""Common implementation for operators, including unmapped and mapped.
8179
8280
This base class is more about sharing implementations, not defining a common
@@ -96,10 +94,6 @@ class AbstractOperator(LoggingMixin, DAGNode):
9694

9795
# Defines the operator level extra links.
9896
operator_extra_links: Collection[BaseOperatorLink]
99-
# For derived classes to define which fields will get jinjaified.
100-
template_fields: Collection[str]
101-
# Defines which files extensions to look for in the templated fields.
102-
template_ext: Sequence[str]
10397

10498
owner: str
10599
task_id: str
@@ -153,48 +147,6 @@ def dag_id(self) -> str:
153147
def node_id(self) -> str:
154148
return self.task_id
155149

156-
def get_template_env(self) -> jinja2.Environment:
157-
"""Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG."""
158-
# This is imported locally since Jinja2 is heavy and we don't need it
159-
# for most of the functionalities. It is imported by get_template_env()
160-
# though, so we don't need to put this after the 'if dag' check.
161-
from airflow.templates import SandboxedEnvironment
162-
163-
dag = self.get_dag()
164-
if dag:
165-
return dag.get_template_env(force_sandboxed=False)
166-
return SandboxedEnvironment(cache_size=0)
167-
168-
def prepare_template(self) -> None:
169-
"""Hook triggered after the templated fields get replaced by their content.
170-
171-
If you need your operator to alter the content of the file before the
172-
template is rendered, it should override this method to do so.
173-
"""
174-
175-
def resolve_template_files(self) -> None:
176-
"""Getting the content of files for template_field / template_ext."""
177-
if self.template_ext:
178-
for field in self.template_fields:
179-
content = getattr(self, field, None)
180-
if content is None:
181-
continue
182-
elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext):
183-
env = self.get_template_env()
184-
try:
185-
setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore
186-
except Exception:
187-
self.log.exception("Failed to resolve template field %r", field)
188-
elif isinstance(content, list):
189-
env = self.get_template_env()
190-
for i, item in enumerate(content):
191-
if isinstance(item, str) and any(item.endswith(ext) for ext in self.template_ext):
192-
try:
193-
content[i] = env.loader.get_source(env, item)[0] # type: ignore
194-
except Exception:
195-
self.log.exception("Failed to get source %s", item)
196-
self.prepare_template()
197-
198150
def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
199151
"""Get direct relative IDs to the current task, upstream or downstream."""
200152
if upstream:
@@ -580,6 +532,17 @@ def render_template_fields(
580532
"""
581533
raise NotImplementedError()
582534

535+
def _render(self, template, context, dag: DAG | None = None):
536+
if dag is None:
537+
dag = self.get_dag()
538+
return super()._render(template, context, dag=dag)
539+
540+
def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
541+
"""Get the template environment for rendering templates."""
542+
if dag is None:
543+
dag = self.get_dag()
544+
return super().get_template_env(dag=dag)
545+
583546
@provide_session
584547
def _do_render_template_fields(
585548
self,
@@ -591,6 +554,7 @@ def _do_render_template_fields(
591554
*,
592555
session: Session = NEW_SESSION,
593556
) -> None:
557+
"""Override the base to use custom error logging."""
594558
for attr_name in template_fields:
595559
try:
596560
value = getattr(parent, attr_name)
@@ -618,85 +582,3 @@ def _do_render_template_fields(
618582
raise
619583
else:
620584
setattr(parent, attr_name, rendered_content)
621-
622-
def render_template(
623-
self,
624-
content: Any,
625-
context: Context,
626-
jinja_env: jinja2.Environment | None = None,
627-
seen_oids: set[int] | None = None,
628-
) -> Any:
629-
"""Render a templated string.
630-
631-
If *content* is a collection holding multiple templated strings, strings
632-
in the collection will be templated recursively.
633-
634-
:param content: Content to template. Only strings can be templated (may
635-
be inside a collection).
636-
:param context: Dict with values to apply on templated content
637-
:param jinja_env: Jinja environment. Can be provided to avoid
638-
re-creating Jinja environments during recursion.
639-
:param seen_oids: template fields already rendered (to avoid
640-
*RecursionError* on circular dependencies)
641-
:return: Templated content
642-
"""
643-
# "content" is a bad name, but we're stuck to it being public API.
644-
value = content
645-
del content
646-
647-
if seen_oids is not None:
648-
oids = seen_oids
649-
else:
650-
oids = set()
651-
652-
if id(value) in oids:
653-
return value
654-
655-
if not jinja_env:
656-
jinja_env = self.get_template_env()
657-
658-
if isinstance(value, str):
659-
if any(value.endswith(ext) for ext in self.template_ext): # A filepath.
660-
template = jinja_env.get_template(value)
661-
else:
662-
template = jinja_env.from_string(value)
663-
dag = self.get_dag()
664-
if dag and dag.render_template_as_native_obj:
665-
return render_template_as_native(template, context)
666-
return render_template_to_string(template, context)
667-
668-
if isinstance(value, ResolveMixin):
669-
return value.resolve(context)
670-
671-
# Fast path for common built-in collections.
672-
if value.__class__ is tuple:
673-
return tuple(self.render_template(element, context, jinja_env, oids) for element in value)
674-
elif isinstance(value, tuple): # Special case for named tuples.
675-
return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value))
676-
elif isinstance(value, list):
677-
return [self.render_template(element, context, jinja_env, oids) for element in value]
678-
elif isinstance(value, dict):
679-
return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()}
680-
elif isinstance(value, set):
681-
return {self.render_template(element, context, jinja_env, oids) for element in value}
682-
683-
# More complex collections.
684-
self._render_nested_template_fields(value, context, jinja_env, oids)
685-
return value
686-
687-
def _render_nested_template_fields(
688-
self,
689-
value: Any,
690-
context: Context,
691-
jinja_env: jinja2.Environment,
692-
seen_oids: set[int],
693-
) -> None:
694-
if id(value) in seen_oids:
695-
return
696-
seen_oids.add(id(value))
697-
try:
698-
nested_template_fields = value.template_fields
699-
except AttributeError:
700-
# content has no inner template fields
701-
return
702-
self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids)

airflow/notifications/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

airflow/notifications/basenotifier.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from abc import abstractmethod
21+
from typing import TYPE_CHECKING, Sequence
22+
23+
import jinja2
24+
25+
from airflow.template.templater import Templater
26+
from airflow.utils.context import Context, context_merge
27+
28+
if TYPE_CHECKING:
29+
from airflow import DAG
30+
31+
32+
class BaseNotifier(Templater):
33+
"""BaseNotifier class for sending notifications"""
34+
35+
template_fields: Sequence[str] = ()
36+
template_ext: Sequence[str] = ()
37+
38+
def __init__(self):
39+
super().__init__()
40+
self.resolve_template_files()
41+
42+
def _update_context(self, context: Context) -> Context:
43+
"""
44+
Add additional context to the context
45+
46+
:param context: The airflow context
47+
"""
48+
additional_context = ((f, getattr(self, f)) for f in self.template_fields)
49+
context_merge(context, additional_context)
50+
return context
51+
52+
def _render(self, template, context, dag: DAG | None = None):
53+
dag = dag or context["dag"]
54+
return super()._render(template, context, dag)
55+
56+
def render_template_fields(
57+
self,
58+
context: Context,
59+
jinja_env: jinja2.Environment | None = None,
60+
) -> None:
61+
"""Template all attributes listed in *self.template_fields*.
62+
63+
This mutates the attributes in-place and is irreversible.
64+
65+
:param context: Context dict with values to apply on content.
66+
:param jinja_env: Jinja environment to use for rendering.
67+
"""
68+
dag = context["dag"]
69+
if not jinja_env:
70+
jinja_env = self.get_template_env(dag=dag)
71+
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
72+
73+
@abstractmethod
74+
def notify(self, context: Context) -> None:
75+
"""
76+
Sends a notification
77+
78+
:param context: The airflow context
79+
"""
80+
...
81+
82+
def __call__(self, context: Context) -> None:
83+
"""
84+
Send a notification
85+
86+
:param context: The airflow context
87+
"""
88+
context = self._update_context(context)
89+
self.render_template_fields(context)
90+
try:
91+
self.notify(context)
92+
except Exception as e:
93+
self.log.exception("Failed to send notification: %s", e)

airflow/providers/slack/CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Features
4141
~~~~~~~~
4242

4343
* ``Implements SqlToSlackApiFileOperator (#26374)``
44+
* ``Added SlackNotifier (#28569)``
4445

4546
Bug Fixes
4647
~~~~~~~~~
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

0 commit comments

Comments
 (0)