diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 90a201028..0bc5b74b1 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -31,6 +31,7 @@ from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.task import Task from labelbox.schema.user import User +from labelbox.schema.iam_integration import IAMIntegration logger = logging.getLogger(__name__) @@ -912,3 +913,99 @@ def _convert_items_to_upsert_format(self, _items): } # remove None values _upsert_items.append(DataRowUpsertItem(payload=item, id=key)) return _upsert_items + + def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration: + """ + Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. + + Args: + iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. + + Returns: + IAMIntegration: IAM integration object. + + Raises: + LabelboxError: If the IAM integration can't be set. + + Examples: + + >>> # Get all IAM integrations + >>> iam_integrations = client.get_organization().get_iam_integrations() + >>> + >>> # Get IAM integration id + >>> iam_integration_id = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for integration id + >>> dataset.set_iam_integration(iam_integration_id) + >>> + >>> # Get IAM integration object + >>> iam_integration = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for IAMIntegrtion object + >>> dataset.set_iam_integration(iam_integration) + """ + + iam_integration_id = iam_integration.uid if isinstance(iam_integration, IAMIntegration) else iam_integration + + query = """ + mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { + setSignerForDataset( + data: { signerId: $signerId } + where: { id: $datasetId } + ) { + id + signer { + id + } + } + } + """ + + response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid}) + + if not response: + raise ResourceNotFoundError(IAMIntegration, {"signerId": iam_integration_id, "datasetId": self.uid}) + + try: + iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"] + + return [integration for integration + in self.client.get_organization().get_iam_integrations() + if integration.uid == iam_integration_id][0] + except: + raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}") + + def remove_iam_integration(self) -> None: + """ + Unsets the IAM integration for the dataset. + + Args: + None + + Returns: + None + + Raises: + LabelboxError: If the IAM integration can't be unset. + + Examples: + >>> dataset.remove_iam_integration() + """ + + query = """ + mutation DetachSignerPyApi($id: ID!) { + clearSignerForDataset(where: { id: $id }) { + id + } + } + """ + + response = self.client.execute(query, {"id": self.uid}) + + if not response: + raise ResourceNotFoundError(Dataset, {"id": self.uid}) + \ No newline at end of file diff --git a/libs/labelbox/tests/integration/test_delegated_access.py b/libs/labelbox/tests/integration/test_delegated_access.py index 3f025a1ab..1592319d2 100644 --- a/libs/labelbox/tests/integration/test_delegated_access.py +++ b/libs/labelbox/tests/integration/test_delegated_access.py @@ -2,6 +2,7 @@ import requests import pytest +import uuid from labelbox import Client @@ -77,3 +78,114 @@ def test_no_default_integration(client): ds = client.create_dataset(name="new_ds") assert ds.iam_integration() is None ds.delete() + + +@pytest.mark.skip( + reason= + "Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found") +def test_add_integration_from_object(): + """ + This test is based on test_non_default_integration() and assumes the following: + + 1. aws delegated access is configured to work with lbox-test-bucket + 2. an integration called aws is available to the org + + Currently tests against: + Org ID: cl26d06tk0gch10901m7jeg9v + Email: jtso+aws_sdk_tests@labelbox.com + """ + client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) + integrations = client.get_organization().get_iam_integrations() + + # Prepare dataset with no integration + integration = [ + integration for integration + in integrations + if 'aws-da-test-bucket' in integration.name][0] + + ds = client.create_dataset(iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}") + + # Test set integration with object + new_integration = ds.add_iam_integration(integration) + assert new_integration == integration + + # Cleaning + ds.delete() + +@pytest.mark.skip( + reason= + "Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found") +def test_add_integration_from_uid(): + """ + This test is based on test_non_default_integration() and assumes the following: + + 1. aws delegated access is configured to work with lbox-test-bucket + 2. an integration called aws is available to the org + + Currently tests against: + Org ID: cl26d06tk0gch10901m7jeg9v + Email: jtso+aws_sdk_tests@labelbox.com + """ + client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) + integrations = client.get_organization().get_iam_integrations() + + # Prepare dataset with no integration + integration = [ + integration for integration + in integrations + if 'aws-da-test-bucket' in integration.name][0] + + ds = client.create_dataset(iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}") + + # Test set integration with integration id + integration_id = [ + integration.uid for integration + in integrations + if 'aws-da-test-bucket' in integration.name][0] + + new_integration = ds.add_iam_integration(integration_id) + assert new_integration == integration + + # Cleaning + ds.delete() + +@pytest.mark.skip( + reason= + "Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found") +def test_integration_remove(): + """ + This test is based on test_non_default_integration() and assumes the following: + + 1. aws delegated access is configured to work with lbox-test-bucket + 2. an integration called aws is available to the org + + Currently tests against: + Org ID: cl26d06tk0gch10901m7jeg9v + Email: jtso+aws_sdk_tests@labelbox.com + """ + client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) + integrations = client.get_organization().get_iam_integrations() + + # Prepare dataset with an existing integration + integration = [ + integration for integration + in integrations + if 'aws-da-test-bucket' in integration.name][0] + + ds = client.create_dataset(iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}") + + # Test unset integration + ds.remove_iam_integration() + assert ds.iam_integration() is None + + # Cleaning + ds.delete() \ No newline at end of file