Skip to content

Commit b96a622

Browse files
committed
Add dataset.set_iam_integration() to select/deselect integrations
1 parent 0874c15 commit b96a622

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from labelbox.schema.identifiable import UniqueId, GlobalKey
3232
from labelbox.schema.task import Task
3333
from labelbox.schema.user import User
34+
from labelbox.schema.iam_integration import IAMIntegration
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -912,3 +913,80 @@ def _convert_items_to_upsert_format(self, _items):
912913
} # remove None values
913914
_upsert_items.append(DataRowUpsertItem(payload=item, id=key))
914915
return _upsert_items
916+
917+
def set_iam_integration(self, iam_integration: Union[str, IAMIntegration, None] = None) -> Optional[IAMIntegration]:
918+
"""
919+
Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets.
920+
921+
>>> # Get all IAM integrations
922+
>>> iam_integrations = client.get_organization().get_iam_integrations()
923+
>>>
924+
>>> # Get IAM integration id
925+
>>> iam_integration_id = [integration.uid for integration
926+
>>> in iam_integrations
927+
>>> if integration.name == "My S3 integration"][0]
928+
>>>
929+
>>> # Set IAM integration for integration id
930+
>>> dataset.set_iam_integration(iam_integration_id)
931+
>>>
932+
>>> # Get IAM integration object
933+
>>> iam_integration = [integration.uid for integration
934+
>>> in iam_integrations
935+
>>> if integration.name == "My S3 integration"][0]
936+
>>>
937+
>>> # Set IAM integration for integration object
938+
>>> dataset.set_iam_integration(iam_integration)
939+
>>>
940+
>>> # Unset IAM integration
941+
>>> dataset.set_iam_integration()
942+
"""
943+
944+
# Unset IAM integration if iam_integration is None
945+
if iam_integration is None:
946+
query = """mutation DetachSignerPyApi($id: ID!) {
947+
clearSignerForDataset(where: {id: $id}) {
948+
id
949+
signer {
950+
id
951+
}
952+
}
953+
}"""
954+
response = self.client.execute(query, {"id": self.uid})
955+
956+
if response:
957+
return response["clearSignerForDataset"]["signer"]
958+
else:
959+
raise lb.exceptions.LabelboxError("Can't unset IAM integration")
960+
961+
else:
962+
963+
if isinstance(iam_integration, IAMIntegration):
964+
iam_integration_id = iam_integration.uid
965+
else:
966+
iam_integration_id = iam_integration
967+
968+
query = """mutation AttachSignerPyApi($signerId: ID!, $datasetId: ID!) {
969+
setSignerForDataset(data: {signerId: $signerId}, where: {id: $datasetId}) {
970+
id
971+
signer {
972+
id
973+
}
974+
}
975+
}"""
976+
response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid})
977+
978+
# Return IAM Integration object if
979+
if response:
980+
try:
981+
982+
iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"]
983+
return [integration for integration
984+
in self.client.get_organization().get_iam_integrations()
985+
if integration.uid == iam_integration_id][0]
986+
except:
987+
raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}")
988+
989+
else:
990+
raise LabelboxError(f"Can't set IAM integration {iam_integration_id}")
991+
992+
return response

libs/labelbox/tests/integration/test_delegated_access.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import requests
44
import pytest
5+
import uuid
56

67
from labelbox import Client
78

@@ -77,3 +78,52 @@ def test_no_default_integration(client):
7778
ds = client.create_dataset(name="new_ds")
7879
assert ds.iam_integration() is None
7980
ds.delete()
81+
82+
83+
@pytest.mark.skip(
84+
reason=
85+
"Google credentials are being updated for this test, disabling till it's all sorted out"
86+
)
87+
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
88+
reason="DA_GCP_LABELBOX_API_KEY not found")
89+
def test_integration_change():
90+
"""
91+
This test is based on test_non_default_integration() and assumes the following:
92+
93+
1. aws delegated access is configured to work with lbox-test-bucket
94+
2. an integration called aws is available to the org
95+
96+
Currently tests against:
97+
Org ID: cl26d06tk0gch10901m7jeg9v
98+
Email: jtso+aws_sdk_tests@labelbox.com
99+
"""
100+
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
101+
integrations = client.get_organization().get_iam_integrations()
102+
103+
# Prepare dataset with an existing integration
104+
integration = [
105+
integration for integration
106+
in integrations
107+
if 'aws-da-test-bucket' in integration.name][0]
108+
109+
ds = client.create_dataset(iam_integration=integration, name=f"integration_change-{uuid.uuid4()}")
110+
111+
# Test unset integration
112+
ds.set_iam_integration()
113+
assert ds.iam_integration() is None
114+
115+
# Test set integration with object
116+
new_integration = ds.set_iam_integration(integration)
117+
assert new_integration == integration
118+
119+
# Test set integration with integration id
120+
integration_id = [
121+
integration.uid for integration
122+
in integrations
123+
if 'aws-da-test-bucket' in integration.name][0]
124+
125+
new_integration = ds.set_iam_integration(integration_id)
126+
assert new_integration == integration
127+
128+
# Cleaning
129+
ds.delete()

0 commit comments

Comments
 (0)