diff --git a/common/build.gradle b/common/build.gradle index 9db59f5070..49bfbb857f 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -22,6 +22,8 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2' testImplementation "org.opensearch.test:framework:${opensearch_version}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0' compileOnly group: 'org.json', name: 'json', version: '20231013' diff --git a/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java new file mode 100644 index 0000000000..5d0978dbc2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +/** + * Accessor for resource sharing client + */ +public class ResourceSharingClientAccessor { + private ResourceSharingClient CLIENT; + + private static ResourceSharingClientAccessor resourceSharingClientAccessor; + + private ResourceSharingClientAccessor() {} + + public static ResourceSharingClientAccessor getInstance() { + if (resourceSharingClientAccessor == null) { + resourceSharingClientAccessor = new ResourceSharingClientAccessor(); + } + + return resourceSharingClientAccessor; + } + + /** + * Set the resource sharing client + */ + public void setResourceSharingClient(ResourceSharingClient client) { + resourceSharingClientAccessor.CLIENT = client; + } + + /** + * Get the resource sharing client + */ + public ResourceSharingClient getResourceSharingClient() { + return resourceSharingClientAccessor.CLIENT; + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java index cd2be26209..00d2700d54 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; @@ -16,6 +17,7 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.DocRequest; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; @@ -24,7 +26,7 @@ import lombok.Builder; import lombok.Getter; -public class MLModelGroupDeleteRequest extends ActionRequest { +public class MLModelGroupDeleteRequest extends ActionRequest implements DocRequest { @Getter String modelGroupId; @Getter @@ -78,4 +80,24 @@ public static MLModelGroupDeleteRequest fromActionRequest(ActionRequest actionRe throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupDeleteRequest", e); } } + + /** + * Get the index that this request operates on + * + * @return the index + */ + @Override + public String index() { + return ML_MODEL_GROUP_INDEX; + } + + /** + * Get the id of the document for this request + * + * @return the id + */ + @Override + public String id() { + return modelGroupId; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java index cc9dbcd444..4870302c78 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; @@ -16,6 +17,7 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.DocRequest; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; @@ -30,7 +32,7 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLModelGroupGetRequest extends ActionRequest { +public class MLModelGroupGetRequest extends ActionRequest implements DocRequest { String modelGroupId; String tenantId; @@ -83,4 +85,24 @@ public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionReque throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e); } } + + /** + * Get the index that this request operates on + * + * @return the index + */ + @Override + public String index() { + return ML_MODEL_GROUP_INDEX; + } + + /** + * Get the id of the document for this request + * + * @return the id + */ + @Override + public String id() { + return modelGroupId; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java index e130975c71..f4edb60cd5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; @@ -17,6 +18,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.DocRequest; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; @@ -32,7 +34,7 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLUpdateModelGroupRequest extends ActionRequest { +public class MLUpdateModelGroupRequest extends ActionRequest implements DocRequest { MLUpdateModelGroupInput updateModelGroupInput; @@ -80,4 +82,25 @@ public static MLUpdateModelGroupRequest fromActionRequest(ActionRequest actionRe } } + + /** + * Get the index that this request operates on + * + * @return the index + */ + @Override + public String index() { + return ML_MODEL_GROUP_INDEX; + } + + /** + * Get the id of the document for this request + * + * @return the id + */ + @Override + public String id() { + return updateModelGroupInput.getModelGroupID(); + + } } diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index b8f07c9238..1d4a05c910 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -73,7 +73,7 @@ public void toXContent() throws IOException { public void parse() throws IOException { String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"user_requested_tenant\":null,\"custom_attribute_names\":[]}," + "\"access\":\"PUBLIC\"}"; XContentParser parser = XContentType.JSON .xContent() @@ -176,7 +176,7 @@ public void toXContent_WithTenantId() throws IOException { public void parse_WithTenantId() throws IOException { String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null,\"custom_attribute_names\":[]}," + "\"access\":\"PUBLIC\",\"tenant_id\":\"test_tenant\"}"; XContentParser parser = XContentType.JSON @@ -201,7 +201,7 @@ public void parse_WithTenantId() throws IOException { public void parse_WithoutTenantId() throws IOException { String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"user_requested_tenant\":null,\"custom_attribute_names\":[]}," + "\"access\":\"PUBLIC\"}"; XContentParser parser = XContentType.JSON diff --git a/plugin/build.gradle b/plugin/build.gradle index 935491db0d..4d309512c9 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -44,11 +44,12 @@ opensearchplugin { name 'opensearch-ml' description 'machine learning plugin for opensearch' classname 'org.opensearch.ml.plugin.MachineLearningPlugin' - extendedPlugins = ['opensearch-job-scheduler'] + extendedPlugins = ['opensearch-job-scheduler', 'opensearch-security;optional=true'] } configurations { zipArchive + opensearchPlugin } dependencies { @@ -70,6 +71,10 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + opensearchPlugin "org.opensearch.plugin:opensearch-job-scheduler:${opensearch_build}@zip" + opensearchPlugin "org.opensearch.plugin:opensearch-security:${opensearch_build}@zip" + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" // Multi-tenant SDK Client @@ -232,38 +237,88 @@ integTest { testLogging.showStandardStreams = true } -testClusters.integTest { - testDistribution = "ARCHIVE" - // Cluster shrink exception thrown if we try to set numberOfNodes to 1, so only apply if > 1 - if (_numNodes > 1) numberOfNodes = _numNodes - // When running integration tests it doesn't forward the --debug-jvm to the cluster anymore - // i.e. we have to use a custom property to flag when we want to debug elasticsearch JVM - // since we also support multi node integration tests we increase debugPort per node - if (System.getProperty("cluster.debug") != null) { - def debugPort = 5005 - nodes.forEach { node -> - node.jvmArgs("-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=*:${debugPort}") - debugPort += 1 - } - } - plugin(project.tasks.bundlePlugin.archiveFile) - plugin(provider(new Callable(){ +// Resolve plugin zips +ext.resolvePluginFile = { pluginId -> + return new Callable() { @Override RegularFile call() throws Exception { return new RegularFile() { @Override File getAsFile() { - return configurations.zipArchive.asFileTree.getSingleFile() + return configurations.opensearchPlugin.resolvedConfiguration.resolvedArtifacts + .find { ResolvedArtifact f -> f.name.startsWith(pluginId) } + .file } } } - })) + } +} + +def jobSchedulerFile = resolvePluginFile("opensearch-job-scheduler") +def securityPluginFile = resolvePluginFile("opensearch-security") + +// Enable Security if -Dsecurity=true or -Dhttps=true +def securityEnabled = System.getProperty("security", "false") == "true" || + System.getProperty("https", "false") == "true" - nodes.each { node -> - def plugins = node.plugins - def firstPlugin = plugins.get(0) - plugins.remove(0) - plugins.add(firstPlugin) +// Single authoritative cluster definition +testClusters.integTest { + testDistribution = "ARCHIVE" + if (_numNodes > 1) numberOfNodes = _numNodes + configureClusterPlugins(delegate, jobSchedulerFile, securityPluginFile, securityEnabled) +} + +def configureClusterPlugins(cluster, jobSchedZip, securityZip, securityEnabled) { + // Enable attachable debugging if requested (prefer 'opensearch.debug'; also accept legacy 'cluster.debug') + if (System.getProperty("opensearch.debug") != null || System.getProperty("cluster.debug") != null) { + def debugPort = 5005 + cluster.nodes.each { node -> + // server=y (listen), suspend=n (don’t block startup) + node.jvmArgs "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:${debugPort}" + debugPort++ + } + } + + cluster.with { + // Install dependency plugins FIRST to avoid the reordering hack + plugin(provider(jobSchedZip)) + if (securityEnabled) { + plugin(provider(securityZip)) + } + // Install the plugin under test LAST + plugin(project.tasks.bundlePlugin.archiveFile) + + if (securityEnabled) { + nodes.each { node -> + node.extraConfigFile("kirk.pem", file("build/resources/test/kirk.pem")) + node.extraConfigFile("kirk-key.pem", file("build/resources/test/kirk-key.pem")) + node.extraConfigFile("esnode.pem", file("build/resources/test/esnode.pem")) + node.extraConfigFile("esnode-key.pem", file("build/resources/test/esnode-key.pem")) + node.extraConfigFile("root-ca.pem", file("build/resources/test/root-ca.pem")) + + node.setting "plugins.security.ssl.transport.pemcert_filepath", "esnode.pem" + node.setting "plugins.security.ssl.transport.pemkey_filepath", "esnode-key.pem" + node.setting "plugins.security.ssl.transport.pemtrustedcas_filepath", "root-ca.pem" + node.setting "plugins.security.ssl.transport.enforce_hostname_verification", "false" + + node.setting "plugins.security.ssl.http.enabled", "true" + node.setting "plugins.security.ssl.http.pemcert_filepath", "esnode.pem" + node.setting "plugins.security.ssl.http.pemkey_filepath", "esnode-key.pem" + node.setting "plugins.security.ssl.http.pemtrustedcas_filepath", "root-ca.pem" + + node.setting "plugins.security.allow_unsafe_democertificates", "true" + node.setting "plugins.security.allow_default_init_securityindex", "true" + node.setting "plugins.security.authcz.admin_dn", "\n - CN=kirk,OU=client,O=client,L=test,C=de" + node.setting "plugins.security.audit.type", "internal_opensearch" + node.setting "plugins.security.enable_snapshot_restore_privilege", "true" + node.setting "plugins.security.check_snapshot_restore_write_privileges","true" + node.setting "plugins.security.restapi.roles_enabled", "[\"all_access\", \"security_rest_api_access\"]" + node.setting "plugins.security.system_indices.enabled", "true" + if (System.getProperty("resource_sharing.enabled") == "true") { + node.setting "plugins.security.experimental.resource_sharing.enabled", "true" + } + } + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 92fd0228f1..2026fcd897 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -112,35 +112,41 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (hasPermission) { - if (mlModel.getModelState() != MLModelState.DEPLOYING) { - indexAndCreateController(mlModel, controller, wrappedListener); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLCreateControllerAction.NAME, + client, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + if (mlModel.getModelState() != MLModelState.DEPLOYING) { + indexAndCreateController(mlModel, controller, wrappedListener); + } else { + String errorMessage = + "Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it."; + errorMessage = getErrorMessage(errorMessage, modelId, isHidden); + log.error(errorMessage); + wrappedListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.CONFLICT)); + } } else { - String errorMessage = - "Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it."; + String errorMessage = "User doesn't have privilege to perform this operation on this model controller."; errorMessage = getErrorMessage(errorMessage, modelId, isHidden); log.error(errorMessage); - wrappedListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.CONFLICT)); + wrappedListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.FORBIDDEN)); } - } else { - String errorMessage = "User doesn't have privilege to perform this operation on this model controller."; - errorMessage = getErrorMessage(errorMessage, modelId, isHidden); - log.error(errorMessage); - wrappedListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.FORBIDDEN)); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to create the model controller. Details: {}", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to create the model controller. Details: {}", + modelId, + isHidden + ), + exception + ); + wrappedListener.onFailure(exception); + }) + ); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 3be5e07a0b..69b2663705 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -98,49 +98,55 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - mlModelManager - .getController( - modelId, - ActionListener - .wrap( - controller -> deleteControllerWithDeployedModel( + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLControllerDeleteAction.NAME, + client, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + mlModelManager + .getController( + modelId, + ActionListener + .wrap( + controller -> deleteControllerWithDeployedModel( + modelId, + mlModel.getIsHidden(), + wrappedListener + ), + deleteException -> { + log.error(deleteException); + wrappedListener.onFailure(deleteException); + } + ) + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + getErrorMessage( + "User doesn't have privilege to perform this operation on this model controller.", modelId, - mlModel.getIsHidden(), - wrappedListener + isHidden ), - deleteException -> { - log.error(deleteException); - wrappedListener.onFailure(deleteException); - } + RestStatus.FORBIDDEN ) + ); + } + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to delete the model controller with the provided model. Details: ", + modelId, + isHidden + ), + exception ); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - getErrorMessage( - "User doesn't have privilege to perform this operation on this model controller.", - modelId, - isHidden - ), - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to delete the model controller with the provided model. Details: ", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + wrappedListener.onFailure(exception); + }) + ); }, e -> { log .warn( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index d70488948f..9b9a230a7f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -96,34 +96,40 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - getErrorMessage( - "User doesn't have privilege to perform this operation on this model controller.", - modelId, - isHidden - ), - RestStatus.FORBIDDEN - ) + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLControllerGetAction.NAME, + client, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + getErrorMessage( + "User doesn't have privilege to perform this operation on this model controller.", + modelId, + isHidden + ), + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to create the model controller for the given model.", + modelId, + isHidden + ), + exception ); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to create the model controller for the given model.", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + wrappedListener.onFailure(exception); + }) + ); }, e -> wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index ac44069930..1385500551 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -104,51 +104,58 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (hasPermission) { - mlModelManager.getController(modelId, ActionListener.wrap(controller -> { - boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput); - controller.update(updateControllerInput); - updateController(mlModel, controller, isDeployRequiredAfterUpdate, wrappedListener); - }, e -> { - if (mlModel.getIsControllerEnabled() == null || !mlModel.getIsControllerEnabled()) { - final String errorMsg = getErrorMessage( - "Model controller haven't been created for the model. Consider calling create model controller api instead.", - modelId, - isHidden - ); - wrappedListener.onFailure(new OpenSearchStatusException(errorMsg, RestStatus.CONFLICT)); - log.error(errorMsg, e); - } else { - log.error(e); - wrappedListener.onFailure(e); - } - })); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - getErrorMessage( - "User doesn't have privilege to perform this operation on this model controller.", + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLUpdateControllerAction.NAME, + client, + + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + mlModelManager.getController(modelId, ActionListener.wrap(controller -> { + boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput); + controller.update(updateControllerInput); + updateController(mlModel, controller, isDeployRequiredAfterUpdate, wrappedListener); + }, e -> { + if (mlModel.getIsControllerEnabled() == null || !mlModel.getIsControllerEnabled()) { + final String errorMsg = getErrorMessage( + "Model controller haven't been created for the model. Consider calling create model controller api instead.", modelId, isHidden - ), - RestStatus.FORBIDDEN - ) + ); + wrappedListener.onFailure(new OpenSearchStatusException(errorMsg, RestStatus.CONFLICT)); + log.error(errorMsg, e); + } else { + log.error(e); + wrappedListener.onFailure(e); + } + })); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + getErrorMessage( + "User doesn't have privilege to perform this operation on this model controller.", + modelId, + isHidden + ), + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to create the model controller for the model. Details: ", + modelId, + isHidden + ), + exception ); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to create the model controller for the model. Details: ", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + wrappedListener.onFailure(exception); + }) + ); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 9ae795438d..e48b415574 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -177,22 +177,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - deployModel(deployModelRequest, mlModel, modelId, tenantId, wrappedListener, listener); - } - }, e -> { - log.error(getErrorMessage("Failed to Validate Access for the given model", modelId, isHidden), e); - wrappedListener.onFailure(e); - })); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLDeployModelAction.NAME, + client, + + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + deployModel(deployModelRequest, mlModel, modelId, tenantId, wrappedListener, listener); + } + }, e -> { + log.error(getErrorMessage("Failed to Validate Access for the given model", modelId, isHidden), e); + wrappedListener.onFailure(e); + }) + ); } }, e -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 449a490ef3..1bd117ec92 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.search.TotalHits; @@ -20,6 +21,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -36,6 +38,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -51,7 +54,6 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.transport.client.Client; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; import lombok.extern.log4j.Log4j2; @@ -136,67 +138,48 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, request.source().fetchSource(rebuiltFetchSourceContext); final ActionListener doubleWrapperListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (modelAccessControlHelper.skipModelAccessControl(user) - || !MLIndicesHandler - .doesMultiTenantIndexExist( - clusterService, - mlFeatureEnabledSetting.isMultiTenancyEnabled(), - CommonValue.ML_MODEL_GROUP_INDEX - )) { + boolean skip = modelAccessControlHelper.skipModelAccessControl(user); + boolean hasIndex = MLIndicesHandler + .doesMultiTenantIndexExist( + clusterService, + mlFeatureEnabledSetting.isMultiTenancyEnabled(), + CommonValue.ML_MODEL_GROUP_INDEX + ); + boolean rsClientPresent = ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null; + if (skip || !hasIndex) { + // No gating at all SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest .builder() .indices(request.indices()) .searchSourceBuilder(request.source()) .tenantId(tenantId) .build(); + sdkClient .searchDataObjectAsync(searchDataObjectRequest) .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrapperListener)); + } else if (rsClientPresent) { + modelAccessControlHelper + .addAccessibleModelGroupsFilterAndSearch( + tenantId, + request, + sdkClient, + (ids) -> modelGroupGateAndSearch( + tenantId, + request, + sdkClient, + ids, + /*useBackendRoles*/ request.source(), + wrappedListener + ), + doubleWrapperListener + ); } else { - SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); - SearchRequest modelGroupSearchRequest = new SearchRequest(); - sourceBuilder.fetchSource(new String[] { MLModelGroup.MODEL_GROUP_ID_FIELD, }, null); - sourceBuilder.size(10000); - modelGroupSearchRequest.source(sourceBuilder); - modelGroupSearchRequest.indices(CommonValue.ML_MODEL_GROUP_INDEX); - ActionListener modelGroupSearchActionListener = ActionListener.wrap(r -> { - if (Optional - .ofNullable(r) - .map(SearchResponse::getHits) - .map(SearchHits::getTotalHits) - .map(TotalHits::value) - .orElse(0L) > 0) { - List modelGroupIds = new ArrayList<>(); - Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); }); - - request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds)); - } else { - log.debug("No model group found"); - request.source().query(rewriteQueryBuilder(request.source().query(), null)); - } - SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest - .builder() - .indices(request.indices()) - .searchSourceBuilder(request.source()) - .tenantId(tenantId) - .build(); - sdkClient - .searchDataObjectAsync(searchDataObjectRequest) - .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrapperListener)); - }, e -> { - log.error("Fail to search model groups!", e); - wrappedListener.onFailure(e); - }); - SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest - .builder() - .indices(modelGroupSearchRequest.indices()) - .searchSourceBuilder(modelGroupSearchRequest.source()) - .tenantId(tenantId) - .build(); - sdkClient - .searchDataObjectAsync(searchDataObjectRequest) - .whenComplete(SdkClientUtils.wrapSearchCompletion(modelGroupSearchActionListener)); + // Legacy backend-roles/owner path + SearchSourceBuilder searchSourceBuilder = modelAccessControlHelper + .addUserBackendRolesFilter(user, new SearchSourceBuilder()); + modelGroupGateAndSearch(tenantId, request, sdkClient, /*modelGroupIds*/ null, searchSourceBuilder, doubleWrapperListener); } } catch (Exception e) { log.error(e.getMessage(), e); @@ -204,8 +187,68 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, } } - @VisibleForTesting - static QueryBuilder rewriteQueryBuilder(QueryBuilder queryBuilder, List modelGroupIds) { + public void modelGroupGateAndSearch( + String tenantId, + SearchRequest request, + SdkClient sdkClient, + @Nullable Set modelGroupIds, + SearchSourceBuilder sourceBuilder, + ActionListener wrappedListener + ) { + + // build discovery source + sourceBuilder.fetchSource(new String[] { MLModelGroup.MODEL_GROUP_ID_FIELD }, null); + sourceBuilder.size(10_000); + + if (modelGroupIds != null) { + // RSC pre-filter → merge as filter (doesn't affect scoring) + sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), modelGroupIds)); + } + + SearchRequest modelGroupSearchReq = new SearchRequest().indices(CommonValue.ML_MODEL_GROUP_INDEX).source(sourceBuilder); + + SearchDataObjectRequest mgSearch = SearchDataObjectRequest + .builder() + .indices(modelGroupSearchReq.indices()) + .searchSourceBuilder(modelGroupSearchReq.source()) + .tenantId(tenantId) + .build(); + + sdkClient.searchDataObjectAsync(mgSearch).whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.wrap(mgResp -> { + long total = Optional + .ofNullable(mgResp) + .map(SearchResponse::getHits) + .map(SearchHits::getTotalHits) + .map(TotalHits::value) + .orElse(0L); + + List mGIds = new ArrayList<>(); + if (total > 0) { + Arrays.stream(mgResp.getHits().getHits()).forEach(h -> mGIds.add(h.getId())); + } + + // Apply the model-group constraint to the ORIGINAL request + SearchSourceBuilder reqSrc = request.source() != null ? request.source() : new SearchSourceBuilder(); + reqSrc.query(rewriteQueryBuilder(reqSrc.query(), total > 0 ? mGIds : null)); + request.source(reqSrc); + + // Final search + SearchDataObjectRequest finalSearch = SearchDataObjectRequest + .builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .tenantId(tenantId) + .build(); + + sdkClient.searchDataObjectAsync(finalSearch).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); + + }, e -> { + log.error("Fail to search model groups!", e); + wrappedListener.onFailure(e); + }))); + } + + public static QueryBuilder rewriteQueryBuilder(QueryBuilder queryBuilder, List modelGroupIds) { ExistsQueryBuilder existsQueryBuilder = new ExistsQueryBuilder(MLModelGroup.MODEL_GROUP_ID_FIELD); BoolQueryBuilder modelGroupIdMustNotExistBoolQuery = new BoolQueryBuilder(); modelGroupIdMustNotExistBoolQuery.mustNot(existsQueryBuilder); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index b77b88ad1a..fbb87c0ff6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -27,6 +27,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; @@ -93,7 +94,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); + + // if resource sharing feature is enabled, access will be automatically checked by security plugin, so no need to check again + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + checkForAssociatedModels(modelGroupId, tenantId, wrappedListener); + } else { + validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); + } } } @@ -105,6 +112,7 @@ private void validateAndDeleteModelGroup(String modelGroupId, String tenantId, A mlFeatureEnabledSetting, tenantId, modelGroupId, + MLModelGroupDeleteAction.NAME, client, sdkClient, ActionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index f1dbe8be48..74fc87da30 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -27,6 +27,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; @@ -183,21 +184,35 @@ private void validateModelGroupAccess( MLModelGroup mlModelGroup, ActionListener wrappedListener ) { - modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model group", - RestStatus.FORBIDDEN - ) - ); - } else { - wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); - } - }, e -> { - log.error("Failed to validate access for Model Group {}", modelGroupId, e); - wrappedListener.onFailure(e); - })); + // if resource sharing feature is enabled, security plugin will have automatically evaluated access to this model group, hence no + // need to validate again + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); + return; + } + modelAccessControlHelper + .validateModelGroupAccess( + user, + modelGroupId, + MLModelGroupGetAction.NAME, + client, + + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model group", + RestStatus.FORBIDDEN + ) + ); + } else { + wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); + } + }, e -> { + log.error("Failed to validate access for Model Group {}", modelGroupId, e); + wrappedListener.onFailure(e); + }) + ); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 96af8cd317..22c2175f1d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,8 +6,11 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import java.util.Collections; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; @@ -17,15 +20,19 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.utils.PluginClient; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -35,6 +42,7 @@ @Log4j2 public class SearchModelGroupTransportAction extends HandledTransportAction { Client client; + PluginClient pluginClient; SdkClient sdkClient; ClusterService clusterService; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -49,10 +57,12 @@ public SearchModelGroupTransportAction( SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, - MLFeatureEnabledSetting mlFeatureEnabledSetting + MLFeatureEnabledSetting mlFeatureEnabledSetting, + PluginClient pluginClient ) { super(MLModelGroupSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.pluginClient = pluginClient; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; @@ -82,23 +92,53 @@ private void preProcessRoleAndPerformSearch( final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); + // If resource-sharing feature is enabled, we fetch accessible model-groups and restrict the search to those model-groups only. + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + // If a model-group is shared, then it will have been shared at-least at read access, hence the final result is guaranteed + // to only contain model-groups that the user at-least has read access to. + addAccessibleModelGroupsFilterAndSearch(tenantId, request, doubleWrappedListener); + // pluginClient.search(request, doubleWrappedListener); + return; + } if (!modelAccessControlHelper.skipModelAccessControl(user)) { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by {}", user.getBackendRoles()); } - SearchDataObjectRequest searchDataObjecRequest = SearchDataObjectRequest - .builder() - .indices(request.indices()) - .searchSourceBuilder(request.source()) - .tenantId(tenantId) - .build(); - sdkClient - .searchDataObjectAsync(searchDataObjecRequest) - .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener)); + search(tenantId, request, doubleWrappedListener); } catch (Exception e) { log.error("Failed to search", e); listener.onFailure(e); } } + + private void addAccessibleModelGroupsFilterAndSearch( + String tenantId, + SearchRequest request, + ActionListener wrappedListener + ) { + SearchSourceBuilder sourceBuilder = request.source() != null ? request.source() : new SearchSourceBuilder(); + ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + // filter by accessible model-groups + rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> { + sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), ids)); + request.source(sourceBuilder); + search(tenantId, request, wrappedListener); + }, e -> { + // Fail-safe: deny-all and still return a response + sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), Collections.emptySet())); + request.source(sourceBuilder); + search(tenantId, request, wrappedListener); + })); + } + + private void search(String tenantId, SearchRequest request, ActionListener listener) { + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .tenantId(tenantId) + .build(); + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(listener)); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index d3ab730f0d..abf8dfc33b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -35,6 +35,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -146,12 +147,23 @@ protected void doExecute(Task task, ActionRequest request, ActionListener feature is disabled, follow old route + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() == null) { + // TODO: At some point, this call must be replaced by the one above, (i.e. no user info to + // be stored in model-group index) + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { + validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + } + } + // For backwards compatibility we still allow storing backend_roles + // data in ml_model_group + // index updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); + } } catch (Exception e) { log.error("Failed to parse ml connector {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index c7f9d4ae18..01d1228ce9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -110,8 +110,6 @@ public class DeleteModelTransportAction extends HandledTransportAction isSafeDelete = it); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; @@ -215,42 +212,56 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else if (isModelNotDeployed(mlModelState)) { - if (isSafeDelete) { - // We only check downstream task when it's not hidden and cluster setting is true. - checkDownstreamTaskBeforeDeleteModel( - modelId, - tenantId, - mlModel.getAlgorithm().name(), - isHidden, - actionListener - ); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLModelDeleteAction.NAME, + client, + + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else if (isModelNotDeployed(mlModelState)) { + if (isSafeDelete) { + // We only check downstream task when it's not hidden and cluster setting is true. + checkDownstreamTaskBeforeDeleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } else { + deleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } + // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, + // actionListener); } else { - deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", + RestStatus.BAD_REQUEST + ) + ); } - // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", - RestStatus.BAD_REQUEST - ) - ); - } - }, e -> { - log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 64c9eb6676..641e8785f4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -20,7 +20,6 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -65,15 +64,12 @@ public class GetModelTransportAction extends HandledTransportAction { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - log.debug("Completed Get Model Request, id:{}", modelId); - Connector connector = mlModel.getConnector(); - if (connector != null) { - connector.removeCredential(); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + MLModelGetAction.NAME, + client, + + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); } - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } - }, e -> { - log.error("Failed to validate Access for Model Id {}", modelId, e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Id {}", modelId, e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 15dbee24b1..64f0ebcbe8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -81,7 +81,6 @@ public class UpdateModelTransportAction extends HandledTransportAction { final Client client; private final SdkClient sdkClient; - final Settings settings; final ClusterService clusterService; final ModelAccessControlHelper modelAccessControlHelper; final ConnectorAccessControlHelper connectorAccessControlHelper; @@ -115,7 +114,6 @@ public UpdateModelTransportAction( this.mlModelGroupManager = mlModelGroupManager; this.clusterService = clusterService; this.mlEngine = mlEngine; - this.settings = settings; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService @@ -174,6 +172,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -437,42 +436,49 @@ private void updateModelWithRegisteringToAnotherModelGroup( UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); if (newModelGroupId != null) { modelAccessControlHelper - .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { - if (hasNewModelGroupPermission) { - mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { - buildUpdateRequest( - modelId, - newModelGroupId, - updateRequest, - updateModelInput, - newModelGroupResponse, - wrappedListener, - isUpdateModelCache - ); - }, - exception -> wrappedListener + .validateModelGroupAccess( + user, + newModelGroupId, + MLUpdateModelAction.NAME, + client, + ActionListener.wrap(hasNewModelGroupPermission -> { + if (hasNewModelGroupPermission) { + mlModelGroupManager + .getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + buildUpdateRequest( + modelId, + newModelGroupId, + updateRequest, + updateModelInput, + newModelGroupResponse, + wrappedListener, + isUpdateModelCache + ); + }, + exception -> wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + + newModelGroupId, + RestStatus.NOT_FOUND + ) + ) + )); + } else { + wrappedListener .onFailure( new OpenSearchStatusException( - "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + newModelGroupId, - RestStatus.NOT_FOUND + RestStatus.FORBIDDEN ) - ) - )); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " - + newModelGroupId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - wrappedListener.onFailure(exception); - })); + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + wrappedListener.onFailure(exception); + }) + ); } else { buildUpdateRequest(modelId, tenantId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index ab5944db94..47c4a21743 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -136,6 +136,7 @@ public void onResponse(MLModel mlModel) { mlFeatureEnabledSetting, tenantId, mlModel.getModelGroupId(), + MLPredictionTaskAction.NAME, client, sdkClient, ActionListener.wrap(access -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 00c577c8d6..1cb0bc7850 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -211,6 +211,7 @@ private void checkUserAccess( mlFeatureEnabledSetting, registerModelInput.getTenantId(), registerModelInput.getModelGroupId(), + MLRegisterModelAction.NAME, client, sdkClient, ActionListener.wrap(access -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 762b807428..a96917e4f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -193,40 +193,50 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener getModelListener = ActionListener.wrap(model -> { - modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); - } else { - if (model.getConnector() != null) { - Connector connector = model.getConnector(); - executeConnector(connector, mlInput, actionListener); - } else if (MLIndicesHandler - .doesMultiTenantIndexExist( - clusterService, - mlFeatureEnabledSetting.isMultiTenancyEnabled(), - ML_CONNECTOR_INDEX - )) { - ActionListener listener = ActionListener - .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { - log.error("Failed to get connector {}", model.getConnectorId(), e); - actionListener.onFailure(e); - }); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper - .getConnector( - client, - model.getConnectorId(), - ActionListener.runBefore(listener, threadContext::restore) - ); + modelAccessControlHelper + .validateModelGroupAccess( + user, + model.getModelGroupId(), + MLCancelBatchJobAction.NAME, + client, + ActionListener.wrap(access -> { + if (!access) { + actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); + } else { + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, actionListener); + } else if (MLIndicesHandler + .doesMultiTenantIndexExist( + clusterService, + mlFeatureEnabledSetting.isMultiTenancyEnabled(), + ML_CONNECTOR_INDEX + )) { + ActionListener listener = ActionListener + .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { + log.error("Failed to get connector {}", model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try ( + ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext() + ) { + connectorAccessControlHelper + .getConnector( + client, + model.getConnectorId(), + ActionListener.runBefore(listener, threadContext::restore) + ); + } + } else { + actionListener + .onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } } - } else { - actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); - } - } - }, e -> { - log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); - actionListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); + actionListener.onFailure(e); + }) + ); }, e -> { log.error("Failed to retrieve the ML model with the given ID", e); actionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 674aa0244f..895a418012 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -373,6 +373,7 @@ private void processRemoteBatchPrediction( mlFeatureEnabledSetting, tenantId, model.getModelGroupId(), + MLTaskGetAction.NAME, client, sdkClient, ActionListener.wrap(access -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index f08749f758..a6ff369726 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -27,7 +27,6 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -81,7 +80,6 @@ public class TransportUndeployModelsAction extends HandledTransportAction { - if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - wrappedListener - .onFailure( - new IllegalArgumentException( - "You don't have permissions to perform this operation on this model." - ) - ); - } else { - existingModel.setModelId(r.getId()); - if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { - throw new Exception("Chunk number exceeds total chunks"); - } - byte[] bytes = uploadModelChunkInput.getContent(); - // Check the size of the content not to exceed 10 mb - if (bytes == null || bytes.length == 0) { - throw new Exception("Chunk size either 0 or null"); - } - if (validateChunkSize(bytes.length)) { - throw new Exception("Chunk size exceeds 10MB"); - } - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { - if (!res) { - wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); - return; - } - int chunkNum = uploadModelChunkInput.getChunkNumber(); - MLModel mlModel = MLModel - .builder() - .algorithm(existingModel.getAlgorithm()) - .modelGroupId(existingModel.getModelGroupId()) - .version(existingModel.getVersion()) - .modelId(existingModel.getModelId()) - .modelFormat(existingModel.getModelFormat()) - .totalChunks(existingModel.getTotalChunks()) - .algorithm(existingModel.getAlgorithm()) - .chunkNumber(chunkNum) - .content(Base64.getEncoder().encodeToString(bytes)) - .build(); - IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); - indexRequest - .source( - mlModel - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + .validateModelGroupAccess( + user, + existingModel.getModelGroupId(), + MLUploadModelChunkAction.NAME, + client, + ActionListener.wrap(access -> { + if (!access) { + log.error("You don't have permissions to perform this operation on this model."); + wrappedListener + .onFailure( + new IllegalArgumentException( + "You don't have permissions to perform this operation on this model." + ) ); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, ActionListener.wrap(response -> { - log - .info( - "Index model successful for {} for chunk number {}", - uploadModelChunkInput.getModelId(), - chunkNum + 1 + } else { + existingModel.setModelId(r.getId()); + if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { + throw new Exception("Chunk number exceeds total chunks"); + } + byte[] bytes = uploadModelChunkInput.getContent(); + // Check the size of the content not to exceed 10 mb + if (bytes == null || bytes.length == 0) { + throw new Exception("Chunk size either 0 or null"); + } + if (validateChunkSize(bytes.length)) { + throw new Exception("Chunk size exceeds 10MB"); + } + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + if (!res) { + wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); + return; + } + int chunkNum = uploadModelChunkInput.getChunkNumber(); + MLModel mlModel = MLModel + .builder() + .algorithm(existingModel.getAlgorithm()) + .modelGroupId(existingModel.getModelGroupId()) + .version(existingModel.getVersion()) + .modelId(existingModel.getModelId()) + .modelFormat(existingModel.getModelFormat()) + .totalChunks(existingModel.getTotalChunks()) + .algorithm(existingModel.getAlgorithm()) + .chunkNumber(chunkNum) + .content(Base64.getEncoder().encodeToString(bytes)) + .build(); + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest + .id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); + indexRequest + .source( + mlModel + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) ); - if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { - Semaphore semaphore = new Semaphore(1); - semaphore.acquire(); - MLModel mlModelMeta = MLModel - .builder() - .name(existingModel.getName()) - .algorithm(existingModel.getAlgorithm()) - .version(existingModel.getVersion()) - .modelGroupId((existingModel.getModelGroupId())) - .modelFormat(existingModel.getModelFormat()) - .modelState(MLModelState.REGISTERED) - .modelConfig(existingModel.getModelConfig()) - .totalChunks(existingModel.getTotalChunks()) - .modelContentHash(existingModel.getModelContentHash()) - .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) - .createdTime(existingModel.getCreatedTime()) - .build(); - IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); - indexReq.id(modelId); - indexReq - .source( - mlModelMeta - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(response -> { + log + .info( + "Index model successful for {} for chunk number {}", + uploadModelChunkInput.getModelId(), + chunkNum + 1 ); - indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexReq, ActionListener.wrap(re -> { - log.debug("Index model successful", existingModel.getName()); - semaphore.release(); - }, e -> { - log.error("Failed to update model state", e); - semaphore.release(); - wrappedListener.onFailure(e); - })); - } - wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); - }, e -> { - log.error("Failed to upload chunk model", e); - wrappedListener.onFailure(e); + if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { + Semaphore semaphore = new Semaphore(1); + semaphore.acquire(); + MLModel mlModelMeta = MLModel + .builder() + .name(existingModel.getName()) + .algorithm(existingModel.getAlgorithm()) + .version(existingModel.getVersion()) + .modelGroupId((existingModel.getModelGroupId())) + .modelFormat(existingModel.getModelFormat()) + .modelState(MLModelState.REGISTERED) + .modelConfig(existingModel.getModelConfig()) + .totalChunks(existingModel.getTotalChunks()) + .modelContentHash(existingModel.getModelContentHash()) + .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) + .createdTime(existingModel.getCreatedTime()) + .build(); + IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); + indexReq.id(modelId); + indexReq + .source( + mlModelMeta + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) + ); + indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexReq, ActionListener.wrap(re -> { + log.debug("Index model successful", existingModel.getName()); + semaphore.release(); + }, e -> { + log.error("Failed to update model state", e); + semaphore.release(); + wrappedListener.onFailure(e); + })); + } + wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); + }, e -> { + log.error("Failed to upload chunk model", e); + wrappedListener.onFailure(e); + })); + }, ex -> { + log.error("Failed to init model index", ex); + wrappedListener.onFailure(ex); })); - }, ex -> { - log.error("Failed to init model index", ex); - wrappedListener.onFailure(ex); - })); - } - }, e -> { - logException("Failed to validate model access", e, log); - wrappedListener.onFailure(e); - })); + } + }, e -> { + logException("Failed to validate model access", e, log); + wrappedListener.onFailure(e); + }) + ); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index ec3b67949c..b25f07d0ee 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -92,30 +92,38 @@ private void checkUserAccess( ) { User user = RestActionUtils.getUserContext(client); - modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (access) { - createModelGroup(mlUploadInput, listener); - return; - } - if (isModelNameAlreadyExisting) { - listener - .onFailure( - new IllegalArgumentException( - "The name {" - + mlUploadInput.getName() - + "} you provided is unavailable because it is used by another model group with id {" - + mlUploadInput.getModelGroupId() - + "} to which you do not have access. Please provide a different name." - ) - ); - } else { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } - }, e -> { - logException("Failed to validate model access", e, log); - listener.onFailure(e); - })); + modelAccessControlHelper + .validateModelGroupAccess( + user, + mlUploadInput.getModelGroupId(), + MLRegisterModelMetaAction.NAME, + client, + ActionListener.wrap(access -> { + if (access) { + createModelGroup(mlUploadInput, listener); + return; + } + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + mlUploadInput.getName() + + "} you provided is unavailable because it is used by another model group with id {" + + mlUploadInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } else { + log.error("You don't have permissions to perform this operation on this model."); + listener + .onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); + } + }, e -> { + logException("Failed to validate model access", e, log); + listener.onFailure(e); + }) + ); } private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionListener listener) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index c22ac5be48..2507dc2342 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -14,11 +14,16 @@ import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; @@ -26,6 +31,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -43,6 +49,7 @@ import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -50,8 +57,10 @@ import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableList; @@ -83,8 +92,29 @@ public ModelAccessControlHelper(ClusterService clusterService, Settings settings ); // TODO Eventually remove this when all usages of it have been migrated to the SdkClient version - public void validateModelGroupAccess(User user, String modelGroupId, Client client, ActionListener listener) { - if (modelGroupId == null || isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { + public void validateModelGroupAccess(User user, String modelGroupId, String action, Client client, ActionListener listener) { + if (modelGroupId == null) { + listener.onResponse(true); + return; + } + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { listener.onResponse(true); return; } @@ -128,13 +158,33 @@ public void validateModelGroupAccess( MLFeatureEnabledSetting mlFeatureEnabledSetting, String tenantId, String modelGroupId, + String action, Client client, SdkClient sdkClient, ActionListener listener ) { - if (modelGroupId == null - || (!mlFeatureEnabledSetting.isMultiTenancyEnabled() - && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)))) { + if (modelGroupId == null) { + listener.onResponse(true); + return; + } + if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (!mlFeatureEnabledSetting.isMultiTenancyEnabled() && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user))) { listener.onResponse(true); return; } @@ -149,7 +199,7 @@ public void validateModelGroupAccess( sdkClient.getDataObjectAsync(getModelGroupRequest).whenComplete((r, throwable) -> { if (throwable == null) { try { - GetResponse gr = r.getResponse(); + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); if (gr != null && gr.isExists()) { try ( XContentParser parser = jsonXContent @@ -311,7 +361,44 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil return searchSourceBuilder; } - public SearchSourceBuilder createSearchSourceBuilder(User user) { - return addUserBackendRolesFilter(user, new SearchSourceBuilder()); + public QueryBuilder mergeWithAccessFilter(QueryBuilder existing, Set ids) { + QueryBuilder accessFilter = (ids == null || ids.isEmpty()) + ? QueryBuilders.boolQuery().mustNot(QueryBuilders.matchAllQuery()) // deny-all + : QueryBuilders.idsQuery().addIds(ids.toArray(new String[0])); // use termsQuery(field, ids) if not _id + + if (existing == null) + return QueryBuilders.boolQuery().filter(accessFilter); + if (existing instanceof BoolQueryBuilder) { + ((BoolQueryBuilder) existing).filter(accessFilter); + return existing; + } + return QueryBuilders.boolQuery().must(existing).filter(accessFilter); } + + public void addAccessibleModelGroupsFilterAndSearch( + String tenantId, + SearchRequest request, + SdkClient sdkClient, + Consumer> onSuccess, + ActionListener wrappedListener + ) { + ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + // filter by accessible model-groups + rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(onSuccess::accept, e -> { + // Fail-safe: deny-all and still return a response + SearchSourceBuilder reqSrc = request.source() != null ? request.source() : new SearchSourceBuilder(); + reqSrc.query(mergeWithAccessFilter(reqSrc.query(), Collections.emptySet())); + request.source(reqSrc); + + SearchDataObjectRequest finalSearch = SearchDataObjectRequest + .builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .tenantId(tenantId) + .build(); + + sdkClient.searchDataObjectAsync(finalSearch).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); + })); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 7e264d3347..8db619ac35 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -35,7 +35,6 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; @@ -60,7 +59,6 @@ public class MLModelGroupManager { ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; - private MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public MLModelGroupManager( @@ -68,15 +66,13 @@ public MLModelGroupManager( Client client, SdkClient sdkClient, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper, - MLFeatureEnabledSetting mlFeatureEnabledSetting + ModelAccessControlHelper modelAccessControlHelper ) { this.mlIndicesHandler = mlIndicesHandler; this.client = client; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; - this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } public void createModelGroup(MLRegisterModelGroupInput input, ActionListener listener) { @@ -101,6 +97,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener indicesToListen; @@ -592,6 +597,7 @@ public Collection createComponents( ) { this.indexUtils = new IndexUtils(client, clusterService); this.client = client; + this.pluginClient = new PluginClient(client); this.threadPool = threadPool; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; @@ -883,7 +889,8 @@ public Collection createComponents( cmHandler, sdkClient, toolFactoryWrapper, - mcpToolsHelper + mcpToolsHelper, + pluginClient ); } @@ -1420,4 +1427,11 @@ public ScheduledJobRunner getJobRunner() { public ScheduledJobParser getJobParser() { return (parser, id, jobDocVersion) -> MLJobParameter.parse(parser); } + + @Override + public void assignSubject(PluginSubject pluginSubject) { + if (this.pluginClient != null) { + this.pluginClient.setSubject(pluginSubject); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java new file mode 100644 index 0000000000..497680a722 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.resources; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; + +import java.util.Set; + +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; +import org.opensearch.security.spi.resources.ResourceProvider; +import org.opensearch.security.spi.resources.ResourceSharingExtension; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +public class MLResourceSharingExtension implements ResourceSharingExtension { + + @Override + public Set getResourceProviders() { + return Set.of(new ResourceProvider(MLModelGroup.class.getCanonicalName(), ML_MODEL_GROUP_INDEX)); + } + + @Override + public void assignResourceSharingClient(ResourceSharingClient resourceSharingClient) { + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(resourceSharingClient); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/PluginClient.java b/plugin/src/main/java/org/opensearch/ml/utils/PluginClient.java new file mode 100644 index 0000000000..618bb39d5d --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/PluginClient.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.utils; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.identity.Subject; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.FilterClient; + +import lombok.Setter; + +/** + * A special client for executing transport actions as this plugin's system subject. + */ +@Setter +public class PluginClient extends FilterClient { + + private static final Logger logger = LogManager.getLogger(PluginClient.class); + + private Subject subject; + + public PluginClient(Client delegate) { + super(delegate); + } + + public PluginClient(Client delegate, Subject subject) { + super(delegate); + this.subject = subject; + } + + @Override + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (subject == null) { + throw new IllegalStateException("PluginClient is not initialized."); + } + try (ThreadContext.StoredContext ctx = threadPool().getThreadContext().newStoredContext(false)) { + subject.runAs(() -> { + logger.info("Running transport action with subject: {}", subject.getPrincipal().getName()); + super.doExecute(action, request, ActionListener.runBefore(listener, ctx::restore)); + }); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension new file mode 100644 index 0000000000..00dfe98b5a --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension @@ -0,0 +1 @@ +org.opensearch.ml.resources.MLResourceSharingExtension diff --git a/plugin/src/main/resources/plugin-additional-permissions.yml b/plugin/src/main/resources/plugin-additional-permissions.yml new file mode 100644 index 0000000000..424f580481 --- /dev/null +++ b/plugin/src/main/resources/plugin-additional-permissions.yml @@ -0,0 +1,24 @@ +cluster_permissions: + - "indices:data/read*" + - "indices:data/write/index*" + - "cluster:admin/opensearch/ml/*" +index_permissions: + - index_patterns: + - ".plugins-ml-agent" + - ".plugins-ml-config" + - ".plugins-ml-connector" + - ".plugins-ml-controller" + - ".plugins-ml-model-group" + - ".plugins-ml-model" + - ".plugins-ml-task" + - ".plugins-ml-memory-meta" + - ".plugins-ml-memory-message" + - ".plugins-ml-memory-container" + - ".plugins-ml-stop-words" + - ".plugins-ml-mcp-session-management" + - ".plugins-ml-mcp-tools" + - ".plugins-ml-jobs" + allowed_actions: + - "indices:data/write/index*" + - "indices:data/read*" + - "indices:admin/create" diff --git a/plugin/src/main/resources/resource-action-groups.yml b/plugin/src/main/resources/resource-action-groups.yml new file mode 100644 index 0000000000..40034b2b85 --- /dev/null +++ b/plugin/src/main/resources/resource-action-groups.yml @@ -0,0 +1,18 @@ +# For resource-access-management +resource_types: + org.opensearch.ml.common.MLModelGroup: + ml_read_only: + - "cluster:admin/opensearch/ml/model_groups/search" + - "cluster:admin/opensearch/ml/model_groups/get" + - "cluster:admin/opensearch/ml/models/search" + - "cluster:admin/opensearch/ml/models/get" + - "indices:data/read*" + + ml_read_write: + - "cluster:admin/opensearch/ml/model*" + - "indices:*" + + ml_full_access: + - "cluster:admin/opensearch/ml/model*" + - "indices:*" + - "cluster:admin/security/resource/share" diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index 61c0282ac2..0c0ee5a9fc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -168,10 +168,10 @@ public void setup() throws IOException { createControllerRequest = MLCreateControllerRequest.builder().controllerInput(controller).build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -233,10 +233,10 @@ public void testCreateControllerWithTextEmbeddingModelSuccess() { @Test public void testCreateControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -250,10 +250,10 @@ public void testCreateControllerWithModelAccessControlNoPermission() { @Test public void testCreateControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 52bbfdad3d..b433d0ecef 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -159,10 +159,10 @@ public void setup() throws IOException { mlControllerDeleteRequest = MLControllerDeleteRequest.builder().modelId("testModelId").build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -213,10 +213,10 @@ public void testDeleteControllerFailedWithControllerFeatureFlagDisabled() { @Test public void testDeleteControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -238,10 +238,10 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() return null; }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -255,10 +255,10 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() @Test public void testDeleteControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -277,13 +277,13 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( return null; }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener .onFailure( new RuntimeException("Permission denied: Unable to delete the model controller with the provided model. Details: ") ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index f414572b02..a030e1e363 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -119,10 +119,10 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareControllerGetResponse(); doAnswer(invocation -> { @@ -157,10 +157,10 @@ public void testGetControllerFailedWithControllerFeatureFlagDisabled() { @Test public void testGetControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -174,10 +174,10 @@ public void testGetControllerWithModelAccessControlNoPermission() { @Test public void testGetControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index 2bdef9c022..c25c2aa50c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -178,10 +178,10 @@ public void setup() throws IOException { updateControllerRequest = MLUpdateControllerRequest.builder().updateControllerInput(updatedController).build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -243,10 +243,10 @@ public void testUpdateControllerWithTextEmbeddingModelSuccess() { @Test public void testUpdateControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -268,10 +268,10 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() return null; }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -285,10 +285,10 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() @Test public void testUpdateControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -307,10 +307,10 @@ public void testUpdateControllerWithModelAccessControlOtherExceptionHiddenModel( return null; }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Permission denied: Unable to create the model controller for the model. Details: ")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -453,7 +453,7 @@ public void testUpdateControllerWithNullUpdateResponse() { public void testUpdateControllerWithDeploySuccessNullFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(3); listener.onResponse(mlDeployControllerNodesResponse); return null; }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index b98af5a906..06ea12ce73 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -174,10 +174,10 @@ public void setup() { when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true); @@ -355,10 +355,10 @@ public void testDoExecute_userHasNoAccessException() { }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); @@ -411,10 +411,10 @@ public void test_ValidationFailedException() { }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java index c59964063a..a8cb908ab0 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java @@ -115,10 +115,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -226,10 +226,10 @@ public void test_DeleteRequestInternalServerError() { @Test public void test_UserHasNoAccessException() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -240,10 +240,10 @@ public void test_UserHasNoAccessException() throws IOException { @Test public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java index aa2ceb20ce..bd6792136b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java @@ -106,10 +106,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -132,10 +132,10 @@ public void test_Success() throws IOException { public void testGetModel_UserHasNoAccess() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { @@ -152,10 +152,10 @@ public void testGetModel_UserHasNoAccess() throws IOException { public void testGetModel_ValidateAccessFailed() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index f6aac6ec92..7287b551c6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -5,13 +5,16 @@ package org.opensearch.ml.action.model_group; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.Set; +import java.util.concurrent.CompletableFuture; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -31,68 +34,68 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.remote.metadata.client.SdkClient; -import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.SearchDataObjectResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; public class SearchModelGroupTransportActionTests extends OpenSearchTestCase { + @Mock Client client; + @Mock SdkClient sdkClient; @Mock NamedXContentRegistry namedXContentRegistry; - @Mock TransportService transportService; - @Mock ActionFilters actionFilters; - SearchRequest searchRequest; - - SearchResponse searchResponse; - SearchSourceBuilder searchSourceBuilder; - - MLSearchActionRequest mlSearchActionRequest; - @Mock FetchSourceContext fetchSourceContext; @Mock ActionListener actionListener; - @Mock ThreadPool threadPool; - @Mock ClusterService clusterService; + SearchModelGroupTransportAction searchModelGroupTransportAction; @Mock private ModelAccessControlHelper modelAccessControlHelper; @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + ThreadContext threadContext; @Before - public void setup() { + @Override + public void setUp() throws Exception { + super.setUp(); MockitoAnnotations.openMocks(this); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + searchModelGroupTransportAction = new SearchModelGroupTransportAction( transportService, actionFilters, @@ -111,23 +114,31 @@ public void setup() { searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.fetchSource(fetchSourceContext); - searchRequest = new SearchRequest(new String[0], searchSourceBuilder); - mlSearchActionRequest = new MLSearchActionRequest(searchRequest, null); when(fetchSourceContext.includes()).thenReturn(new String[] {}); when(fetchSourceContext.excludes()).thenReturn(new String[] {}); - SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); - InternalSearchResponse internalSearchResponse = new InternalSearchResponse( - searchHits, - InternalAggregations.EMPTY, - null, - null, - false, - null, - 0 - ); - searchResponse = new SearchResponse( - internalSearchResponse, + // By default, do not skip access control + when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); + // Simplify the merged query for tests + when(modelAccessControlHelper.mergeWithAccessFilter(any(QueryBuilder.class), any(Set.class))) + .thenAnswer(inv -> QueryBuilders.termQuery("dummy", "value")); + } + + @Override + public void tearDown() throws Exception { + try { + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + } finally { + super.tearDown(); + } + } + + /** Helper: empty SDK response that can be converted to SearchResponse by the utils wrapper */ + private SearchDataObjectResponse emptySearchDataObjectResponse() { + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internal = new InternalSearchResponse(hits, InternalAggregations.EMPTY, null, null, false, null, 0); + SearchResponse osSearchResponse = new SearchResponse( + internal, null, 0, 0, @@ -138,68 +149,93 @@ public void setup() { null ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + SearchDataObjectResponse sdkResp = mock(SearchDataObjectResponse.class); + try { + when(sdkResp.searchResponse()).thenReturn(osSearchResponse); + } catch (Throwable ignore) {} + return sdkResp; } @Test - public void test_DoExecute() { - when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + public void test_DoExecute_success_callsSdkClient_andAddsBackendRoleFilter() { + MLSearchActionRequest mlReq = new MLSearchActionRequest(new SearchRequest(new String[0], searchSourceBuilder), "tenant-x"); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); + + future.complete(emptySearchDataObjectResponse()); verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); - verify(client).search(any(), any()); + verify(sdkClient).searchDataObjectAsync(any(SearchDataObjectRequest.class)); + verify(actionListener).onResponse(any(SearchResponse.class)); } @Test - public void test_DoExecute_Exception() throws InterruptedException { + public void test_DoExecute_exception_propagatesFailure() { + MLSearchActionRequest mlReq = new MLSearchActionRequest(new SearchRequest(new String[0], searchSourceBuilder), "tenant-x"); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("search failed")); - return null; - }).when(client).search(any(), any()); + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); - when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); + + future.completeExceptionally(new RuntimeException("search failed")); verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any()); - verify(client).search(any(), any()); + verify(sdkClient).searchDataObjectAsync(any(SearchDataObjectRequest.class)); verify(actionListener).onFailure(any(RuntimeException.class)); } @Test - public void test_skipModelAccessControlTrue() { + public void test_skipModelAccessControlTrue_stillCallsSdkClient() { when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(client).search(any(), any()); + MLSearchActionRequest mlReq = new MLSearchActionRequest( + new SearchRequest(new String[0], searchSourceBuilder), + "tenant-x" + ); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); + future.complete(emptySearchDataObjectResponse()); + + verify(sdkClient).searchDataObjectAsync(any(SearchDataObjectRequest.class)); + verify(actionListener).onResponse(any(SearchResponse.class)); } @Test - public void test_ThreadContextError() { - when(modelAccessControlHelper.skipModelAccessControl(any())).thenThrow(new RuntimeException("thread context error")); + public void test_ThreadContextError_wrappedWithMessage() { + when(modelAccessControlHelper.skipModelAccessControl(any())) + .thenThrow(new RuntimeException("thread context error")); + + MLSearchActionRequest mlReq = new MLSearchActionRequest( + new SearchRequest(new String[0], searchSourceBuilder), + "tenant-x" + ); + + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to search", argumentCaptor.getValue().getMessage()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(captor.capture()); + assertEquals("Fail to search", captor.getValue().getMessage()); } @Test - public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws InterruptedException { + public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query - SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); + SearchRequest request = + new SearchRequest("my_index").source(sourceBuilder); - mlSearchActionRequest = new MLSearchActionRequest(request, null); + MLSearchActionRequest mlReq = new MLSearchActionRequest(request, null); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(captor.capture()); @@ -209,21 +245,101 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringNotEnabled() throws } @Test - public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws InterruptedException { + public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query - SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); - mlSearchActionRequest = new MLSearchActionRequest(request, "123456"); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); + SearchRequest request = + new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlReq = new MLSearchActionRequest(request, "123456"); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + searchModelGroupTransportAction.doExecute(null, mlReq, actionListener); + future.complete(emptySearchDataObjectResponse()); + + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + @Test + public void testResourceSharingEnabled_successPath_filtersByAccessibleIds_andCallsSdkClient() { + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + ArgumentCaptor>> rscListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + SearchSourceBuilder ssb = new SearchSourceBuilder(); + SearchRequest sr = new SearchRequest(new String[] { CommonValue.ML_MODEL_GROUP_INDEX }, ssb); + MLSearchActionRequest req = new MLSearchActionRequest(sr, "tenant-1"); + + searchModelGroupTransportAction.doExecute(null, req, actionListener); + + verify(rsc).getAccessibleResourceIds(eq(CommonValue.ML_MODEL_GROUP_INDEX), rscListenerCaptor.capture()); + rscListenerCaptor.getValue().onResponse(Set.of("idA", "idB")); + future.complete(emptySearchDataObjectResponse()); + + verify(modelAccessControlHelper, atLeastOnce()).mergeWithAccessFilter(any(), eq(Set.of("idA", "idB"))); + + ArgumentCaptor sreq = ArgumentCaptor.forClass(SearchDataObjectRequest.class); + verify(sdkClient).searchDataObjectAsync(sreq.capture()); + SearchDataObjectRequest sent = sreq.getValue(); + + // Adjust these getters if your SDK uses record-style accessors + assertArrayEquals(new String[] { CommonValue.ML_MODEL_GROUP_INDEX }, sent.indices()); + assertEquals("tenant-1", sent.tenantId()); + + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + @Test + public void testResourceSharingEnabled_failSafePath_usesEmptySet_andCallsSdkClient() { + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + ArgumentCaptor>> rscListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + SearchRequest sr = new SearchRequest(new String[] { CommonValue.ML_MODEL_GROUP_INDEX }, new SearchSourceBuilder()); + MLSearchActionRequest req = new MLSearchActionRequest(sr, "tenant-2"); + + searchModelGroupTransportAction.doExecute(null, req, actionListener); + + // Simulate failure -> deny-all (empty set) + verify(rsc).getAccessibleResourceIds(eq(CommonValue.ML_MODEL_GROUP_INDEX), rscListenerCaptor.capture()); + rscListenerCaptor.getValue().onFailure(new RuntimeException("boom")); + + future.complete(emptySearchDataObjectResponse()); + + verify(modelAccessControlHelper, atLeastOnce()).mergeWithAccessFilter(any(), eq(Collections.emptySet())); + + verify(sdkClient).searchDataObjectAsync(any(SearchDataObjectRequest.class)); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + @Test + public void testThreadContext_isRestored_afterExecution() { + String key = "test-header"; + threadContext.putHeader(key, "original"); + + SearchRequest sr = new SearchRequest(new String[] { CommonValue.ML_MODEL_GROUP_INDEX }, new SearchSourceBuilder()); + MLSearchActionRequest req = new MLSearchActionRequest(sr, "tenant-4"); + + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + searchModelGroupTransportAction.doExecute(null, req, actionListener); + future.complete(emptySearchDataObjectResponse()); - searchModelGroupTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + assertEquals("original", threadContext.getHeader(key)); verify(actionListener).onResponse(any(SearchResponse.class)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 2f91a1837c..93e7b1580a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -187,10 +187,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(clusterService.getSettings()).thenReturn(settings); @@ -417,10 +417,10 @@ public void test_UserHasNoAccessException() throws IOException, InterruptedExcep }).when(client).get(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); @@ -498,10 +498,10 @@ public void test_ValidationFailedException() throws IOException, InterruptedExce }).when(client).get(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index e534e26505..248a0d84b2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -105,7 +105,6 @@ public void setup() throws IOException { actionFilters, client, sdkClient, - settings, xContentRegistry, clusterService, modelAccessControlHelper, @@ -114,10 +113,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -134,10 +133,10 @@ public void testGetModel_UserHasNodeAccess() throws IOException, InterruptedExce }).when(client).get(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); @@ -196,10 +195,10 @@ public void testGetModelHidden_SuperUserPermissionError() throws IOException, In public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 9a9ec854b2..104d70f66e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -186,7 +186,7 @@ public void test_DoExecute_addBackendRoles() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -198,7 +198,7 @@ public void test_DoExecute_addBackendRoles_without_groupIds() { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -210,7 +210,7 @@ public void test_DoExecute_addBackendRoles_exception() { listener.onFailure(new RuntimeException("runtime exception")); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); @@ -283,7 +283,7 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -297,7 +297,7 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -332,7 +332,7 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws In return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.addUserBackendRolesFilter(any(), any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 6d7e3ea9a9..dc3a4a3da6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -293,13 +293,15 @@ public void setup() throws IOException { // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; }) @@ -310,21 +312,22 @@ public void setup() throws IOException { any(), eq("test_model_group_id"), any(), + any(), any(SdkClient.class), isA(ActionListener.class) ); // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; }) @@ -335,6 +338,7 @@ public void setup() throws IOException { any(), eq("updated_test_model_group_id"), any(), + any(), any(SdkClient.class), isA(ActionListener.class) ); @@ -597,12 +601,12 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl @Test public void testUpdateModelWithModelAccessControlNoPermission() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), isA(ActionListener.class)); + .validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(SdkClient.class), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -620,7 +624,7 @@ public void testUpdateModelWithModelAccessControlNoPermission() throws Interrupt @Test public void testUpdateModelWithModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener .onFailure( new RuntimeException( @@ -628,7 +632,7 @@ public void testUpdateModelWithModelAccessControlOtherException() { ) ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -642,12 +646,12 @@ public void testUpdateModelWithModelAccessControlOtherException() { @Test public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermission() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -661,7 +665,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermis @Test public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener .onFailure( new RuntimeException( @@ -671,7 +675,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherExc return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -831,7 +835,16 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), eq("mockUpdateModelGroupId"), any(), eq(sdkClient), isA(ActionListener.class)); + .validateModelGroupAccess( + any(), + any(), + any(), + eq("mockUpdateModelGroupId"), + any(), + any(), + eq(sdkClient), + isA(ActionListener.class) + ); MLModelGroup modelGroup = MLModelGroup .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 9b1036f731..343170fce2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -167,10 +167,10 @@ public void testPrediction_default_exception() { when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -206,10 +206,10 @@ public void testPrediction_OpenSearchStatusException() { when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new OpenSearchStatusException("Testing OpenSearchStatusException", RestStatus.BAD_REQUEST)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -229,10 +229,10 @@ public void testPrediction_MLResourceNotFoundException() { when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new MLResourceNotFoundException("Testing MLResourceNotFoundException")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -252,10 +252,10 @@ public void testPrediction_MLLimitExceededException() { when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b0e290a693..823f2a1095 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -209,10 +209,10 @@ public void setup() throws IOException { assertNotNull(transportRegisterModelAction); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -289,10 +289,10 @@ public void testDoExecute_LocalModelDisabledException() { @Test public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -453,10 +453,10 @@ public void testTransportRegisterModelActionDoExecuteWithDispatchException() { @Test public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -703,10 +703,10 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException return null; }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -751,10 +751,10 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 88ae44c70b..68fd45ba3b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -185,10 +185,10 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -286,10 +286,10 @@ public void test_BatchPredictCancel_NoModelGroupAccess() throws IOException { remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index ca511306a4..1819e4ca21 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -244,10 +244,10 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(eq("testModelID"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -334,10 +334,10 @@ public void test_BatchPredictStatus_NoModelGroupAccess() throws IOException { remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -359,10 +359,10 @@ public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -388,10 +388,10 @@ public void test_BatchPredictStatus_NoConnectorFound() throws IOException { remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -419,10 +419,10 @@ public void test_BatchPredictStatus_NoModel() throws IOException { remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 722ba90bee..ebf476c051 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -153,7 +153,6 @@ public void setup() throws IOException { threadPool, client, sdkClient, - settings, xContentRegistry, nodeFilter, mlTaskDispatcher, @@ -452,10 +451,10 @@ public void testHiddenModelPermissionError() { public void testDoExecute() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -484,10 +483,10 @@ public void testDoExecute() { public void testDoExecute_modelAccessControl_notEnabled() { when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLUndeployModelsResponse mlUndeployModelsResponse = new MLUndeployModelsResponse(mock(MLUndeployModelNodesResponse.class)); doAnswer(invocation -> { @@ -502,10 +501,10 @@ public void testDoExecute_modelAccessControl_notEnabled() { public void testDoExecute_validate_false() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(6); + ActionListener listener = invocation.getArgument(7); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 8375ae5fca..7e2c6c4bbc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -92,10 +92,10 @@ public void setup() throws IOException { }).when(executorService).execute(any(Runnable.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -181,10 +181,10 @@ public void testUploadModelChunkNumberEqualsChunkCount() { public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 12286f57f7..b491902b23 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -91,10 +91,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -160,10 +160,10 @@ public void testDoExecute_failureWithCreateModelGroup() { @Test public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -177,10 +177,10 @@ public void testDoExecute_userHasNoAccessException() { @Test public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -210,10 +210,10 @@ public void testDoExecute_ModelNameAlreadyExists() throws IOException { @Test public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index 5083211d91..a89f607be9 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -114,14 +114,15 @@ public void setupModelGroup(String owner, String access, List backendRol // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedModelGroupID_NoSdkClient() { - modelAccessControlHelper.validateModelGroupAccess(null, null, client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, null, null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); } public void test_UndefinedModelGroupID() { - modelAccessControlHelper.validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, actionListener); + modelAccessControlHelper + .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, null, client, sdkClient, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -130,7 +131,7 @@ public void test_UndefinedModelGroupID() { // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedOwner_NoSdkClient() throws IOException { getResponse = modelGroupBuilder(null, null, null); - modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -139,7 +140,7 @@ public void test_UndefinedOwner_NoSdkClient() throws IOException { public void test_UndefinedOwner() throws IOException { getResponse = modelGroupBuilder(null, null, null); modelAccessControlHelper - .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, actionListener); + .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -150,7 +151,7 @@ public void test_ExceptionEmptyBackendRoles_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; User user = User.parse("owner|IT,HR|myTenant"); getResponse = modelGroupBuilder(null, AccessMode.RESTRICTED.getValue(), owner); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); @@ -168,7 +169,7 @@ public void test_ExceptionEmptyBackendRoles() throws IOException, InterruptedExc CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -182,7 +183,7 @@ public void test_MatchingBackendRoles_NoSdkClient() throws IOException { List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.RESTRICTED.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -201,7 +202,7 @@ public void test_MatchingBackendRoles() throws IOException, InterruptedException CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -215,7 +216,7 @@ public void test_PublicModelGroup_NoSdkClient() throws IOException { List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PUBLIC.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -234,7 +235,7 @@ public void test_PublicModelGroup() throws IOException, InterruptedException { CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -248,7 +249,7 @@ public void test_PrivateModelGroupWithSameOwner_NoSdkClient() throws IOException List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -267,7 +268,7 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException, Interrupte CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -281,7 +282,7 @@ public void test_PrivateModelGroupWithDifferentOwner_NoSdkClient() throws IOExce List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("user|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertFalse(argumentCaptor.getValue()); @@ -300,7 +301,7 @@ public void test_PrivateModelGroupWithDifferentOwner() throws IOException, Inter CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", null, client, sdkClient, latchedActionListener); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -413,11 +414,6 @@ public void test_AddUserBackendRolesFilter() { assertNotNull(modelAccessControlHelper.addUserBackendRolesFilter(user, builder)); } - public void test_CreateSearchSourceBuilder() { - User user = User.parse("owner|IT,HR|myTenant"); - assertNotNull(modelAccessControlHelper.createSearchSourceBuilder(user)); - } - private GetResponse modelGroupBuilder(List backendRoles, String access, String owner) throws IOException { MLModelGroup mlModelGroup = MLModelGroup .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 36ecd569b0..5525991879 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -109,14 +109,7 @@ public void setup() throws IOException { Settings settings = Settings.builder().build(); sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); threadContext = new ThreadContext(settings); - mlModelGroupManager = new MLModelGroupManager( - mlIndicesHandler, - client, - sdkClient, - clusterService, - modelAccessControlHelper, - mlFeatureEnabledSetting - ); + mlModelGroupManager = new MLModelGroupManager(mlIndicesHandler, client, sdkClient, clusterService, modelAccessControlHelper); assertNotNull(mlModelGroupManager); indexResponse = new IndexResponse(new ShardId(ML_MODEL_GROUP_INDEX, "_na_", 0), "model_group_ID", 1, 0, 2, true); // when(indexResponse.getId()).thenReturn("modelGroupID"); diff --git a/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java new file mode 100644 index 0000000000..f74fef5354 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.resources; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +import java.util.Iterator; +import java.util.Set; + +import org.junit.After; +import org.junit.Test; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; +import org.opensearch.security.spi.resources.ResourceProvider; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +public class MLResourceSharingExtensionTests { + private static String extractIndexFrom(Set providers) { + assertThat("providers should not be null", providers, is(not(nullValue()))); + assertThat("Expected exactly one provider", providers.size(), equalTo(1)); + Iterator it = providers.iterator(); + assertThat(it.hasNext(), equalTo(true)); + return it.next().resourceIndexName(); + } + + @After + public void tearDown() { + // Reset the accessor to avoid cross-test leakage + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + } + + @Test + public void testGetResourceProviders_returnsExpectedSingleProvider() { + MLResourceSharingExtension ext = new MLResourceSharingExtension(); + + Set providers = ext.getResourceProviders(); + assertThat(providers, is(not(nullValue()))); + assertThat(providers.size(), equalTo(1)); + + ResourceProvider provider = providers.iterator().next(); + assertThat( + "Resource type should be MLModelGroup canonical name", + provider.resourceType(), + equalTo(MLModelGroup.class.getCanonicalName()) + ); + + String index = provider.resourceIndexName(); + assertThat("Index must not be empty", index.trim(), equalTo(CommonValue.ML_MODEL_GROUP_INDEX)); + + } + + @Test(expected = UnsupportedOperationException.class) + public void testGetResourceProviders_returnsUnmodifiableSet() { + MLResourceSharingExtension ext = new MLResourceSharingExtension(); + Set providers = ext.getResourceProviders(); + + // Attempt to modify — Set.of(...) should be unmodifiable and throw + providers.add(new ResourceProvider("some.Type", "some-index")); + } + + @Test + public void testAssignResourceSharingClient_setsClientOnAccessor() { + MLResourceSharingExtension ext = new MLResourceSharingExtension(); + ResourceSharingClient mockClient = mock(ResourceSharingClient.class); + + assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), is(nullValue())); + + ext.assignResourceSharingClient(mockClient); + + assertThat( + "Accessor should hold the client passed to extension", + ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), + equalTo(mockClient) + ); + } + + @Test + public void testAssignResourceSharingClient_overwritesExistingClient() { + MLResourceSharingExtension ext = new MLResourceSharingExtension(); + ResourceSharingClient first = mock(ResourceSharingClient.class); + ResourceSharingClient second = mock(ResourceSharingClient.class); + + // Prime with the first client + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(first); + assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), equalTo(first)); + + // Now assign a new one via the extension + ext.assignResourceSharingClient(second); + + assertThat( + "Accessor should be updated to the new client", + ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), + equalTo(second) + ); + } + + @Test + public void testGetResourceProviders_isDeterministicAcrossCalls() { + MLResourceSharingExtension ext = new MLResourceSharingExtension(); + + Set first = ext.getResourceProviders(); + Set second = ext.getResourceProviders(); + + // Same contents + assertThat(first, equalTo(second)); + + // Extract and compare details for additional safety + String idx1 = extractIndexFrom(first); + String idx2 = extractIndexFrom(second); + assertThat(idx1, equalTo(idx2)); + } +}