Skip to content

Commit 0f34877

Browse files
authored
fix config index masterkey pull up for multi-tenancy (#3700)
* fix config index masterkey pull up for multi-tenancy Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * apply spotless Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * add more unit test Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comment Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comment Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * changing to OpenSearchException Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 5e67dc1 commit 0f34877

File tree

2 files changed

+262
-30
lines changed

2 files changed

+262
-30
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import javax.crypto.spec.SecretKeySpec;
2727

28+
import org.opensearch.OpenSearchException;
2829
import org.opensearch.OpenSearchStatusException;
2930
import org.opensearch.ResourceNotFoundException;
3031
import org.opensearch.action.get.GetResponse;
@@ -33,6 +34,7 @@
3334
import org.opensearch.common.util.concurrent.ThreadContext;
3435
import org.opensearch.core.action.ActionListener;
3536
import org.opensearch.core.common.Strings;
37+
import org.opensearch.core.rest.RestStatus;
3638
import org.opensearch.index.engine.VersionConflictEngineException;
3739
import org.opensearch.ml.common.exception.MLException;
3840
import org.opensearch.ml.engine.indices.MLIndicesHandler;
@@ -254,9 +256,20 @@ private void handleGetDataObjectSuccess(
254256
try {
255257
GetResponse getMasterKeyResponse = response.parser() == null ? null : GetResponse.fromXContent(response.parser());
256258
if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) {
257-
this.tenantMasterKeys
258-
.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) response.source().get(masterKeyId));
259-
log.info("ML encryption master key already initialized, no action needed");
259+
Map<String, Object> source = getMasterKeyResponse.getSourceAsMap();
260+
if (source != null) {
261+
Object keyValue = source.get(MASTER_KEY);
262+
if (keyValue instanceof String) {
263+
this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) keyValue);
264+
log.info("ML encryption master key already initialized, no action needed");
265+
} else {
266+
log.error("Master key not found or not a string for tenantId: {}, masterKeyId: {}", tenantId, masterKeyId);
267+
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
268+
}
269+
} else {
270+
log.error("Master key not found or not a string for tenantId: {}, masterKeyId: {}", tenantId, masterKeyId);
271+
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
272+
}
260273
latch.countDown();
261274
} else {
262275
initializeNewMasterKey(tenantId, masterKeyId, exceptionRef, latch, context);
@@ -351,7 +364,8 @@ private void handlePutDataObjectFailure(
351364
CountDownLatch latch
352365
) {
353366
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
354-
if (cause instanceof VersionConflictEngineException) {
367+
if (cause instanceof VersionConflictEngineException
368+
|| (cause instanceof OpenSearchException && ((OpenSearchException) cause).status() == RestStatus.CONFLICT)) {
355369
handleVersionConflict(tenantId, masterKeyId, context, exceptionRef, latch);
356370
} else {
357371
log.debug("Failed to index ML encryption master key to config index", cause);
@@ -416,9 +430,20 @@ private void handleVersionConflictResponse(
416430
} else {
417431
GetResponse getMasterKeyResponse = response1.parser() == null ? null : GetResponse.fromXContent(response1.parser());
418432
if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) {
419-
this.tenantMasterKeys
420-
.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) response1.source().get(masterKeyId));
421-
log.info("ML encryption master key already initialized, no action needed");
433+
Map<String, Object> source = getMasterKeyResponse.getSourceAsMap();
434+
if (source != null) {
435+
Object keyValue = source.get(MASTER_KEY);
436+
if (keyValue instanceof String) {
437+
this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) keyValue);
438+
log.info("ML encryption master key already initialized, no action needed");
439+
} else {
440+
log.error("Master key not found or not a string for tenantId: {}, masterKeyId: {}", tenantId, masterKeyId);
441+
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
442+
}
443+
} else {
444+
log.error("Master key not found or not a string for tenantId: {}, masterKeyId: {}", tenantId, masterKeyId);
445+
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
446+
}
422447
latch.countDown();
423448
} else {
424449
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));

ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java

