Skip to content

Commit dd887bc

Browse files
authored
exclude trusted connector check for hidden model (#3838)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 0cba0ed commit dd887bc

File tree

2 files changed

+112
-21
lines changed

2 files changed

+112
-21
lines changed

plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import org.opensearch.transport.TransportService;
7373
import org.opensearch.transport.client.Client;
7474

75+
import com.google.common.annotations.VisibleForTesting;
7576
import com.google.common.collect.ImmutableList;
7677
import com.google.common.collect.ImmutableMap;
7778

@@ -171,7 +172,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
171172
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
172173
);
173174
}
174-
registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client));
175+
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
176+
registerModelInput.setIsHidden(isSuperAdmin);
175177
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
176178
mlModelGroupManager
177179
.validateUniqueModelGroupName(
@@ -368,7 +370,11 @@ private void validateInternalConnector(MLRegisterModelInput registerModelInput)
368370
throw new IllegalArgumentException("Connector endpoint is required when creating a remote model without connector id!");
369371
}
370372
// check if the connector url is trusted
371-
registerModelInput.getConnector().validateConnectorURL(trustedConnectorEndpointsRegex);
373+
// if the model is a hidden model, that means Superuser of this domain or cloud provider is settings up this
374+
// model, so no need to verify the connector endpoint as trusted or not
375+
if (!registerModelInput.getIsHidden()) {
376+
registerModelInput.getConnector().validateConnectorURL(trustedConnectorEndpointsRegex);
377+
}
372378
}
373379

374380
private void registerModel(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener) {
@@ -470,4 +476,10 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode
470476
.tenantId(registerModelInput.getTenantId())
471477
.build();
472478
}
479+
480+
// this method is only to stub static method.
481+
@VisibleForTesting
482+
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
483+
return RestActionUtils.isSuperAdminUser(clusterService, client);
484+
}
473485
}

plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import static org.mockito.ArgumentMatchers.eq;
1111
import static org.mockito.ArgumentMatchers.isA;
1212
import static org.mockito.Mockito.doAnswer;
13+
import static org.mockito.Mockito.doReturn;
1314
import static org.mockito.Mockito.doThrow;
1415
import static org.mockito.Mockito.mock;
16+
import static org.mockito.Mockito.spy;
1517
import static org.mockito.Mockito.verify;
1618
import static org.mockito.Mockito.when;
1719
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
@@ -182,25 +184,27 @@ public void setup() throws IOException {
182184
);
183185
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
184186
when(clusterService.getSettings()).thenReturn(settings);
185-
transportRegisterModelAction = new TransportRegisterModelAction(
186-
transportService,
187-
actionFilters,
188-
modelHelper,
189-
mlIndicesHandler,
190-
mlModelManager,
191-
mlTaskManager,
192-
clusterService,
193-
settings,
194-
threadPool,
195-
client,
196-
sdkClient,
197-
nodeFilter,
198-
mlTaskDispatcher,
199-
mlStats,
200-
modelAccessControlHelper,
201-
connectorAccessControlHelper,
202-
mlModelGroupManager,
203-
mlFeatureEnabledSetting
187+
transportRegisterModelAction = spy(
188+
new TransportRegisterModelAction(
189+
transportService,
190+
actionFilters,
191+
modelHelper,
192+
mlIndicesHandler,
193+
mlModelManager,
194+
mlTaskManager,
195+
clusterService,
196+
settings,
197+
threadPool,
198+
client,
199+
sdkClient,
200+
nodeFilter,
201+
mlTaskDispatcher,
202+
mlStats,
203+
modelAccessControlHelper,
204+
connectorAccessControlHelper,
205+
mlModelGroupManager,
206+
mlFeatureEnabledSetting
207+
)
204208
);
205209
assertNotNull(transportRegisterModelAction);
206210

@@ -594,6 +598,79 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
594598
);
595599
}
596600

