Skip to content

remove ml config index creation from cron job #3850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private void startSyncModelRoutingCron() {
log.info("Starting ML sync up job...");
syncModelRoutingCron = threadPool
.scheduleWithFixedDelay(
new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting),
new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, mlFeatureEnabledSetting),
TimeValue.timeValueSeconds(jobInterval),
GENERAL_THREAD_POOL
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

package org.opensearch.ml.cluster;

import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

Expand All @@ -23,14 +20,10 @@
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
Expand All @@ -46,7 +39,6 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
Expand All @@ -58,7 +50,6 @@
import org.opensearch.transport.client.Client;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand All @@ -71,7 +62,6 @@ public class MLSyncUpCron implements Runnable {
private ClusterService clusterService;
private DiscoveryNodeHelper nodeHelper;
private MLIndicesHandler mlIndicesHandler;
private Encryptor encryptor;
private volatile Boolean mlConfigInited;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
@VisibleForTesting
Expand All @@ -83,7 +73,6 @@ public MLSyncUpCron(
ClusterService clusterService,
DiscoveryNodeHelper nodeHelper,
MLIndicesHandler mlIndicesHandler,
Encryptor encryptor,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
this.client = client;
Expand All @@ -93,13 +82,11 @@ public MLSyncUpCron(
this.mlIndicesHandler = mlIndicesHandler;
this.updateModelStateSemaphore = new Semaphore(1);
this.mlConfigInited = false;
this.encryptor = encryptor;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public void run() {
initMLConfig();
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX)) {
// no need to run sync up job if no model index
log.info("Skipping sync up job - ML model index not found");
Expand Down Expand Up @@ -241,45 +228,6 @@ private void undeployExpiredModels(
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
}

@VisibleForTesting
void initMLConfig() {
if (mlConfigInited || mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
return;
}
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
if (!r) {
log.error("Failed to initialize or update ML Config index");
return;
}
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (!getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String masterKey = encryptor.generateMasterKey();
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
indexRequest.opType(DocWriteRequest.OpType.CREATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
log.info("ML configuration initialized successfully");
// as this method is not being used for multi-tenancy use case, we are setting
// tenant id null by default
encryptor.setMasterKey(null, masterKey);
mlConfigInited = true;
}, e -> { log.debug("Failed to save ML encryption master key", e); }));
} else {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
// as this method is not being used for multi-tenancy use case, we are setting
// tenant id null by default
encryptor.setMasterKey(null, masterKey);
mlConfigInited = true;
log.info("ML configuration already initialized, no action needed");
}
}, e -> { log.debug("Failed to get ML encryption master key", e); }));
}
}, e -> { log.debug("Failed to init ML config index", e); }));
}

@VisibleForTesting
void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Set<String>> deployingModels) {
if (!updateModelStateSemaphore.tryAcquire()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;
Expand All @@ -34,16 +30,13 @@
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.lucene.search.TotalHits;
import org.junit.Assert;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
Expand Down Expand Up @@ -71,8 +64,6 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.remote.metadata.client.SdkClient;
Expand All @@ -87,7 +78,6 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.Client;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

public class MLSyncUpCronTests extends OpenSearchTestCase {
Expand All @@ -112,7 +102,6 @@ public class MLSyncUpCronTests extends OpenSearchTestCase {
private final String mlNode2Id = "mlNode2";

private ClusterState testState;
private Encryptor encryptor;

@Mock
ThreadPool threadPool;
Expand All @@ -124,64 +113,17 @@ public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="));

testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
ActionListener<Boolean> actionListener = invocation.getArgument(0);
actionListener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConfigIndex(any());

Settings settings = Settings.builder().build();
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));
threadContext = new ThreadContext(settings);
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting);
}

public void testInitMlConfig_MasterKeyNotExist() {
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
GetResponse response = mock(GetResponse.class);
when(response.isExists()).thenReturn(false);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
IndexResponse indexResponse = mock(IndexResponse.class);
listener.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());

syncUpCron.initMLConfig();
Assert.assertNotNull(encryptor.encrypt("test", null));
syncUpCron.initMLConfig();
verify(encryptor, times(1)).setMasterKey(any(), any());
}

public void testInitMlConfig_MasterKeyExists() {
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
GetResponse response = mock(GetResponse.class);
when(response.isExists()).thenReturn(true);
String masterKey = encryptor.generateMasterKey();
when(response.getSourceAsMap())
.thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

syncUpCron.initMLConfig();
Assert.assertNotNull(encryptor.encrypt("test", null));
syncUpCron.initMLConfig();
verify(encryptor, times(1)).setMasterKey(any(), any());
syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, mlFeatureEnabledSetting);
}

public void testRun_NoMLModelIndex() {
Expand Down
Loading