Lines changed: 230 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,15 @@ public class EncryptorImplTest {
7676
ThreadContext threadContext;
7777
final String USER_STRING = "myuser|role1,role2|myTenant";
7878
final String TENANT_ID = "myTenant";
79+
final String GENERATED_MASTER_KEY = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
7980

8081
Encryptor encryptor;
8182

8283
@Before
8384
public void setUp() {
8485
MockitoAnnotations.openMocks(this);
8586
masterKey = new ConcurrentHashMap<>();
86-
masterKey.put(DEFAULT_TENANT_ID, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
87+
masterKey.put(DEFAULT_TENANT_ID, GENERATED_MASTER_KEY);
8788
sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap());
8889

8990
doAnswer(invocation -> {
@@ -483,8 +484,7 @@ public void initMasterKey_AddTenantMasterKeys() throws IOException {
483484
Assert.assertNotNull(tenantMasterKey);
484485

485486
// Ensure that the master key for this tenant matches the expected value
486-
String expectedMasterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID);
487-
Assert.assertEquals("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=", encryptor.getMasterKey(TENANT_ID));
487+
Assert.assertEquals(GENERATED_MASTER_KEY, encryptor.getMasterKey(TENANT_ID));
488488
}
489489

490490
@Test
@@ -514,24 +514,6 @@ public void encrypt_SdkClientPutDataObjectFailure() {
514514
encryptor.encrypt("test", null);
515515
}
516516

517-
@Test
518-
public void handleVersionConflictResponse_ShouldThrowException_WhenRetryFails() throws IOException {
519-
doAnswer(invocation -> {
520-
ActionListener<Boolean> actionListener = (ActionListener) invocation.getArgument(0);
521-
actionListener.onResponse(true);
522-
return null;
523-
}).when(mlIndicesHandler).initMLConfigIndex(any());
524-
525-
doAnswer(invocation -> {
526-
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
527-
actionListener.onFailure(new IOException("Failed to get master key"));
528-
return null;
529-
}).when(client).get(any(), any());
530-
531-
exceptionRule.expect(MLException.class);
532-
encryptor.encrypt("test", "someTenant");
533-
}
534-
535517
// Helper method to prepare a valid GetResponse
536518
private GetResponse prepareMLConfigResponse(String tenantId) throws IOException {
537519
// Compute the masterKeyId based on tenantId
@@ -543,8 +525,8 @@ private GetResponse prepareMLConfigResponse(String tenantId) throws IOException
543525
// Create the source map with the expected fields
544526
Map<String, Object> sourceMap = Map
545527
.of(
546-
masterKeyId,
547-
"m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=", // Valid MASTER_KEY for this tenant
528+
MASTER_KEY,
529+
GENERATED_MASTER_KEY, // Valid MASTER_KEY for this tenant
548530
CREATE_TIME_FIELD,
549531
Instant.now().toEpochMilli()
550532
);
@@ -565,6 +547,231 @@ private GetResponse prepareMLConfigResponse(String tenantId) throws IOException
565547
return new GetResponse(getResult);
566548
}
567549

550+
@Test
551+
public void encrypt_MasterKeyFieldMismatch_ShouldFallbackToProperKeyField() throws IOException {
552+
// This test simulates the case where the document ID is `master_key_<hash>`
553+
// but the actual `_source` only contains `master_key` (as expected in real DDB).
554+
555+
doAnswer(invocation -> {
556+
ActionListener<Boolean> actionListener = (ActionListener) invocation.getArgument(0);
557+
actionListener.onResponse(true); // init index success
558+
return null;
559+
}).when(mlIndicesHandler).initMLConfigIndex(any());
560+
561+
// Prepare a GetResponse where the _source has ONLY "master_key"
562+
Map<String, Object> sourceMap = Map.of(MASTER_KEY, GENERATED_MASTER_KEY, CREATE_TIME_FIELD, Instant.now().toEpochMilli());
563+
564+
XContentBuilder builder = XContentFactory.jsonBuilder();
565+
builder.startObject();
566+
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
567+
builder.field(entry.getKey(), entry.getValue());
568+
}
569+
builder.endObject();
570+
571+
BytesReference sourceBytes = BytesReference.bytes(builder);
572+
String masterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID); // Simulate full hashed ID
573+
GetResult getResult = new GetResult(ML_CONFIG_INDEX, masterKeyId, 1L, 1L, 1L, true, sourceBytes, null, null);
574+
GetResponse getResponse = new GetResponse(getResult);
575+
576+
// Simulate Get API call returning a GetResponse with only "master_key" field
577+
doAnswer(invocation -> {
578+
ActionListener<GetResponse> listener = invocation.getArgument(1);
579+
listener.onResponse(getResponse);
580+
return null;
581+
}).when(client).get(any(), any());
582+
583+
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
584+
585+
// Old buggy code would try to access response.source().get(masterKeyId) and get null
586+
// This test ensures the new fix works — we access MASTER_KEY properly
587+
String encrypted = encryptor.encrypt("test", TENANT_ID);
588+
Assert.assertNotNull(encrypted);
589+
Assert.assertEquals("test", encryptor.decrypt(encrypted, TENANT_ID));
590+
}
591+
592+
@Test
593+
public void encrypt_MasterKeyFieldExistsButNotString_ShouldThrowError() throws IOException {
594+
doAnswer(invocation -> {
595+
ActionListener<Boolean> actionListener = invocation.getArgument(0);
596+
actionListener.onResponse(true);
597+
return null;
598+
}).when(mlIndicesHandler).initMLConfigIndex(any());
599+
600+
// Prepare _source with a non-string master key
601+
Map<String, Object> sourceMap = Map
602+
.of(
603+
MASTER_KEY,
604+
12345, // wrong type
605+
CREATE_TIME_FIELD,
606+
Instant.now().toEpochMilli()
607+
);
608+
609+
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
610+
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
611+
builder.field(entry.getKey(), entry.getValue());
612+
}
613+
builder.endObject();
614+
615+
BytesReference sourceBytes = BytesReference.bytes(builder);
616+
String masterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID);
617+
GetResult getResult = new GetResult(ML_CONFIG_INDEX, masterKeyId, 1L, 1L, 1L, true, sourceBytes, null, null);
618+
GetResponse getResponse = new GetResponse(getResult);
619+
620+
doAnswer(invocation -> {
621+
ActionListener<GetResponse> listener = invocation.getArgument(1);
622+
listener.onResponse(getResponse);
623+
return null;
624+
}).when(client).get(any(), any());
625+
626+
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
627+
628+
exceptionRule.expect(ResourceNotFoundException.class);
629+
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);
630+
631+
encryptor.encrypt("test", TENANT_ID);
632+
}
633+
634+
@Test
635+
public void encrypt_MasterKeyFieldMissing_ShouldThrowError() throws IOException {
636+
doAnswer(invocation -> {
637+
ActionListener<Boolean> actionListener = invocation.getArgument(0);
638+
actionListener.onResponse(true);
639+
return null;
640+
}).when(mlIndicesHandler).initMLConfigIndex(any());
641+
642+
// _source does not include the "master_key" field
643+
Map<String, Object> sourceMap = Map.of(CREATE_TIME_FIELD, Instant.now().toEpochMilli());
644+
645+
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
646+
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
647+
builder.field(entry.getKey(), entry.getValue());
648+
}
649+
builder.endObject();
650+
651+
BytesReference sourceBytes = BytesReference.bytes(builder);
652+
String masterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID);
653+
GetResult getResult = new GetResult(ML_CONFIG_INDEX, masterKeyId, 1L, 1L, 1L, true, sourceBytes, null, null);
654+
GetResponse getResponse = new GetResponse(getResult);
655+
656+
doAnswer(invocation -> {
657+
ActionListener<GetResponse> listener = invocation.getArgument(1);
658+
listener.onResponse(getResponse);
659+
return null;
660+
}).when(client).get(any(), any());
661+
662+
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
663+
664+
exceptionRule.expect(ResourceNotFoundException.class);
665+
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);
666+
667+
encryptor.encrypt("test", TENANT_ID);
668+
}
669+
670+
@Test
671+
public void handleVersionConflictResponse_RetrySucceeds() throws IOException {
672+
// Simulate successful ML Config Index initialization
673+
doAnswer(invocation -> {
674+
ActionListener<Boolean> listener = invocation.getArgument(0);
675+
listener.onResponse(true);
676+
return null;
677+
}).when(mlIndicesHandler).initMLConfigIndex(any());
678+
679+
// First, simulate a version conflict on the initial PUT
680+
doAnswer(invocation -> {
681+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
682+
// Version conflict error is thrown
683+
listener.onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed"));
684+
return null;
685+
}).when(client).index(any(), any());
686+
687+
// Simulate that after the version conflict, the GET call returns a valid master key document.
688+
GetResponse validResponse = prepareMLConfigResponse(TENANT_ID);
689+
// This GET call will be triggered twice (once for the version conflict GET and again in the normal flow),
690+
// so we set up our stub to return a valid response each time.
691+
doAnswer(invocation -> {
692+
ActionListener<GetResponse> listener = invocation.getArgument(1);
693+
listener.onResponse(validResponse);
694+
return null;
695+
}).when(client).get(any(), any());
696+
697+
// Now run encryption; it should handle the version conflict by fetching the key, and then succeed.
698+
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
699+
// This will go through the PUT failure, then version conflict handling, and use the returned key.
700+
String encrypted = encryptor.encrypt("test", TENANT_ID);
701+
Assert.assertNotNull(encrypted);
702+
Assert.assertEquals("test", encryptor.decrypt(encrypted, TENANT_ID));
703+
}
704+
705+
@Test
706+
public void handleVersionConflictResponse_RetryFails() throws IOException {
707+
// Simulate successful ML Config Index initialization
708+
doAnswer(invocation -> {
709+
ActionListener<Boolean> listener = invocation.getArgument(0);
710+
listener.onResponse(true);
711+
return null;
712+
}).when(mlIndicesHandler).initMLConfigIndex(any());
713+
714+
// Simulate a version conflict on the initial PUT
715+
doAnswer(invocation -> {
716+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
717+
listener.onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed"));
718+
return null;
719+
}).when(client).index(any(), any());
720+
721+
// Simulate that the GET call in version conflict handling fails, e.g., by throwing an IOException.
722+
doAnswer(invocation -> {
723+
ActionListener<GetResponse> listener = invocation.getArgument(1);
724+
listener.onFailure(new IOException("Failed to get master key on retry"));
725+
return null;
726+
}).when(client).get(any(), any());
727+
728+
Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
729+
730+
// We expect an MLException (or a ResourceNotFoundException) to be thrown due to the failure in getting the key.
731+
exceptionRule.expect(MLException.class);
732+
exceptionRule.expectMessage("Failed to get master key"); // Or adjust based on your exact message.
733+
734+
encryptor.encrypt("test", TENANT_ID);
735+
}
736+
737+
@Test
738+
public void encrypt_GetSourceAsMapIsNull_ShouldThrowResourceNotFound() throws Exception {
739+
exceptionRule.expect(ResourceNotFoundException.class);
740+
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);
741+
742+
// Simulate ML config index init success
743+
doAnswer(invocation -> {
744+
ActionListener<Boolean> actionListener = (ActionListener) invocation.getArgument(0);
745+
actionListener.onResponse(true);
746+
return null;
747+
}).when(mlIndicesHandler).initMLConfigIndex(any());
748+
749+
// Create a GetResult with null sourceBytes
750+
String masterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID);
751+
GetResult getResult = new GetResult(
752+
ML_CONFIG_INDEX,
753+
masterKeyId,
754+
1L,
755+
1L,
756+
1L,
757+
true, // exists = true
758+
null, // sourceBytes = null => getSourceAsMap() will return null
759+
null,
760+
null
761+
);
762+
GetResponse getResponse = new GetResponse(getResult);
763+
764+
// Mock the get response
765+
doAnswer(invocation -> {
766+
ActionListener<GetResponse> listener = invocation.getArgument(1);
767+
listener.onResponse(getResponse);
768+
return null;
769+
}).when(client).get(any(), any());
770+
771+
// Now run it
772+
encryptor.encrypt("test", TENANT_ID);
773+
}
774+
568775
// Helper method to prepare a valid IndexResponse
569776
private IndexResponse prepareIndexResponse() {
570777
ShardId shardId = new ShardId(ML_CONFIG_INDEX, "index_uuid", 0);

0 commit comments

Comments
 (0)