601+
@Test
602+
public void test_execute_registerRemoteModel_withUntrustedEndpoint() {
603+
// Create request and input mocks
604+
MLRegisterModelRequest request = mock(MLRegisterModelRequest.class);
605+
MLRegisterModelInput input = MLRegisterModelInput
606+
.builder()
607+
.functionName(FunctionName.REMOTE)
608+
.isHidden(false)
609+
.modelName("test-model")
610+
.build();
611+
612+
// Create a proper Connector instance instead of mocking it
613+
Connector connector = mock(Connector.class);
614+
when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://untrusted-endpoint.com");
615+
// Set the connector on the input
616+
input.setConnector(connector);
617+
618+
when(request.getRegisterModelInput()).thenReturn(input);
619+
620+
// Mock super admin check
621+
doReturn(false).when(transportRegisterModelAction).isSuperAdminUserWrapper(any(), any());
622+
623+
// Mock model group validation
624+
SearchResponse searchResponse = mock(SearchResponse.class);
625+
SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), 0.0f);
626+
when(searchResponse.getHits()).thenReturn(searchHits);
627+
628+
doAnswer(invocation -> {
629+
ActionListener<SearchResponse> listener = invocation.getArgument(2);
630+
listener.onResponse(searchResponse);
631+
return null;
632+
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any());
633+
634+
// Mock connector validation to throw exception
635+
doThrow(new IllegalArgumentException("The connector endpoint provided is not trusted")).when(connector).validateConnectorURL(any());
636+
637+
// Execute
638+
transportRegisterModelAction.doExecute(task, request, actionListener);
639+
640+
// Verify
641+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
642+
verify(actionListener).onFailure(argumentCaptor.capture());
643+
assertTrue(argumentCaptor.getValue().getMessage().contains("not trusted"));
644+
}
645+
646+
@Test
647+
public void test_execute_registerRemoteModel_withUntrustedEndpoint_hidden_model() {
648+
MLRegisterModelRequest request = mock(MLRegisterModelRequest.class);
649+
MLRegisterModelInput input = mock(MLRegisterModelInput.class);
650+
when(request.getRegisterModelInput()).thenReturn(input);
651+
when(input.getModelName()).thenReturn("Test Model");
652+
when(input.getVersion()).thenReturn("1");
653+
when(input.getModelGroupId()).thenReturn("modelGroupID");
654+
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
655+
when(input.getIsHidden()).thenReturn(true);
656+
657+
// Create a proper Connector instance instead of mocking it
658+
Connector connector = mock(Connector.class);
659+
when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://untrusted-endpoint.com");
660+
// Set the connector on the input
661+
when(input.getConnector()).thenReturn(connector);
662+
MLCreateConnectorResponse mlCreateConnectorResponse = mock(MLCreateConnectorResponse.class);
663+
doAnswer(invocation -> {
664+
ActionListener<MLCreateConnectorResponse> listener = invocation.getArgument(2);
665+
listener.onResponse(mlCreateConnectorResponse);
666+
return null;
667+
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), isA(ActionListener.class));
668+
MLRegisterModelResponse response = mock(MLRegisterModelResponse.class);
669+
transportRegisterModelAction.doExecute(task, request, actionListener);
670+
ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
671+
verify(mlModelManager).registerMLRemoteModel(eq(sdkClient), eq(input), isA(MLTask.class), eq(actionListener));
672+
}
673+
597674
@Test
598675
public void test_ModelNameAlreadyExists() throws IOException {
599676
when(node1.getId()).thenReturn("NodeId1");
@@ -647,6 +724,7 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException
647724
);
648725
}
649726

727+
@Test
650728
public void test_FailureWhenSearchingModelGroupName() throws IOException {
651729
doAnswer(invocation -> {
652730
ActionListener<SearchResponse> listener = invocation.getArgument(2);
@@ -661,6 +739,7 @@ public void test_FailureWhenSearchingModelGroupName() throws IOException {
661739
assertEquals("Runtime exception", argumentCaptor.getValue().getMessage());
662740
}
663741

742+
@Test
664743
public void test_NoAccessWhenModelNameAlreadyExists() throws IOException {
665744

666745
SearchResponse searchResponse = createModelGroupSearchResponse(1);

0 commit comments

Comments
 (0)