Skip to content

Commit 11e082f

Browse files
pelinKurandlpzx
andauthored
Automated metadata generation using genAI MVP (#1598)
### Feature - Feature ### Detail - Automated metadata generation using gen AI. MVP phase ### Related #1599 By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: dlpzx <dlpzx@amazon.com>
1 parent 4b67986 commit 11e082f

34 files changed

+1850
-13
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dataall.base.api.constants import GraphQLEnumMapper
2+
3+
4+
class MetadataGenerationTargets(GraphQLEnumMapper):
5+
"""Describes the s3_datasets metadata generation targets"""
6+
7+
Table = 'Table'
8+
Folder = 'Folder'
9+
S3_Dataset = 'S3_Dataset'

backend/dataall/modules/s3_datasets/api/dataset/input_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@
4646
gql.Argument(name='expiryMaxDuration', type=gql.Integer),
4747
],
4848
)
49+
DatasetMetadataInput = gql.InputType(
50+
name='DatasetMetadataInput',
51+
arguments=[
52+
gql.Argument('label', gql.String),
53+
gql.Argument('description', gql.String),
54+
gql.Argument('tags', gql.ArrayType(gql.String)),
55+
gql.Argument('topics', gql.ArrayType(gql.Ref('Topic'))),
56+
],
57+
)
4958

