10
10
import static org .mockito .ArgumentMatchers .eq ;
11
11
import static org .mockito .ArgumentMatchers .isA ;
12
12
import static org .mockito .Mockito .doAnswer ;
13
+ import static org .mockito .Mockito .doReturn ;
13
14
import static org .mockito .Mockito .doThrow ;
14
15
import static org .mockito .Mockito .mock ;
16
+ import static org .mockito .Mockito .spy ;
15
17
import static org .mockito .Mockito .verify ;
16
18
import static org .mockito .Mockito .when ;
17
19
import static org .opensearch .ml .common .settings .MLCommonsSettings .ML_COMMONS_ALLOW_MODEL_URL ;
@@ -182,25 +184,27 @@ public void setup() throws IOException {
182
184
);
183
185
when (clusterService .getClusterSettings ()).thenReturn (clusterSettings );
184
186
when (clusterService .getSettings ()).thenReturn (settings );
185
- transportRegisterModelAction = new TransportRegisterModelAction (
186
- transportService ,
187
- actionFilters ,
188
- modelHelper ,
189
- mlIndicesHandler ,
190
- mlModelManager ,
191
- mlTaskManager ,
192
- clusterService ,
193
- settings ,
194
- threadPool ,
195
- client ,
196
- sdkClient ,
197
- nodeFilter ,
198
- mlTaskDispatcher ,
199
- mlStats ,
200
- modelAccessControlHelper ,
201
- connectorAccessControlHelper ,
202
- mlModelGroupManager ,
203
- mlFeatureEnabledSetting
187
+ transportRegisterModelAction = spy (
188
+ new TransportRegisterModelAction (
189
+ transportService ,
190
+ actionFilters ,
191
+ modelHelper ,
192
+ mlIndicesHandler ,
193
+ mlModelManager ,
194
+ mlTaskManager ,
195
+ clusterService ,
196
+ settings ,
197
+ threadPool ,
198
+ client ,
199
+ sdkClient ,
200
+ nodeFilter ,
201
+ mlTaskDispatcher ,
202
+ mlStats ,
203
+ modelAccessControlHelper ,
204
+ connectorAccessControlHelper ,
205
+ mlModelGroupManager ,
206
+ mlFeatureEnabledSetting
207
+ )
204
208
);
205
209
assertNotNull (transportRegisterModelAction );
206
210
@@ -594,6 +598,79 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
594
598
);
595
599
}
596
600
601
+ @ Test
602
+ public void test_execute_registerRemoteModel_withUntrustedEndpoint () {
603
+ // Create request and input mocks
604
+ MLRegisterModelRequest request = mock (MLRegisterModelRequest .class );
605
+ MLRegisterModelInput input = MLRegisterModelInput
606
+ .builder ()
607
+ .functionName (FunctionName .REMOTE )
608
+ .isHidden (false )
609
+ .modelName ("test-model" )
610
+ .build ();
611
+
612
+ // Create a proper Connector instance instead of mocking it
613
+ Connector connector = mock (Connector .class );
614
+ when (connector .getActionEndpoint (anyString (), any (Map .class ))).thenReturn ("https://untrusted-endpoint.com" );
615
+ // Set the connector on the input
616
+ input .setConnector (connector );
617
+
618
+ when (request .getRegisterModelInput ()).thenReturn (input );
619
+
620
+ // Mock super admin check
621
+ doReturn (false ).when (transportRegisterModelAction ).isSuperAdminUserWrapper (any (), any ());
622
+
623
+ // Mock model group validation
624
+ SearchResponse searchResponse = mock (SearchResponse .class );
625
+ SearchHits searchHits = new SearchHits (new SearchHit [0 ], new TotalHits (0L , TotalHits .Relation .EQUAL_TO ), 0.0f );
626
+ when (searchResponse .getHits ()).thenReturn (searchHits );
627
+
628
+ doAnswer (invocation -> {
629
+ ActionListener <SearchResponse > listener = invocation .getArgument (2 );
630
+ listener .onResponse (searchResponse );
631
+ return null ;
632
+ }).when (mlModelGroupManager ).validateUniqueModelGroupName (any (), any (), any ());
633
+
634
+ // Mock connector validation to throw exception
635
+ doThrow (new IllegalArgumentException ("The connector endpoint provided is not trusted" )).when (connector ).validateConnectorURL (any ());
636
+
637
+ // Execute
638
+ transportRegisterModelAction .doExecute (task , request , actionListener );
639
+
640
+ // Verify
641
+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
642
+ verify (actionListener ).onFailure (argumentCaptor .capture ());
643
+ assertTrue (argumentCaptor .getValue ().getMessage ().contains ("not trusted" ));
644
+ }
645
+
646
+ @ Test
647
+ public void test_execute_registerRemoteModel_withUntrustedEndpoint_hidden_model () {
648
+ MLRegisterModelRequest request = mock (MLRegisterModelRequest .class );
649
+ MLRegisterModelInput input = mock (MLRegisterModelInput .class );
650
+ when (request .getRegisterModelInput ()).thenReturn (input );
651
+ when (input .getModelName ()).thenReturn ("Test Model" );
652
+ when (input .getVersion ()).thenReturn ("1" );
653
+ when (input .getModelGroupId ()).thenReturn ("modelGroupID" );
654
+ when (input .getFunctionName ()).thenReturn (FunctionName .REMOTE );
655
+ when (input .getIsHidden ()).thenReturn (true );
656
+
657
+ // Create a proper Connector instance instead of mocking it
658
+ Connector connector = mock (Connector .class );
659
+ when (connector .getActionEndpoint (anyString (), any (Map .class ))).thenReturn ("https://untrusted-endpoint.com" );
660
+ // Set the connector on the input
661
+ when (input .getConnector ()).thenReturn (connector );
662
+ MLCreateConnectorResponse mlCreateConnectorResponse = mock (MLCreateConnectorResponse .class );
663
+ doAnswer (invocation -> {
664
+ ActionListener <MLCreateConnectorResponse > listener = invocation .getArgument (2 );
665
+ listener .onResponse (mlCreateConnectorResponse );
666
+ return null ;
667
+ }).when (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), any (), isA (ActionListener .class ));
668
+ MLRegisterModelResponse response = mock (MLRegisterModelResponse .class );
669
+ transportRegisterModelAction .doExecute (task , request , actionListener );
670
+ ArgumentCaptor <MLRegisterModelResponse > argumentCaptor = ArgumentCaptor .forClass (MLRegisterModelResponse .class );
671
+ verify (mlModelManager ).registerMLRemoteModel (eq (sdkClient ), eq (input ), isA (MLTask .class ), eq (actionListener ));
672
+ }
673
+
597
674
@ Test
598
675
public void test_ModelNameAlreadyExists () throws IOException {
599
676
when (node1 .getId ()).thenReturn ("NodeId1" );
@@ -647,6 +724,7 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException
647
724
);
648
725
}
649
726
727
+ @ Test
650
728
public void test_FailureWhenSearchingModelGroupName () throws IOException {
651
729
doAnswer (invocation -> {
652
730
ActionListener <SearchResponse > listener = invocation .getArgument (2 );
@@ -661,6 +739,7 @@ public void test_FailureWhenSearchingModelGroupName() throws IOException {
661
739
assertEquals ("Runtime exception" , argumentCaptor .getValue ().getMessage ());
662
740
}
663
741
742
+ @ Test
664
743
public void test_NoAccessWhenModelNameAlreadyExists () throws IOException {
665
744
666
745
SearchResponse searchResponse = createModelGroupSearchResponse (1 );
0 commit comments