Skip to content

Commit 7254e8b

Browse files
author
nickchecan
committed
refactor: adapt ai client factory to also handle ollama models
1 parent 36faa04 commit 7254e8b

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

com.developer.nefarious.zjoule.plugin/src/com/developer/nefarious/zjoule/plugin/auth/SessionManager.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ public abstract class SessionManager {
2121

2222

2323
public static boolean isUserLoggedIn() {
24-
return (isSapSessionOn() || isOllamaSessionOn()) ? true : false;
24+
return (isSapSession() || isOllamaSession()) ? true : false;
2525
}
2626

27-
private static boolean isSapSessionOn() {
27+
public static boolean isSapSession() {
2828
MemoryAccessToken memoryAccessToken = MemoryAccessToken.getInstance();
2929
MemoryServiceKey memoryServiceKey = MemoryServiceKey.getInstance();
3030
MemoryResourceGroup memoryResourceGroup = MemoryResourceGroup.getInstance();
@@ -34,7 +34,7 @@ private static boolean isSapSessionOn() {
3434
|| memoryDeployment.isEmpty()) ? false : true;
3535
}
3636

37-
private static boolean isOllamaSessionOn() {
37+
public static boolean isOllamaSession() {
3838
MemoryOllamaEndpoint memoryOllamaEndpoint = MemoryOllamaEndpoint.getInstance();
3939
MemoryOllamaModel memoryOllamaModel = MemoryOllamaModel.getInstance();
4040

com.developer.nefarious.zjoule.plugin/src/com/developer/nefarious/zjoule/plugin/chat/AIClientFactory.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import com.developer.nefarious.zjoule.plugin.auth.AuthClient;
66
import com.developer.nefarious.zjoule.plugin.auth.AuthClientHelper;
7+
import com.developer.nefarious.zjoule.plugin.auth.SessionManager;
78
import com.developer.nefarious.zjoule.plugin.chat.memory.MemoryMessageHistory;
89
import com.developer.nefarious.zjoule.plugin.chat.openai.OpenAIClient;
910
import com.developer.nefarious.zjoule.plugin.chat.openai.OpenAIClientHelper;
@@ -29,8 +30,17 @@ public abstract class AIClientFactory {
2930
* @return an instance of {@link IAIClient} for the corresponding model, or {@code null} if unsupported.
3031
*/
3132
public static IAIClient getClient() {
32-
33-
// Load memory components for access token, service key, resource group, deployment, and message history.
33+
if (SessionManager.isSapSession()) {
34+
return getClientForSapAiCore();
35+
} else if (SessionManager.isOllamaSession()) {
36+
return getClientForOllama();
37+
} else {
38+
return null;
39+
}
40+
}
41+
42+
private static IAIClient getClientForSapAiCore() {
43+
// Load memory components for access token, service key, resource group, deployment, and message history.
3444
MemoryAccessToken memoryAccessToken = MemoryAccessToken.getInstance();
3545
MemoryServiceKey memoryServiceKey = MemoryServiceKey.getInstance();
3646
MemoryResourceGroup memoryResourceGroup = MemoryResourceGroup.getInstance();
@@ -52,6 +62,11 @@ public static IAIClient getClient() {
5262
return null;
5363
}
5464
}
65+
66+
private static IAIClient getClientForOllama() {
67+
return null;
68+
69+
}
5570

5671
/**
5772
* Checks if the given model name corresponds to an OpenAI model.
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.mockito.MockedStatic;
1313
import org.mockito.MockitoAnnotations;
1414

15+
import com.developer.nefarious.zjoule.plugin.auth.SessionManager;
1516
import com.developer.nefarious.zjoule.plugin.chat.AIClientFactory;
1617
import com.developer.nefarious.zjoule.plugin.chat.IAIClient;
1718
import com.developer.nefarious.zjoule.plugin.chat.memory.MemoryMessageHistory;
@@ -22,8 +23,10 @@
2223
import com.developer.nefarious.zjoule.plugin.memory.MemoryServiceKey;
2324
import com.developer.nefarious.zjoule.plugin.models.Deployment;
2425

25-
public class AIClientFactoryTest {
26-
26+
public class AIClientFactorySapTest {
27+
28+
private MockedStatic<SessionManager> mockStaticSessionManager;
29+
2730
private MockedStatic<MemoryAccessToken> mockStaticMemoryAccessToken;
2831

2932
private MockedStatic<MemoryServiceKey> mockStaticMemoryServiceKey;
@@ -61,12 +64,15 @@ public void setUp() {
6164
mockStaticMemoryResourceGroup = mockStatic(MemoryResourceGroup.class);
6265
mockStaticMemoryDeployment = mockStatic(MemoryDeployment.class);
6366
mockStaticMemoryMessageHistory = mockStatic(MemoryMessageHistory.class);
67+
mockStaticSessionManager = mockStatic(SessionManager.class);
6468

6569
mockStaticMemoryAccessToken.when(MemoryAccessToken::getInstance).thenReturn(mockMemoryAccessToken);
6670
mockStaticMemoryServiceKey.when(MemoryServiceKey::getInstance).thenReturn(mockMemoryServiceKey);
6771
mockStaticMemoryResourceGroup.when(MemoryResourceGroup::getInstance).thenReturn(mockMemoryResourceGroup);
6872
mockStaticMemoryDeployment.when(MemoryDeployment::getInstance).thenReturn(mockMemoryDeployment);
6973
mockStaticMemoryMessageHistory.when(MemoryMessageHistory::getInstance).thenReturn(mockMemoryMessageHistory);
74+
mockStaticSessionManager.when(SessionManager::isSapSession).thenReturn(true);
75+
mockStaticSessionManager.when(SessionManager::isOllamaSession).thenReturn(false);
7076

7177
when(mockMemoryDeployment.load()).thenReturn(mockDeployment);
7278
}
@@ -138,6 +144,9 @@ public void tearDown() {
138144
if (mockStaticMemoryMessageHistory != null) {
139145
mockStaticMemoryMessageHistory.close();
140146
}
147+
if (mockStaticSessionManager != null) {
148+
mockStaticSessionManager.close();
149+
}
141150
}
142151

143152
}

0 commit comments

Comments
 (0)