5059
DatasetPresignedUrlInput = gql.InputType(
5160
name='DatasetPresignedUrlInput',
@@ -58,6 +67,14 @@
5867

5968
CrawlerInput = gql.InputType(name='CrawlerInput', arguments=[gql.Argument(name='prefix', type=gql.String)])
6069

70+
SampleDataInput = gql.InputType(
71+
name='SampleDataInput',
72+
arguments=[
73+
gql.Field(name='fields', type=gql.ArrayType(gql.String)),
74+
gql.Field(name='rows', type=gql.ArrayType(gql.String)),
75+
],
76+
)
77+
6178
ImportDatasetInput = gql.InputType(
6279
name='ImportDatasetInput',
6380
arguments=[

backend/dataall/modules/s3_datasets/api/dataset/mutations.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from dataall.base.api import gql
2-
from dataall.modules.s3_datasets.api.dataset.input_types import (
3-
ModifyDatasetInput,
4-
NewDatasetInput,
5-
ImportDatasetInput,
6-
)
2+
from dataall.modules.s3_datasets.api.dataset.input_types import ModifyDatasetInput, NewDatasetInput, ImportDatasetInput
73
from dataall.modules.s3_datasets.api.dataset.resolvers import (
84
create_dataset,
95
update_dataset,
106
generate_dataset_access_token,
117
delete_dataset,
128
import_dataset,
139
start_crawler,
10+
generate_metadata,
1411
)
12+
from dataall.modules.s3_datasets.api.dataset.enums import MetadataGenerationTargets
1513

1614
createDataset = gql.MutationField(
1715
name='createDataset',
@@ -68,3 +66,15 @@
6866
resolver=start_crawler,
6967
type=gql.Ref('GlueCrawler'),
7068
)
69+
generateMetadata = gql.MutationField(
70+
name='generateMetadata',
71+
args=[
72+
gql.Argument(name='resourceUri', type=gql.NonNullableType(gql.String)),
73+
gql.Argument(name='targetType', type=gql.NonNullableType(MetadataGenerationTargets.toGraphQLEnum())),
74+
gql.Argument(name='version', type=gql.Integer),
75+
gql.Argument(name='metadataTypes', type=gql.NonNullableType(gql.ArrayType(gql.String))),
76+
gql.Argument(name='sampleData', type=gql.Ref('SampleDataInput')),
77+
],
78+
type=gql.Ref('GeneratedMetadata'),
79+
resolver=generate_metadata,
80+
)

backend/dataall/modules/s3_datasets/api/dataset/queries.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
get_dataset_assume_role_url,
55
get_file_upload_presigned_url,
66
list_datasets_owned_by_env_group,
7+
list_dataset_tables_folders,
8+
read_sample_data,
79
)
810

911
getDataset = gql.QueryField(
@@ -45,3 +47,18 @@
4547
resolver=list_datasets_owned_by_env_group,
4648
test_scope='Dataset',
4749
)
50+
listDatasetTablesFolders = gql.QueryField(
51+
name='listDatasetTablesFolders',
52+
args=[
53+
gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
54+
gql.Argument(name='filter', type=gql.Ref('DatasetFilter')),
55+
],
56+
type=gql.Ref('DatasetItemsSearchResult'),
57+
resolver=list_dataset_tables_folders,
58+
)
59+
listSampleData = gql.QueryField(
60+
name='listSampleData',
61+
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
62+
type=gql.Ref('QueryPreviewResult'), # basically returns nothing...?
63+
resolver=read_sample_data,
64+
) # return the data -> user invokes generateMetadata again + sample data ; similar api exists

backend/dataall/modules/s3_datasets/api/dataset/resolvers.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
2+
import re
33
from dataall.base.api.context import Context
44
from dataall.base.feature_toggle_checker import is_feature_enabled
55
from dataall.base.utils.expiration_util import Expiration
@@ -11,6 +11,9 @@
1111
from dataall.modules.s3_datasets.db.dataset_models import S3Dataset
1212
from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, ConfidentialityClassification
1313
from dataall.modules.s3_datasets.services.dataset_service import DatasetService
14+
from dataall.modules.s3_datasets.services.dataset_table_service import DatasetTableService
15+
from dataall.modules.s3_datasets.services.dataset_location_service import DatasetLocationService
16+
from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets, MetadataGenerationTypes
1417

1518
log = logging.getLogger(__name__)
1619

@@ -156,6 +159,59 @@ def list_datasets_owned_by_env_group(
156159
return DatasetService.list_datasets_owned_by_env_group(environmentUri, groupUri, filter)
157160

158161

162+
# @ResourceThresholdRepository.invocation_handler('generate_metadata_ai')
163+
# To make this treshold work treshold limits should be added on resource_treshold_repository into the resource paths dictionary.
164+
# as an example; 'nlq' : 'modules.worksheets.features.max_count_per_day' here max_count_per_day shall be defined for metadata generation
165+
# or it could be used as it is by using different key or even the same key after merge.
166+
@is_feature_enabled('modules.s3_datasets.features.generate_metadata_ai.active')
167+
def generate_metadata(
168+
context: Context,
169+
source: S3Dataset,
170+
resourceUri: str,
171+
targetType: str,
172+
version: int,
173+
metadataTypes: list,
174+
sampleData: dict = {},
175+
):
176+
RequestValidator.validate_uri(param_name='resourceUri', param_value=resourceUri)
177+
if metadataTypes not in [item.value for item in MetadataGenerationTypes]:
178+
raise InvalidInput(
179+
'metadataType',
180+
metadataTypes,
181+
f'a list of allowed values {[item.value for item in MetadataGenerationTypes]}',
182+
)
183+
# TODO validate sampleData and make it generic for S3
184+
if targetType == MetadataGenerationTargets.S3_Dataset.value:
185+
return DatasetService.generate_metadata_for_dataset(
186+
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes
187+
)
188+
elif targetType == MetadataGenerationTargets.Table.value:
189+
return DatasetTableService.generate_metadata_for_table(
190+
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes, sampleData=sampleData
191+
)
192+
elif targetType == MetadataGenerationTargets.Folder.value:
193+
return DatasetLocationService.generate_metadata_for_folder(
194+
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes
195+
)
196+
else:
197+
raise Exception('Unsupported target type for metadata generation')
198+
199+
200+
def read_sample_data(context: Context, source: S3Dataset, tableUri: str):
201+
RequestValidator.validate_uri(param_name='tableUri', param_value=tableUri)
202+
return DatasetTableService.preview(uri=tableUri)
203+
204+
205+
def update_dataset_metadata(context: Context, source: S3Dataset, resourceUri: str):
206+
return DatasetService.update_dataset(uri=resourceUri, data=input)
207+
208+
209+
def list_dataset_tables_folders(context: Context, source: S3Dataset, datasetUri: str, filter: dict = None):
210+
if not filter:
211+
filter = {}
212+
return DatasetService.list_dataset_tables_folders(dataset_uri=datasetUri, filter=filter)
213+
214+
159215
class RequestValidator:
160216
@staticmethod
161217
def validate_creation_request(data):
@@ -200,6 +256,18 @@ def validate_share_expiration_request(data):
200256
'is of invalid type',
201257
)
202258

259+
@staticmethod
260+
def validate_uri(param_name: str, param_value: str):
261+
if not param_value:
262+
raise RequiredParameter(param_name)
263+
pattern = r'^[a-z0-9]{8}$'
264+
if not re.match(pattern, param_value):
265+
raise InvalidInput(
266+
param_name=param_name,
267+
param_value=param_value,
268+
constraint='8 characters long and contain only lowercase letters and numbers',
269+
)
270+
203271
@staticmethod
204272
def validate_import_request(data):
205273
RequestValidator.validate_creation_request(data)

backend/dataall/modules/s3_datasets/api/dataset/types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,47 @@
140140
gql.Field(name='status', type=gql.String),
141141
],
142142
)
143+
SubitemDescription = gql.ObjectType(
144+
name='SubitemDescription',
145+
fields=[
146+
gql.Field(name='label', type=gql.String),
147+
gql.Field(name='description', type=gql.String),
148+
gql.Field(name='subitem_id', type=gql.String),
149+
],
150+
)
151+
GeneratedMetadata = gql.ObjectType(
152+
name='GeneratedMetadata',
153+
fields=[
154+
gql.Field(name='type', type=gql.String), # Table, Column, Folder, Dataset
155+
gql.Field(name='label', type=gql.String),
156+
gql.Field(name='topics', type=gql.ArrayType(gql.String)),
157+
gql.Field(name='tags', type=gql.ArrayType(gql.String)),
158+
gql.Field(name='description', type=gql.String),
159+
gql.Field(name='name', type=gql.String),
160+
gql.Field(name='subitem_descriptions', type=gql.ArrayType(gql.Ref('SubitemDescription'))),
161+
],
162+
)
163+
164+
DatasetItem = gql.ObjectType(
165+
name='DatasetItem',
166+
fields=[
167+
gql.Field(name='name', type=gql.String),
168+
gql.Field(name='targetType', type=gql.String),
169+
gql.Field(name='targetUri', type=gql.String),
170+
],
171+
)
172+
173+
DatasetItemsSearchResult = gql.ObjectType(
174+
name='DatasetItemsSearchResult',
175+
fields=[
176+
gql.Field(name='count', type=gql.Integer),
177+
gql.Field(name='nodes', type=gql.ArrayType(DatasetItem)),
178+
gql.Field(name='pageSize', type=gql.Integer),
179+
gql.Field(name='nextPage', type=gql.Integer),
180+
gql.Field(name='pages', type=gql.Integer),
181+
gql.Field(name='page', type=gql.Integer),
182+
gql.Field(name='previousPage', type=gql.Integer),
183+
gql.Field(name='hasNext', type=gql.Boolean),
184+
gql.Field(name='hasPrevious', type=gql.Boolean),
185+
],
186+
)

