diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 413793a58a..10615d3985 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -102,7 +102,7 @@ private void startSyncModelRoutingCron() { log.info("Starting ML sync up job..."); syncModelRoutingCron = threadPool .scheduleWithFixedDelay( - new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting), + new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, mlFeatureEnabledSetting), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL ); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index c0b72c0ec9..b5b17a3629 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -5,9 +5,6 @@ package org.opensearch.ml.cluster; -import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; -import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; @@ -23,14 +20,10 @@ import java.util.stream.Collectors; import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; @@ -46,7 +39,6 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; -import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.remote.metadata.client.BulkDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; @@ -58,7 +50,6 @@ import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -71,7 +62,6 @@ public class MLSyncUpCron implements Runnable { private ClusterService clusterService; private DiscoveryNodeHelper nodeHelper; private MLIndicesHandler mlIndicesHandler; - private Encryptor encryptor; private volatile Boolean mlConfigInited; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @VisibleForTesting @@ -83,7 +73,6 @@ public MLSyncUpCron( ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, - Encryptor encryptor, MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; @@ -93,13 +82,11 @@ public MLSyncUpCron( this.mlIndicesHandler = mlIndicesHandler; this.updateModelStateSemaphore = new Semaphore(1); this.mlConfigInited = false; - this.encryptor = encryptor; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override public void run() { - initMLConfig(); if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX)) { // no need to run sync up job if no model index log.info("Skipping sync up job - ML model index not found"); @@ -241,45 +228,6 @@ private void undeployExpiredModels( }, e -> { log.error("Failed to undeploy models {}", expiredModels, e); })); } - @VisibleForTesting - void initMLConfig() { - if (mlConfigInited || mlFeatureEnabledSetting.isMultiTenancyEnabled()) { - return; - } - mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { - if (!r) { - log.error("Failed to initialize or update ML Config index"); - return; - } - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (!getResponse.isExists()) { - IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - final String masterKey = encryptor.generateMasterKey(); - indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - indexRequest.opType(DocWriteRequest.OpType.CREATE); - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - log.info("ML configuration initialized successfully"); - // as this method is not being used for multi-tenancy use case, we are setting - // tenant id null by default - encryptor.setMasterKey(null, masterKey); - mlConfigInited = true; - }, e -> { log.debug("Failed to save ML encryption master key", e); })); - } else { - final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); - // as this method is not being used for multi-tenancy use case, we are setting - // tenant id null by default - encryptor.setMasterKey(null, masterKey); - mlConfigInited = true; - log.info("ML configuration already initialized, no action needed"); - } - }, e -> { log.debug("Failed to get ML encryption master key", e); })); - } - }, e -> { log.debug("Failed to init ML config index", e); })); - } - @VisibleForTesting void refreshModelState(Map> modelWorkerNodes, Map> deployingModels) { if (!updateModelStateSemaphore.tryAcquire()) { diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index f6c61e7799..6b4e22457c 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -10,14 +10,10 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; @@ -34,7 +30,6 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; -import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -42,8 +37,6 @@ import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; @@ -71,8 +64,6 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; -import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.remote.metadata.client.SdkClient; @@ -87,7 +78,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; public class MLSyncUpCronTests extends OpenSearchTestCase { @@ -112,7 +102,6 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private final String mlNode2Id = "mlNode2"; private ClusterState testState; - private Encryptor encryptor; @Mock ThreadPool threadPool; @@ -124,64 +113,17 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(0); - actionListener.onResponse(true); - return null; - }).when(mlIndicesHandler).initMLConfigIndex(any()); - Settings settings = Settings.builder().build(); sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting); - } - - public void testInitMlConfig_MasterKeyNotExist() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - listener.onResponse(response); - return null; - }).when(client).get(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - IndexResponse indexResponse = mock(IndexResponse.class); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - - syncUpCron.initMLConfig(); - Assert.assertNotNull(encryptor.encrypt("test", null)); - syncUpCron.initMLConfig(); - verify(encryptor, times(1)).setMasterKey(any(), any()); - } - - public void testInitMlConfig_MasterKeyExists() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); - String masterKey = encryptor.generateMasterKey(); - when(response.getSourceAsMap()) - .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); - listener.onResponse(response); - return null; - }).when(client).get(any(), any()); - - syncUpCron.initMLConfig(); - Assert.assertNotNull(encryptor.encrypt("test", null)); - syncUpCron.initMLConfig(); - verify(encryptor, times(1)).setMasterKey(any(), any()); + syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, mlFeatureEnabledSetting); } public void testRun_NoMLModelIndex() {