Skip to content

Commit e3b5b38

Browse files
authored
[PLT-1018] Add dataset.set_iam_integration() to select/deselect integrations (#1622)
2 parents 126c0fb + 210edfd commit e3b5b38

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from labelbox.schema.identifiable import UniqueId, GlobalKey
3232
from labelbox.schema.task import Task
3333
from labelbox.schema.user import User
34+
from labelbox.schema.iam_integration import IAMIntegration
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -912,3 +913,99 @@ def _convert_items_to_upsert_format(self, _items):
912913
} # remove None values
913914
_upsert_items.append(DataRowUpsertItem(payload=item, id=key))
914915
return _upsert_items
916+
917+
def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration:
918+
"""
919+
Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets.
920+
921+
Args:
922+
iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id.
923+
924+
Returns:
925+
IAMIntegration: IAM integration object.
926+
927+
Raises:
928+
LabelboxError: If the IAM integration can't be set.
929+
930+
Examples:
931+
932+
>>> # Get all IAM integrations
933+
>>> iam_integrations = client.get_organization().get_iam_integrations()
934+
>>>
935+
>>> # Get IAM integration id
936+
>>> iam_integration_id = [integration.uid for integration
937+
>>> in iam_integrations
938+
>>> if integration.name == "My S3 integration"][0]
939+
>>>
940+
>>> # Set IAM integration for integration id
941+
>>> dataset.set_iam_integration(iam_integration_id)
942+
>>>
943+
>>> # Get IAM integration object
944+
>>> iam_integration = [integration.uid for integration
945+
>>> in iam_integrations
946+
>>> if integration.name == "My S3 integration"][0]
947+
>>>
948+
>>> # Set IAM integration for IAMIntegrtion object
949+
>>> dataset.set_iam_integration(iam_integration)
950+
"""
951+
952+
iam_integration_id = iam_integration.uid if isinstance(iam_integration, IAMIntegration) else iam_integration
953+
954+
query = """
955+
mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) {
956+
setSignerForDataset(
957+
data: { signerId: $signerId }
958+
where: { id: $datasetId }
959+
) {
960+
id
961+
signer {
962+
id
963+
}
964+
}
965+
}
966+
"""
967+
968+
response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid})
969+
970+
if not response:
971+
raise ResourceNotFoundError(IAMIntegration, {"signerId": iam_integration_id, "datasetId": self.uid})
972+
973+
try:
974+
iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"]
975+
976+
return [integration for integration
977+
in self.client.get_organization().get_iam_integrations()
978+
if integration.uid == iam_integration_id][0]
979+
except:
980+
raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}")
981+
982+
def remove_iam_integration(self) -> None:
983+
"""
984+
Unsets the IAM integration for the dataset.
985+
986+
Args:
987+
None
988+
989+
Returns:
990+
None
991+
992+
Raises:
993+
LabelboxError: If the IAM integration can't be unset.
994+
995+
Examples:
996+
>>> dataset.remove_iam_integration()
997+
"""
998+
999+
query = """
1000+
mutation DetachSignerPyApi($id: ID!) {
1001+
clearSignerForDataset(where: { id: $id }) {
1002+
id
1003+
}
1004+
}
1005+
"""
1006+
1007+
response = self.client.execute(query, {"id": self.uid})
1008+
1009+
if not response:
1010+
raise ResourceNotFoundError(Dataset, {"id": self.uid})
1011+

libs/labelbox/tests/integration/test_delegated_access.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import requests
44
import pytest
5+
import uuid
56

67
from labelbox import Client
78

@@ -77,3 +78,114 @@ def test_no_default_integration(client):
7778
ds = client.create_dataset(name="new_ds")
7879
assert ds.iam_integration() is None
7980
ds.delete()
81+
82+
83+
@pytest.mark.skip(
84+
reason=
85+
"Google credentials are being updated for this test, disabling till it's all sorted out"
86+
)
87+
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
88+
reason="DA_GCP_LABELBOX_API_KEY not found")
89+
def test_add_integration_from_object():
90+
"""
91+
This test is based on test_non_default_integration() and assumes the following:
92+
93+
1. aws delegated access is configured to work with lbox-test-bucket
94+
2. an integration called aws is available to the org
95+
96+
Currently tests against:
97+
Org ID: cl26d06tk0gch10901m7jeg9v
98+
Email: jtso+aws_sdk_tests@labelbox.com
99+
"""
100+
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
101+
integrations = client.get_organization().get_iam_integrations()
102+
103+
# Prepare dataset with no integration
104+
integration = [
105+
integration for integration
106+
in integrations
107+
if 'aws-da-test-bucket' in integration.name][0]
108+
109+
ds = client.create_dataset(iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}")
110+
111+
# Test set integration with object
112+
new_integration = ds.add_iam_integration(integration)
113+
assert new_integration == integration
114+
115+
# Cleaning
116+
ds.delete()
117+
118+
@pytest.mark.skip(
119+
reason=
120+
"Google credentials are being updated for this test, disabling till it's all sorted out"
121+
)
122+
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
123+
reason="DA_GCP_LABELBOX_API_KEY not found")
124+
def test_add_integration_from_uid():
125+
"""
126+
This test is based on test_non_default_integration() and assumes the following:
127+
128+
1. aws delegated access is configured to work with lbox-test-bucket
129+
2. an integration called aws is available to the org
130+
131+
Currently tests against:
132+
Org ID: cl26d06tk0gch10901m7jeg9v
133+
Email: jtso+aws_sdk_tests@labelbox.com
134+
"""
135+
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
136+
integrations = client.get_organization().get_iam_integrations()
137+
138+
# Prepare dataset with no integration
139+
integration = [
140+
integration for integration
141+
in integrations
142+
if 'aws-da-test-bucket' in integration.name][0]
143+
144+
ds = client.create_dataset(iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}")
145+
146+
# Test set integration with integration id
147+
integration_id = [
148+
integration.uid for integration
149+
in integrations
150+
if 'aws-da-test-bucket' in integration.name][0]
151+
152+
new_integration = ds.add_iam_integration(integration_id)
153+
assert new_integration == integration
154+
155+
# Cleaning
156+
ds.delete()
157+
158+
@pytest.mark.skip(
159+
reason=
160+
"Google credentials are being updated for this test, disabling till it's all sorted out"
161+
)
162+
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
163+
reason="DA_GCP_LABELBOX_API_KEY not found")
164+
def test_integration_remove():
165+
"""
166+
This test is based on test_non_default_integration() and assumes the following:
167+
168+
1. aws delegated access is configured to work with lbox-test-bucket
169+
2. an integration called aws is available to the org
170+
171+
Currently tests against:
172+
Org ID: cl26d06tk0gch10901m7jeg9v
173+
Email: jtso+aws_sdk_tests@labelbox.com
174+
"""
175+
client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY"))
176+
integrations = client.get_organization().get_iam_integrations()
177+
178+
# Prepare dataset with an existing integration
179+
integration = [
180+
integration for integration
181+
in integrations
182+
if 'aws-da-test-bucket' in integration.name][0]
183+
184+
ds = client.create_dataset(iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}")
185+
186+
# Test unset integration
187+
ds.remove_iam_integration()
188+
assert ds.iam_integration() is None
189+
190+
# Cleaning
191+
ds.delete()

0 commit comments

Comments
 (0)