backend/dataall/modules/s3_datasets/api/table_column/input_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,11 @@
1818
gql.Argument('topics', gql.Integer),
1919
],
2020
)
21+
SubitemDescription = gql.InputType(
22+
name='SubitemDescriptionInput',
23+
arguments=[
24+
gql.Argument(name='label', type=gql.String),
25+
gql.Argument(name='description', type=gql.String),
26+
gql.Argument(name='subitem_id', type=gql.String),
27+
],
28+
)

backend/dataall/modules/s3_datasets/api/table_column/mutations.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataall.base.api import gql
2-
from dataall.modules.s3_datasets.api.table_column.resolvers import sync_table_columns, update_table_column
2+
from dataall.modules.s3_datasets.api.table_column.resolvers import (
3+
sync_table_columns,
4+
update_table_column,
5+
batch_update_table_columns_description,
6+
)
37

48
syncDatasetTableColumns = gql.MutationField(
59
name='syncDatasetTableColumns',
@@ -18,3 +22,9 @@
1822
type=gql.Ref('DatasetTableColumn'),
1923
resolver=update_table_column,
2024
)
25+
batchUpdateDatasetTableColumn = gql.MutationField(
26+
name='batchUpdateDatasetTableColumn',
27+
args=[gql.Argument(name='columns', type=gql.ArrayType(gql.Ref('SubitemDescriptionInput')))],
28+
type=gql.String,
29+
resolver=batch_update_table_columns_description,
30+
)

backend/dataall/modules/s3_datasets/api/table_column/resolvers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ def update_table_column(context: Context, source, columnUri: str = None, input:
4141

4242
description = input.get('description', 'No description provided')
4343
return DatasetColumnService.update_table_column_description(column_uri=columnUri, description=description)
44+
45+
46+
def batch_update_table_columns_description(context: Context, source, columns):
47+
if columns is None:
48+
return None
49+
return DatasetColumnService.batch_update_table_columns_description(columns=columns)

0 commit comments

Comments
 (0)