Skip to content

[PLT-1018] Add dataset.set_iam_integration() to select/deselect integrations #1622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions libs/labelbox/src/labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from labelbox.schema.identifiable import UniqueId, GlobalKey
from labelbox.schema.task import Task
from labelbox.schema.user import User
from labelbox.schema.iam_integration import IAMIntegration

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -912,3 +913,98 @@ def _convert_items_to_upsert_format(self, _items):
} # remove None values
_upsert_items.append(DataRowUpsertItem(payload=item, id=key))
return _upsert_items

def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration:
"""
Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets.

Args:
iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id.

Returns:
IAMIntegration: IAM integration object.

Raises:
LabelboxError: If the IAM integration can't be set.

Examples:

>>> # Get all IAM integrations
>>> iam_integrations = client.get_organization().get_iam_integrations()
>>>
>>> # Get IAM integration id
>>> iam_integration_id = [integration.uid for integration
>>> in iam_integrations
>>> if integration.name == "My S3 integration"][0]
>>>
>>> # Set IAM integration for integration id
>>> dataset.set_iam_integration(iam_integration_id)
>>>
>>> # Get IAM integration object
>>> iam_integration = [integration.uid for integration
>>> in iam_integrations
>>> if integration.name == "My S3 integration"][0]
>>>
>>> # Set IAM integration for IAMIntegrtion object
>>> dataset.set_iam_integration(iam_integration)
"""

iam_integration_id = iam_integration.uid if isinstance(iam_integration, IAMIntegration) else iam_integration

query = """
mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) {
setSignerForDataset(
data: { signerId: $signerId }
where: { id: $datasetId }
) {
id
signer {
id
}
}
}
"""

response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid})

if not response:
raise ResourceNotFoundError(IAMIntegration, {"signerId": iam_integration_id, "datasetId": self.uid})

try:
iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"]
return [integration for integration
in self.client.get_organization().get_iam_integrations()
if integration.uid == iam_integration_id][0]
except:
raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}")

def remove_iam_integration(self) -> None:
"""
Unsets the IAM integration for the dataset.

Args:
None

Returns:
None

Raises:
LabelboxError: If the IAM integration can't be unset.

Examples:
>>> dataset.remove_iam_integration()
"""

query = """
mutation DetachSignerPyApi($id: ID!) {
clearSignerForDataset(where: { id: $id }) {
id
}
}
"""

response = self.client.execute(query, {"id": self.uid})

if not response:
raise ResourceNotFoundError(Dataset, {"id": self.uid})

112 changes: 112 additions & 0 deletions libs/labelbox/tests/integration/test_delegated_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import requests
import pytest
import uuid

from labelbox import Client

Expand Down Expand Up @@ -77,3 +78,114 @@ def test_no_default_integration(client):
ds = client.create_dataset(name="new_ds")
assert ds.iam_integration() is None
ds.delete()


@pytest.mark.skip(
reason=
"Google credentials are being updated for this test, disabling till it's all sorted out"
)
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
reason="DA_GCP_LABELBOX_API_KEY not found")
def test_add_integration_from_object():
"""
This test is based on test_non_default_integration() and assumes the following:

1. aws delegated access is configured to work with lbox-test-bucket
2. an integration called aws is available to the org

Currently tests against:
Org ID: cl26d06tk0gch10901m7jeg9v
Email: jtso+aws_sdk_tests@labelbox.com
"""
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
integrations = client.get_organization().get_iam_integrations()

# Prepare dataset with no integration
integration = [
integration for integration
in integrations
if 'aws-da-test-bucket' in integration.name][0]

ds = client.create_dataset(iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}")

# Test set integration with object
new_integration = ds.add_iam_integration(integration)
assert new_integration == integration

# Cleaning
ds.delete()

@pytest.mark.skip(
reason=
"Google credentials are being updated for this test, disabling till it's all sorted out"
)
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
reason="DA_GCP_LABELBOX_API_KEY not found")
def test_add_integration_from_uid():
"""
This test is based on test_non_default_integration() and assumes the following:

1. aws delegated access is configured to work with lbox-test-bucket
2. an integration called aws is available to the org

Currently tests against:
Org ID: cl26d06tk0gch10901m7jeg9v
Email: jtso+aws_sdk_tests@labelbox.com
"""
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
integrations = client.get_organization().get_iam_integrations()

# Prepare dataset with no integration
integration = [
integration for integration
in integrations
if 'aws-da-test-bucket' in integration.name][0]

ds = client.create_dataset(iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}")

# Test set integration with integration id
integration_id = [
integration.uid for integration
in integrations
if 'aws-da-test-bucket' in integration.name][0]

new_integration = ds.add_iam_integration(integration_id)
assert new_integration == integration

# Cleaning
ds.delete()

@pytest.mark.skip(
reason=
"Google credentials are being updated for this test, disabling till it's all sorted out"
)
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
reason="DA_GCP_LABELBOX_API_KEY not found")
def test_integration_remove():
"""
This test is based on test_non_default_integration() and assumes the following:

1. aws delegated access is configured to work with lbox-test-bucket
2. an integration called aws is available to the org

Currently tests against:
Org ID: cl26d06tk0gch10901m7jeg9v
Email: jtso+aws_sdk_tests@labelbox.com
"""
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
integrations = client.get_organization().get_iam_integrations()

# Prepare dataset with an existing integration
integration = [
integration for integration
in integrations
if 'aws-da-test-bucket' in integration.name][0]

ds = client.create_dataset(iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}")

# Test unset integration
ds.remove_iam_integration()
assert ds.iam_integration() is None

# Cleaning
ds.delete()