diff --git a/common/build.gradle b/common/build.gradle index 3d831aff2b..a572fec814 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -23,6 +23,8 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2' testImplementation "org.opensearch.test:framework:${opensearch_version}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0' compileOnly group: 'org.json', name: 'json', version: '20231013' diff --git a/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java new file mode 100644 index 0000000000..5d0978dbc2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +/** + * Accessor for resource sharing client + */ +public class ResourceSharingClientAccessor { + private ResourceSharingClient CLIENT; + + private static ResourceSharingClientAccessor resourceSharingClientAccessor; + + private ResourceSharingClientAccessor() {} + + public static ResourceSharingClientAccessor getInstance() { + if (resourceSharingClientAccessor == null) { + resourceSharingClientAccessor = new ResourceSharingClientAccessor(); + } + + return resourceSharingClientAccessor; + } + + /** + * Set the resource sharing client + */ + public void setResourceSharingClient(ResourceSharingClient client) { + resourceSharingClientAccessor.CLIENT = client; + } + + /** + * Get the resource sharing client + */ + public ResourceSharingClient getResourceSharingClient() { + return resourceSharingClientAccessor.CLIENT; + } + +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 172b55b45e..0f8f633019 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -45,7 +45,7 @@ opensearchplugin { name 'opensearch-ml' description 'machine learning plugin for opensearch' classname 'org.opensearch.ml.plugin.MachineLearningPlugin' - extendedPlugins = ['opensearch-job-scheduler'] + extendedPlugins = ['opensearch-job-scheduler', 'opensearch-security;optional=true'] } configurations { @@ -71,6 +71,8 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" // Multi-tenant SDK Client diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 92fd0228f1..b04721b954 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -30,6 +30,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; @@ -66,6 +67,7 @@ public class CreateControllerTransportAction extends HandledTransportAction { MLIndicesHandler mlIndicesHandler; Client client; + Settings settings; MLModelManager mlModelManager; ClusterService clusterService; MLModelCacheHelper mlModelCacheHelper; @@ -78,6 +80,7 @@ public CreateControllerTransportAction( ActionFilters actionFilters, MLIndicesHandler mlIndicesHandler, Client client, + Settings settings, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, @@ -87,6 +90,7 @@ public CreateControllerTransportAction( super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new); this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.settings = settings; this.mlModelManager = mlModelManager; this.clusterService = clusterService; this.mlModelCacheHelper = mlModelCacheHelper; @@ -112,7 +116,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { if (mlModel.getModelState() != MLModelState.DEPLOYING) { indexAndCreateController(mlModel, controller, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 3be5e07a0b..5e8464bad2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -27,6 +27,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -55,6 +56,7 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class DeleteControllerTransportAction extends HandledTransportAction { Client client; + Settings settings; NamedXContentRegistry xContentRegistry; ClusterService clusterService; MLModelManager mlModelManager; @@ -67,6 +69,7 @@ public DeleteControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, MLModelManager mlModelManager, @@ -98,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { mlModelManager .getController( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index d70488948f..0e040672fa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -48,6 +49,7 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class GetControllerTransportAction extends HandledTransportAction { Client client; + Settings settings; NamedXContentRegistry xContentRegistry; ClusterService clusterService; MLModelManager mlModelManager; @@ -59,6 +61,7 @@ public GetControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, MLModelManager mlModelManager, @@ -67,6 +70,7 @@ public GetControllerTransportAction( ) { super(MLControllerGetAction.NAME, transportService, actionFilters, MLControllerGetRequest::new); this.client = client; + this.settings = settings; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.mlModelManager = mlModelManager; @@ -96,34 +100,40 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - getErrorMessage( - "User doesn't have privilege to perform this operation on this model controller.", - modelId, - isHidden - ), - RestStatus.FORBIDDEN - ) + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + getErrorMessage( + "User doesn't have privilege to perform this operation on this model controller.", + modelId, + isHidden + ), + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to create the model controller for the given model.", + modelId, + isHidden + ), + exception ); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to create the model controller for the given model.", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + wrappedListener.onFailure(exception); + }) + ); }, e -> wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index ac44069930..ca57388192 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -29,6 +29,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; @@ -60,6 +61,7 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class UpdateControllerTransportAction extends HandledTransportAction { Client client; + Settings settings; MLModelManager mlModelManager; MLModelCacheHelper mlModelCacheHelper; ClusterService clusterService; @@ -71,6 +73,7 @@ public UpdateControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, @@ -79,6 +82,7 @@ public UpdateControllerTransportAction( ) { super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new); this.client = client; + this.settings = settings; this.mlModelManager = mlModelManager; this.clusterService = clusterService; this.mlModelCacheHelper = mlModelCacheHelper; @@ -104,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { mlModelManager.getController(modelId, ActionListener.wrap(controller -> { boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput); diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 9ae795438d..463e3e8ad5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -177,7 +177,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(access -> { if (!access) { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 90584451e0..7ab561cc43 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -20,6 +20,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -65,17 +66,20 @@ public class MLSearchHandler { private ModelAccessControlHelper modelAccessControlHelper; private ClusterService clusterService; + private Settings settings; public MLSearchHandler( Client client, NamedXContentRegistry xContentRegistry, ModelAccessControlHelper modelAccessControlHelper, - ClusterService clusterService + ClusterService clusterService, + Settings settings ) { this.modelAccessControlHelper = modelAccessControlHelper; this.client = client; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; + this.settings = settings; } /** @@ -144,7 +148,7 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, .searchDataObjectAsync(searchDataObjectRequest) .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrapperListener)); } else { - SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); + SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user, settings); SearchRequest modelGroupSearchRequest = new SearchRequest(); sourceBuilder.fetchSource(new String[] { MLModelGroup.MODEL_GROUP_ID_FIELD, }, null); sourceBuilder.size(10000); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index 7a9b3925b4..f9ff39bdfc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -57,6 +58,7 @@ public class DeleteModelGroupTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); } } @@ -107,6 +112,7 @@ private void validateAndDeleteModelGroup(String modelGroupId, String tenantId, A modelGroupId, client, sdkClient, + settings, ActionListener .wrap( hasAccess -> handleAccessValidation(hasAccess, modelGroupId, tenantId, listener), diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index f1dbe8be48..5bc07565fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -17,6 +17,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -52,6 +53,7 @@ public class GetModelGroupTransportAction extends HandledTransportAction { final Client client; + final Settings settings; final SdkClient sdkClient; final NamedXContentRegistry xContentRegistry; final ClusterService clusterService; @@ -63,6 +65,7 @@ public GetModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, @@ -71,6 +74,7 @@ public GetModelGroupTransportAction( ) { super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new); this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; @@ -183,7 +187,7 @@ private void validateModelGroupAccess( MLModelGroup mlModelGroup, ActionListener wrappedListener ) { - modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { + modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, settings, ActionListener.wrap(access -> { if (!access) { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 96af8cd317..e17af1e3ae 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,7 +6,10 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -14,6 +17,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -35,6 +39,7 @@ @Log4j2 public class SearchModelGroupTransportAction extends HandledTransportAction { Client client; + Settings settings; SdkClient sdkClient; ClusterService clusterService; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -46,6 +51,7 @@ public SearchModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, @@ -53,6 +59,7 @@ public SearchModelGroupTransportAction( ) { super(MLModelGroupSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; @@ -76,13 +83,19 @@ private void preProcessRoleAndPerformSearch( User user, ActionListener listener ) { + boolean isResourceSharingFeatureEnabled = ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings) + && this.settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (!modelAccessControlHelper.skipModelAccessControl(user)) { + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + // User will be fetched from thread context using persistent header, so stash context will not stash user info + modelAccessControlHelper.addAccessibleModelGroupsFilter(request.source()); + } else if (!modelAccessControlHelper.skipModelAccessControl(user)) { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by {}", user.getBackendRoles()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index d3ab730f0d..1911ec528c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -8,7 +8,10 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.MLExceptionUtils.logException; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; import java.time.Instant; import java.util.HashSet; @@ -23,6 +26,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -35,6 +39,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -51,6 +56,7 @@ import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -66,6 +72,7 @@ public class TransportUpdateModelGroupAction extends HandledTransportAction { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + + user.getName() + + " is not authorized to update ml-model-group: " + + mlModelGroup.getName(), + RestStatus.FORBIDDEN + ) + ); + return; + } + // For backwards compatibility we still allow storing backend_roles data in ml_model_group + // index + updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); + }, listener::onFailure)); } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + // TODO: At some point, this call must be replaced by the one above, (i.e. no user info to + // be stored in model-group index) + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { + validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + } + + updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); } - updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); + } } catch (Exception e) { log.error("Failed to parse ml connector {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index c7f9d4ae18..2a121998eb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -106,12 +106,11 @@ public class DeleteModelTransportAction extends HandledTransportAction { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else if (isModelNotDeployed(mlModelState)) { - if (isSafeDelete) { - // We only check downstream task when it's not hidden and cluster setting is true. - checkDownstreamTaskBeforeDeleteModel( - modelId, - tenantId, - mlModel.getAlgorithm().name(), - isHidden, - actionListener - ); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else if (isModelNotDeployed(mlModelState)) { + if (isSafeDelete) { + // We only check downstream task when it's not hidden and cluster setting is true. + checkDownstreamTaskBeforeDeleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } else { + deleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } + // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, + // actionListener); } else { - deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", + RestStatus.BAD_REQUEST + ) + ); } - // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", - RestStatus.BAD_REQUEST - ) - ); - } - }, e -> { - log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 64c9eb6676..24246b82b1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -141,27 +141,33 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - log.debug("Completed Get Model Request, id:{}", modelId); - Connector connector = mlModel.getConnector(); - if (connector != null) { - connector.removeCredential(); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); } - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } - }, e -> { - log.error("Failed to validate Access for Model Id {}", modelId, e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Id {}", modelId, e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index d17085f01f..506831c626 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -164,6 +164,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { updateRemoteOrTextEmbeddingModel( @@ -380,7 +381,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); if (newModelGroupId != null) { modelAccessControlHelper - .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { + .validateModelGroupAccess(user, newModelGroupId, client, settings, ActionListener.wrap(hasNewModelGroupPermission -> { if (hasNewModelGroupPermission) { mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { buildUpdateRequest( diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 92b4dbaa4e..6ffa23af63 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -57,6 +57,7 @@ public class TransportPredictionTaskAction extends HandledTransportAction { if (!access) { wrappedListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 00c577c8d6..26f3b3f768 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -213,6 +213,7 @@ private void checkUserAccess( registerModelInput.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(access -> { if (access) { doRegister(registerModelInput, listener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 4c1bf76529..e0fafeaa54 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -72,6 +73,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction getModelListener = ActionListener.wrap(model -> { - modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); - } else { - if (model.getConnector() != null) { - Connector connector = model.getConnector(); - executeConnector(connector, mlInput, actionListener); - } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { - ActionListener listener = ActionListener - .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { - log.error("Failed to get connector {}", model.getConnectorId(), e); - actionListener.onFailure(e); - }); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper - .getConnector( - client, - model.getConnectorId(), - ActionListener.runBefore(listener, threadContext::restore) - ); - } + modelAccessControlHelper + .validateModelGroupAccess(user, model.getModelGroupId(), client, settings, ActionListener.wrap(access -> { + if (!access) { + actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); } else { - actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener + .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { + log.error("Failed to get connector {}", model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector( + client, + model.getConnectorId(), + ActionListener.runBefore(listener, threadContext::restore) + ); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } } - } - }, e -> { - log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); - actionListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); + actionListener.onFailure(e); + })); }, e -> { log.error("Failed to retrieve the ML model with the given ID", e); actionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 2422a439d2..ef824350ac 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -111,6 +111,7 @@ public class GetTaskTransportAction extends HandledTransportAction remoteJobStatusFields = it); @@ -374,6 +376,7 @@ private void processRemoteBatchPrediction( model.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(access -> { if (!access) { actionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index ada9f1a604..027b1bae7a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -307,6 +307,7 @@ private void validateAccess(String modelId, String tenantId, ActionListener { - if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - wrappedListener - .onFailure( - new IllegalArgumentException( - "You don't have permissions to perform this operation on this model." - ) - ); - } else { - existingModel.setModelId(r.getId()); - if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { - throw new Exception("Chunk number exceeds total chunks"); - } - byte[] bytes = uploadModelChunkInput.getContent(); - // Check the size of the content not to exceed 10 mb - if (bytes == null || bytes.length == 0) { - throw new Exception("Chunk size either 0 or null"); - } - if (validateChunkSize(bytes.length)) { - throw new Exception("Chunk size exceeds 10MB"); - } - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { - if (!res) { - wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); - return; - } - int chunkNum = uploadModelChunkInput.getChunkNumber(); - MLModel mlModel = MLModel - .builder() - .algorithm(existingModel.getAlgorithm()) - .modelGroupId(existingModel.getModelGroupId()) - .version(existingModel.getVersion()) - .modelId(existingModel.getModelId()) - .modelFormat(existingModel.getModelFormat()) - .totalChunks(existingModel.getTotalChunks()) - .algorithm(existingModel.getAlgorithm()) - .chunkNumber(chunkNum) - .content(Base64.getEncoder().encodeToString(bytes)) - .build(); - IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); - indexRequest - .source( - mlModel - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + .validateModelGroupAccess( + user, + existingModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + log.error("You don't have permissions to perform this operation on this model."); + wrappedListener + .onFailure( + new IllegalArgumentException( + "You don't have permissions to perform this operation on this model." + ) ); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, ActionListener.wrap(response -> { - log - .info( - "Index model successful for {} for chunk number {}", - uploadModelChunkInput.getModelId(), - chunkNum + 1 + } else { + existingModel.setModelId(r.getId()); + if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { + throw new Exception("Chunk number exceeds total chunks"); + } + byte[] bytes = uploadModelChunkInput.getContent(); + // Check the size of the content not to exceed 10 mb + if (bytes == null || bytes.length == 0) { + throw new Exception("Chunk size either 0 or null"); + } + if (validateChunkSize(bytes.length)) { + throw new Exception("Chunk size exceeds 10MB"); + } + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + if (!res) { + wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); + return; + } + int chunkNum = uploadModelChunkInput.getChunkNumber(); + MLModel mlModel = MLModel + .builder() + .algorithm(existingModel.getAlgorithm()) + .modelGroupId(existingModel.getModelGroupId()) + .version(existingModel.getVersion()) + .modelId(existingModel.getModelId()) + .modelFormat(existingModel.getModelFormat()) + .totalChunks(existingModel.getTotalChunks()) + .algorithm(existingModel.getAlgorithm()) + .chunkNumber(chunkNum) + .content(Base64.getEncoder().encodeToString(bytes)) + .build(); + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest + .id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); + indexRequest + .source( + mlModel + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) ); - if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { - Semaphore semaphore = new Semaphore(1); - semaphore.acquire(); - MLModel mlModelMeta = MLModel - .builder() - .name(existingModel.getName()) - .algorithm(existingModel.getAlgorithm()) - .version(existingModel.getVersion()) - .modelGroupId((existingModel.getModelGroupId())) - .modelFormat(existingModel.getModelFormat()) - .modelState(MLModelState.REGISTERED) - .modelConfig(existingModel.getModelConfig()) - .totalChunks(existingModel.getTotalChunks()) - .modelContentHash(existingModel.getModelContentHash()) - .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) - .createdTime(existingModel.getCreatedTime()) - .build(); - IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); - indexReq.id(modelId); - indexReq - .source( - mlModelMeta - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(response -> { + log + .info( + "Index model successful for {} for chunk number {}", + uploadModelChunkInput.getModelId(), + chunkNum + 1 ); - indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexReq, ActionListener.wrap(re -> { - log.debug("Index model successful", existingModel.getName()); - semaphore.release(); - }, e -> { - log.error("Failed to update model state", e); - semaphore.release(); - wrappedListener.onFailure(e); - })); - } - wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); - }, e -> { - log.error("Failed to upload chunk model", e); - wrappedListener.onFailure(e); + if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { + Semaphore semaphore = new Semaphore(1); + semaphore.acquire(); + MLModel mlModelMeta = MLModel + .builder() + .name(existingModel.getName()) + .algorithm(existingModel.getAlgorithm()) + .version(existingModel.getVersion()) + .modelGroupId((existingModel.getModelGroupId())) + .modelFormat(existingModel.getModelFormat()) + .modelState(MLModelState.REGISTERED) + .modelConfig(existingModel.getModelConfig()) + .totalChunks(existingModel.getTotalChunks()) + .modelContentHash(existingModel.getModelContentHash()) + .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) + .createdTime(existingModel.getCreatedTime()) + .build(); + IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); + indexReq.id(modelId); + indexReq + .source( + mlModelMeta + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) + ); + indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexReq, ActionListener.wrap(re -> { + log.debug("Index model successful", existingModel.getName()); + semaphore.release(); + }, e -> { + log.error("Failed to update model state", e); + semaphore.release(); + wrappedListener.onFailure(e); + })); + } + wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); + }, e -> { + log.error("Failed to upload chunk model", e); + wrappedListener.onFailure(e); + })); + }, ex -> { + log.error("Failed to init model index", ex); + wrappedListener.onFailure(ex); })); - }, ex -> { - log.error("Failed to init model index", ex); - wrappedListener.onFailure(ex); - })); - } - }, e -> { - logException("Failed to validate model access", e, log); - wrappedListener.onFailure(e); - })); + } + }, e -> { + logException("Failed to validate model access", e, log); + wrappedListener.onFailure(e); + }) + ); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index ec3b67949c..7eafc5b78b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -12,6 +12,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLTaskState; @@ -37,6 +38,7 @@ public class TransportRegisterModelMetaAction extends HandledTransportAction { - if (access) { - createModelGroup(mlUploadInput, listener); - return; - } - if (isModelNameAlreadyExisting) { - listener - .onFailure( - new IllegalArgumentException( - "The name {" - + mlUploadInput.getName() - + "} you provided is unavailable because it is used by another model group with id {" - + mlUploadInput.getModelGroupId() - + "} to which you do not have access. Please provide a different name." - ) - ); - } else { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } - }, e -> { - logException("Failed to validate model access", e, log); - listener.onFailure(e); - })); + modelAccessControlHelper + .validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, settings, ActionListener.wrap(access -> { + if (access) { + createModelGroup(mlUploadInput, listener); + return; + } + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + mlUploadInput.getName() + + "} you provided is unavailable because it is used by another model group with id {" + + mlUploadInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } else { + log.error("You don't have permissions to perform this operation on this model."); + listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); + } + }, e -> { + logException("Failed to validate model access", e, log); + listener.onFailure(e); + })); } private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionListener listener) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index ac2cfded6c..335fb9ea42 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -11,6 +11,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; import java.util.Collections; import java.util.HashSet; @@ -19,6 +21,7 @@ import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.cluster.service.ClusterService; @@ -28,6 +31,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -45,6 +49,7 @@ import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -54,6 +59,7 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableList; @@ -84,9 +90,42 @@ public ModelAccessControlHelper(ClusterService clusterService, Settings settings RangeQueryBuilder.class ); + private boolean isResourceSharingFeatureEnabled(Settings settings) { + return isModelAccessControlEnabled() + && settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); + } + // TODO Eventually remove this when all usages of it have been migrated to the SdkClient version - public void validateModelGroupAccess(User user, String modelGroupId, Client client, ActionListener listener) { - if (modelGroupId == null || isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { + public void validateModelGroupAccess( + User user, + String modelGroupId, + Client client, + Settings settings, + ActionListener listener + ) { + if (modelGroupId == null) { + listener.onResponse(true); + return; + } + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + if (isResourceSharingFeatureEnabled) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { listener.onResponse(true); return; } @@ -132,11 +171,32 @@ public void validateModelGroupAccess( String modelGroupId, Client client, SdkClient sdkClient, + Settings settings, ActionListener listener ) { - if (modelGroupId == null - || (!mlFeatureEnabledSetting.isMultiTenancyEnabled() - && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)))) { + if (modelGroupId == null || (!mlFeatureEnabledSetting.isMultiTenancyEnabled())) { + listener.onResponse(true); + return; + } + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + if (isResourceSharingFeatureEnabled) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { listener.onResponse(true); return; } @@ -313,7 +373,32 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil return searchSourceBuilder; } - public SearchSourceBuilder createSearchSourceBuilder(User user) { + public SearchSourceBuilder createSearchSourceBuilder(User user, Settings settings) { + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + return addAccessibleModelGroupsFilter(new SearchSourceBuilder()); + } return addUserBackendRolesFilter(user, new SearchSourceBuilder()); } + + public SearchSourceBuilder addAccessibleModelGroupsFilter(SearchSourceBuilder searchSourceBuilder) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + + resourceSharingClient.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(modelGroupIds -> { + if (modelGroupIds.isEmpty()) { + // User has no access → return nothing + searchSourceBuilder.query(QueryBuilders.boolQuery().mustNot(QueryBuilders.matchAllQuery())); + } else { + // Restrict search strictly to these _ids + // TODO check if this should be replaced with model_group_ids: MLModelGroup.MODEL_GROUP_ID_FIELD + + searchSourceBuilder.query(QueryBuilders.termsQuery(MLModelGroup.MODEL_GROUP_ID_FIELD + ".keyword", modelGroupIds)); + } + }, failure -> { + // do nothing to the source or return empty set? + searchSourceBuilder.query(QueryBuilders.boolQuery().mustNot(QueryBuilders.matchAllQuery())); + })); + return searchSourceBuilder; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 7e264d3347..d917d23368 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -7,9 +7,16 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; +import static org.opensearch.security.spi.resources.ResourceAccessLevels.PLACE_HOLDER; import java.time.Instant; import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; @@ -19,6 +26,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -34,6 +42,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -48,6 +57,10 @@ import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; +import org.opensearch.security.spi.resources.sharing.Recipient; +import org.opensearch.security.spi.resources.sharing.Recipients; +import org.opensearch.security.spi.resources.sharing.ShareWith; import org.opensearch.transport.client.Client; import lombok.extern.log4j.Log4j2; @@ -56,6 +69,7 @@ public class MLModelGroupManager { private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final Settings settings; private final SdkClient sdkClient; ClusterService clusterService; @@ -66,6 +80,7 @@ public class MLModelGroupManager { public MLModelGroupManager( MLIndicesHandler mlIndicesHandler, Client client, + Settings settings, SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, @@ -73,6 +88,7 @@ public MLModelGroupManager( ) { this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; @@ -83,6 +99,11 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener>> recipientMap = new AtomicReference<>(); + boolean isResourceSharingFeatureEnabled = ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings) + && this.settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); validateUniqueModelGroupName(input.getName(), input.getTenantId(), ActionListener.wrap(modelGroups -> { @@ -101,9 +122,16 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener { + log + .debug( + "Successfully shared ml-model-group: {} with entities: {}", + modelName, + recipientMap + ); + + wrappedListener.onResponse(r.id()); + }, listener::onFailure) + ); + } else { + wrappedListener.onResponse(r.id()); + } + } catch (Exception e) { wrappedListener.onFailure(e); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 38a84804f1..1c100941ce 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -612,7 +612,7 @@ public Collection createComponents( mlFeatureEnabledSetting ); - mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); + mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, settings, xContentRegistry, modelAccessControlHelper); MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper); mlTrainingTaskRunner = new MLTrainingTaskRunner( @@ -724,7 +724,7 @@ public Collection createComponents( MetricsCorrelation metricsCorrelation = new MetricsCorrelation(client, settings, clusterService); MLEngineClassLoader.register(FunctionName.METRICS_CORRELATION, metricsCorrelation); - MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService); + MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService, settings); MLModelAutoReDeployer mlModelAutoRedeployer = new MLModelAutoReDeployer( clusterService, client, diff --git a/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java new file mode 100644 index 0000000000..13405ead58 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.resources; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; + +import java.util.Set; + +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; +import org.opensearch.security.spi.resources.ResourceProvider; +import org.opensearch.security.spi.resources.ResourceSharingExtension; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +public class MLResourceSharingExtension implements ResourceSharingExtension { + @Override + public Set getResourceProviders() { + return Set.of(new ResourceProvider(MLModelGroup.class.getCanonicalName(), ML_MODEL_GROUP_INDEX)); + } + + @Override + public void assignResourceSharingClient(ResourceSharingClient resourceSharingClient) { + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(resourceSharingClient); + } +} diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension new file mode 100644 index 0000000000..00dfe98b5a --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension @@ -0,0 +1 @@ +org.opensearch.ml.resources.MLResourceSharingExtension diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index 61c0282ac2..68238b30bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -149,6 +149,7 @@ public void setup() throws IOException { actionFilters, mlIndicesHandler, client, + Settings.EMPTY, clusterService, modelAccessControlHelper, mlModelCacheHelper, @@ -171,7 +172,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -236,7 +237,7 @@ public void testCreateControllerWithModelAccessControlNoPermission() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -253,7 +254,7 @@ public void testCreateControllerWithModelAccessControlOtherException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 52bbfdad3d..403a7e806d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -147,6 +147,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + Settings.EMPTY, xContentRegistry, clusterService, mlModelManager, @@ -162,7 +163,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -216,7 +217,7 @@ public void testDeleteControllerWithModelAccessControlNoPermission() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -241,7 +242,7 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -258,7 +259,7 @@ public void testDeleteControllerWithModelAccessControlOtherException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -283,7 +284,7 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( new RuntimeException("Permission denied: Unable to delete the model controller with the provided model. Details: ") ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index f414572b02..462a6fe739 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -103,6 +103,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, xContentRegistry, clusterService, mlModelManager, @@ -122,7 +123,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareControllerGetResponse(); doAnswer(invocation -> { @@ -160,7 +161,7 @@ public void testGetControllerWithModelAccessControlNoPermission() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -177,7 +178,7 @@ public void testGetControllerWithModelAccessControlOtherException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index 2bdef9c022..13c958d82d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -151,6 +151,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + Settings.EMPTY, clusterService, modelAccessControlHelper, mlModelCacheHelper, @@ -181,7 +182,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -246,7 +247,7 @@ public void testUpdateControllerWithModelAccessControlNoPermission() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -271,7 +272,7 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -288,7 +289,7 @@ public void testUpdateControllerWithModelAccessControlOtherException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -310,7 +311,7 @@ public void testUpdateControllerWithModelAccessControlOtherExceptionHiddenModel( ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Permission denied: Unable to create the model controller for the model. Details: ")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index deb8c054af..f618b85b0a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -177,7 +177,7 @@ public void setup() { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true); @@ -358,7 +358,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); @@ -414,7 +414,7 @@ public void test_ValidationFailedException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java index 0bf67454b9..83cf3a12b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java @@ -106,6 +106,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, @@ -118,7 +119,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -229,7 +230,7 @@ public void test_UserHasNoAccessException() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -243,7 +244,7 @@ public void test_ValidationFailedException() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java index aa2ceb20ce..1a54ee2faf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java @@ -97,6 +97,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, @@ -109,7 +110,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -135,7 +136,7 @@ public void testGetModel_UserHasNoAccess() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { @@ -155,7 +156,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index f6aac6ec92..f97a4b622e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -97,6 +97,7 @@ public void setup() { transportService, actionFilters, client, + Settings.EMPTY, sdkClient, clusterService, modelAccessControlHelper, diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index c62716d793..839dfa9d55 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -120,6 +120,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 2f91a1837c..fb4c420183 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -190,7 +190,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(clusterService.getSettings()).thenReturn(settings); @@ -420,7 +420,7 @@ public void test_UserHasNoAccessException() throws IOException, InterruptedExcep ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); @@ -501,7 +501,7 @@ public void test_ValidationFailedException() throws IOException, InterruptedExce ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index e534e26505..9b0cc019a4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -117,7 +117,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -137,7 +137,7 @@ public void testGetModel_UserHasNodeAccess() throws IOException, InterruptedExce ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); @@ -199,7 +199,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedE ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index d1a0279bb6..68999b0a50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -114,7 +114,7 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService)); + mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService, Settings.EMPTY)); searchModelTransportAction = new SearchModelTransportAction( transportService, actionFilters, @@ -184,7 +184,7 @@ public void test_DoExecute_addBackendRoles() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -196,7 +196,7 @@ public void test_DoExecute_addBackendRoles_without_groupIds() { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -208,7 +208,7 @@ public void test_DoExecute_addBackendRoles_exception() { listener.onFailure(new RuntimeException("runtime exception")); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); @@ -281,7 +281,7 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -295,7 +295,7 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -330,7 +330,7 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws In return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index ab30855a2e..fd9c97dbbe 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -296,7 +296,9 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(6); @@ -311,6 +313,7 @@ public void setup() throws IOException { eq("test_model_group_id"), any(), any(SdkClient.class), + any(), isA(ActionListener.class) ); @@ -321,7 +324,7 @@ public void setup() throws IOException { return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(6); @@ -336,6 +339,7 @@ public void setup() throws IOException { eq("updated_test_model_group_id"), any(), any(SdkClient.class), + any(), isA(ActionListener.class) ); @@ -602,7 +606,7 @@ public void testUpdateModelWithModelAccessControlNoPermission() throws Interrupt return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), isA(ActionListener.class)); + .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -628,7 +632,7 @@ public void testUpdateModelWithModelAccessControlOtherException() { ) ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -647,7 +651,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermis return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -671,7 +675,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherExc return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -807,7 +811,16 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), eq("mockUpdateModelGroupId"), any(), eq(sdkClient), isA(ActionListener.class)); + .validateModelGroupAccess( + any(), + any(), + any(), + eq("mockUpdateModelGroupId"), + any(), + eq(sdkClient), + any(), + isA(ActionListener.class) + ); MLModelGroup modelGroup = MLModelGroup .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 9b1036f731..2539afa897 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -170,7 +170,7 @@ public void testPrediction_default_exception() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -209,7 +209,7 @@ public void testPrediction_OpenSearchStatusException() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new OpenSearchStatusException("Testing OpenSearchStatusException", RestStatus.BAD_REQUEST)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -232,7 +232,7 @@ public void testPrediction_MLResourceNotFoundException() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new MLResourceNotFoundException("Testing MLResourceNotFoundException")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -255,7 +255,7 @@ public void testPrediction_MLLimitExceededException() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b0e290a693..dcff39c59c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -212,7 +212,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -292,7 +292,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -456,7 +456,7 @@ public void test_ValidationFailedException() { ActionListener listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -706,7 +706,7 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -754,7 +754,7 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 88ae44c70b..1d78899439 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -140,6 +140,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, xContentRegistry, clusterService, scriptService, @@ -188,7 +189,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -289,7 +290,7 @@ public void test_BatchPredictCancel_NoModelGroupAccess() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index ca511306a4..616eedb0e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -247,7 +247,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -337,7 +337,7 @@ public void test_BatchPredictStatus_NoModelGroupAccess() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -362,7 +362,7 @@ public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -391,7 +391,7 @@ public void test_BatchPredictStatus_NoConnectorFound() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -422,7 +422,7 @@ public void test_BatchPredictStatus_NoModel() throws IOException { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 72af11264d..55dfebd100 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -448,7 +448,7 @@ public void testDoExecute() { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -479,7 +479,7 @@ public void testDoExecute_modelAccessControl_notEnabled() { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLUndeployModelsResponse mlUndeployModelsResponse = new MLUndeployModelsResponse(mock(MLUndeployModelNodesResponse.class)); doAnswer(invocation -> { @@ -497,7 +497,7 @@ public void testDoExecute_validate_false() { ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 8375ae5fca..f6730459ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -95,7 +95,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -117,7 +117,7 @@ public void setup() throws IOException { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); - mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); + mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, settings, xContentRegistry, modelAccessControlHelper); MLModel mlModel = MLModel .builder() @@ -184,7 +184,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 7c04103a0e..2597bdbf1f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -86,6 +86,7 @@ public void setup() throws IOException { actionFilters, mlModelManager, client, + settings, modelAccessControlHelper, mlModelGroupManager ); @@ -94,7 +95,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -163,7 +164,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -180,7 +181,7 @@ public void test_ValidationFailedException() { ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -213,7 +214,7 @@ public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOExceptio ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index 5083211d91..e31d556e87 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -114,14 +114,15 @@ public void setupModelGroup(String owner, String access, List backendRol // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedModelGroupID_NoSdkClient() { - modelAccessControlHelper.validateModelGroupAccess(null, null, client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, null, client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); } public void test_UndefinedModelGroupID() { - modelAccessControlHelper.validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, actionListener); + modelAccessControlHelper + .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -130,7 +131,7 @@ public void test_UndefinedModelGroupID() { // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedOwner_NoSdkClient() throws IOException { getResponse = modelGroupBuilder(null, null, null); - modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -139,7 +140,16 @@ public void test_UndefinedOwner_NoSdkClient() throws IOException { public void test_UndefinedOwner() throws IOException { getResponse = modelGroupBuilder(null, null, null); modelAccessControlHelper - .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, actionListener); + .validateModelGroupAccess( + null, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + actionListener + ); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -150,7 +160,7 @@ public void test_ExceptionEmptyBackendRoles_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; User user = User.parse("owner|IT,HR|myTenant"); getResponse = modelGroupBuilder(null, AccessMode.RESTRICTED.getValue(), owner); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); @@ -168,7 +178,16 @@ public void test_ExceptionEmptyBackendRoles() throws IOException, InterruptedExc CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -182,7 +201,7 @@ public void test_MatchingBackendRoles_NoSdkClient() throws IOException { List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.RESTRICTED.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -201,7 +220,16 @@ public void test_MatchingBackendRoles() throws IOException, InterruptedException CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -215,7 +243,7 @@ public void test_PublicModelGroup_NoSdkClient() throws IOException { List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PUBLIC.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -234,7 +262,16 @@ public void test_PublicModelGroup() throws IOException, InterruptedException { CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -248,7 +285,7 @@ public void test_PrivateModelGroupWithSameOwner_NoSdkClient() throws IOException List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -267,7 +304,16 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException, Interrupte CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -281,7 +327,7 @@ public void test_PrivateModelGroupWithDifferentOwner_NoSdkClient() throws IOExce List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("user|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertFalse(argumentCaptor.getValue()); @@ -300,7 +346,16 @@ public void test_PrivateModelGroupWithDifferentOwner() throws IOException, Inter CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -415,7 +470,7 @@ public void test_AddUserBackendRolesFilter() { public void test_CreateSearchSourceBuilder() { User user = User.parse("owner|IT,HR|myTenant"); - assertNotNull(modelAccessControlHelper.createSearchSourceBuilder(user)); + assertNotNull(modelAccessControlHelper.createSearchSourceBuilder(user, Settings.EMPTY)); } private GetResponse modelGroupBuilder(List backendRoles, String access, String owner) throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 36ecd569b0..03dd82524f 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -112,6 +112,7 @@ public void setup() throws IOException { mlModelGroupManager = new MLModelGroupManager( mlIndicesHandler, client, + settings, sdkClient, clusterService, modelAccessControlHelper,