Skip to content

Commit 3fcb10a

Browse files
ricken07tzolov
authored andcommitted
feat(mistral-ai): Add moderation model support (#2201)
Implement MistralAI moderation capabilities to detect potentially harmful content. This allows Spring AI applications to use Mistral's content moderation services to identify and filter inappropriate content before processing - Add MistralAiModerationApi for interacting with Mistral's moderation endpoints - Create MistralAiModerationModel implementing the ModerationModel interface - Add configuration properties and auto-configuration for the moderation model - Extend Categories and CategoryScores with additional moderation categories - Add integration tests to verify moderation functionality Signed-off-by: Ricken Bazolo <ricken.bazolo@gmail.com>
1 parent d30631e commit 3fcb10a

File tree

10 files changed

+611
-4
lines changed

10 files changed

+611
-4
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
2222
import org.springframework.ai.mistralai.MistralAiChatModel;
2323
import org.springframework.ai.mistralai.api.MistralAiApi;
24+
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
25+
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
2426
import org.springframework.ai.model.SpringAIModelProperties;
2527
import org.springframework.ai.model.SpringAIModels;
2628
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
@@ -57,7 +59,8 @@
5759
*/
5860
@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class,
5961
ToolCallingAutoConfiguration.class })
60-
@EnableConfigurationProperties({ MistralAiCommonProperties.class, MistralAiChatProperties.class })
62+
@EnableConfigurationProperties({ MistralAiCommonProperties.class, MistralAiChatProperties.class,
63+
MistralAiModerationProperties.class })
6164
@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.MISTRAL,
6265
matchIfMissing = true)
6366
@ConditionalOnClass(MistralAiApi.class)
@@ -93,6 +96,27 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro
9396
return chatModel;
9497
}
9598

99+
@Bean
100+
@ConditionalOnMissingBean
101+
public MistralAiModerationModel mistralAiModerationModel(MistralAiCommonProperties commonProperties,
102+
MistralAiModerationProperties moderationProperties, RetryTemplate retryTemplate,
103+
ObjectProvider<RestClient.Builder> restClientBuilderProvider, ResponseErrorHandler responseErrorHandler) {
104+
105+
var apiKey = moderationProperties.getApiKey();
106+
var baseUrl = moderationProperties.getBaseUrl();
107+
108+
var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonProperties.getApiKey();
109+
var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonProperties.getBaseUrl();
110+
111+
Assert.hasText(resolvedApiKey, "Mistral API key must be set");
112+
Assert.hasText(resoledBaseUrl, "Mistral base URL must be set");
113+
114+
var mistralAiModerationAi = new MistralAiModerationApi(resoledBaseUrl, resolvedApiKey,
115+
restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler);
116+
117+
return new MistralAiModerationModel(mistralAiModerationAi, retryTemplate, moderationProperties.getOptions());
118+
}
119+
96120
private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
97121
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
98122

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.springframework.ai.model.mistralai.autoconfigure;
2+
3+
import org.springframework.ai.mistralai.moderation.MistralAiModerationOptions;
4+
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
5+
import org.springframework.boot.context.properties.ConfigurationProperties;
6+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
7+
8+
/**
9+
* @author Ricken Bazolo
10+
*/
11+
@ConfigurationProperties(MistralAiModerationProperties.CONFIG_PREFIX)
12+
public class MistralAiModerationProperties extends MistralAiParentProperties {
13+
14+
public static final String CONFIG_PREFIX = "spring.ai.mistralai.moderation";
15+
16+
private static final String DEFAULT_MODERATION_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue();
17+
18+
@NestedConfigurationProperty
19+
private MistralAiModerationOptions options = MistralAiModerationOptions.builder()
20+
.model(DEFAULT_MODERATION_MODEL)
21+
.build();
22+
23+
public MistralAiModerationProperties() {
24+
super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL);
25+
}
26+
27+
public MistralAiModerationOptions getOptions() {
28+
return this.options;
29+
}
30+
31+
public void setOptions(MistralAiModerationOptions options) {
32+
this.options = options;
33+
}
34+
35+
}

auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiPropertiesTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,17 @@ public void embeddingOptionsTest() {
145145
});
146146
}
147147

148+
@Test
149+
public void moderationOptionsTest() {
150+
new ApplicationContextRunner()
151+
.withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123",
152+
"spring.ai.mistralai.moderation.options.model=MODERATION_MODEL")
153+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
154+
RestClientAutoConfiguration.class, MistralAiChatAutoConfiguration.class))
155+
.run(context -> {
156+
var moderationProperties = context.getBean(MistralAiModerationProperties.class);
157+
assertThat(moderationProperties.getOptions().getModel()).isEqualTo("MODERATION_MODEL");
158+
});
159+
}
160+
148161
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package org.springframework.ai.mistralai.api;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import org.springframework.ai.retry.RetryUtils;
6+
import org.springframework.http.HttpHeaders;
7+
import org.springframework.http.MediaType;
8+
import org.springframework.http.ResponseEntity;
9+
import org.springframework.util.Assert;
10+
import org.springframework.web.client.ResponseErrorHandler;
11+
import org.springframework.web.client.RestClient;
12+
13+
import java.util.function.Consumer;
14+
15+
/**
16+
* MistralAI Moderation API.
17+
*
18+
* @author Ricken Bazolo
19+
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a>
20+
*/
21+
public class MistralAiModerationApi {
22+
23+
private static final String DEFAULT_BASE_URL = "https://api.mistral.ai";
24+
25+
private final RestClient restClient;
26+
27+
public MistralAiModerationApi(String mistralAiApiKey) {
28+
this(DEFAULT_BASE_URL, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
29+
}
30+
31+
public MistralAiModerationApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
32+
ResponseErrorHandler responseErrorHandler) {
33+
34+
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
35+
headers.setBearerAuth(mistralAiApiKey);
36+
headers.setContentType(MediaType.APPLICATION_JSON);
37+
};
38+
39+
this.restClient = restClientBuilder.baseUrl(baseUrl)
40+
.defaultHeaders(jsonContentHeaders)
41+
.defaultStatusHandler(responseErrorHandler)
42+
.build();
43+
}
44+
45+
public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationRequest mistralAiModerationRequest) {
46+
Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null.");
47+
Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty.");
48+
Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null.");
49+
50+
return this.restClient.post()
51+
.uri("v1/moderations")
52+
.body(mistralAiModerationRequest)
53+
.retrieve()
54+
.toEntity(MistralAiModerationResponse.class);
55+
}
56+
57+
public enum Model {
58+
59+
// @formatter:off
60+
MISTRAL_MODERATION("mistral-moderation-latest");
61+
// @formatter:on
62+
63+
private final String value;
64+
65+
Model(String value) {
66+
this.value = value;
67+
}
68+
69+
public String getValue() {
70+
return this.value;
71+
}
72+
73+
}
74+
75+
// @formatter:off
76+
@JsonInclude(JsonInclude.Include.NON_NULL)
77+
public record MistralAiModerationRequest(
78+
@JsonProperty("input") String prompt,
79+
@JsonProperty("model") String model
80+
) {
81+
82+
public MistralAiModerationRequest(String prompt) {
83+
this(prompt, null);
84+
}
85+
}
86+
87+
@JsonInclude(JsonInclude.Include.NON_NULL)
88+
public record MistralAiModerationResponse(
89+
@JsonProperty("id") String id,
90+
@JsonProperty("model") String model,
91+
@JsonProperty("results") MistralAiModerationResult[] results) {
92+
93+
}
94+
95+
@JsonInclude(JsonInclude.Include.NON_NULL)
96+
public record MistralAiModerationResult(
97+
@JsonProperty("categories") Categories categories,
98+
@JsonProperty("category_scores") CategoryScores categoryScores) {
99+
100+
public boolean flagged() {
101+
return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats()
102+
|| categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health()
103+
|| categories.financial() || categories.law() || categories.pii());
104+
}
105+
106+
}
107+
108+
@JsonInclude(JsonInclude.Include.NON_NULL)
109+
public record Categories(
110+
@JsonProperty("sexual") boolean sexual,
111+
@JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination,
112+
@JsonProperty("violence_and_threats") boolean violenceAndThreats,
113+
@JsonProperty("selfharm") boolean selfHarm,
114+
@JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent,
115+
@JsonProperty("health") boolean health,
116+
@JsonProperty("financial") boolean financial,
117+
@JsonProperty("law") boolean law,
118+
@JsonProperty("pii") boolean pii) {
119+
120+
}
121+
122+
@JsonInclude(JsonInclude.Include.NON_NULL)
123+
public record CategoryScores(
124+
@JsonProperty("sexual") double sexual,
125+
@JsonProperty("hate_and_discrimination") double hateAndDiscrimination,
126+
@JsonProperty("violence_and_threats") double violenceAndThreats,
127+
@JsonProperty("selfharm") double selfHarm,
128+
@JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent,
129+
@JsonProperty("health") double health,
130+
@JsonProperty("financial") double financial,
131+
@JsonProperty("law") double law,
132+
@JsonProperty("pii") double pii) {
133+
134+
}
135+
// @formatter:onn
136+
137+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package org.springframework.ai.mistralai.moderation;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
6+
import org.springframework.ai.model.ModelOptionsUtils;
7+
import org.springframework.ai.moderation.*;
8+
import org.springframework.ai.retry.RetryUtils;
9+
import org.springframework.http.ResponseEntity;
10+
import org.springframework.retry.support.RetryTemplate;
11+
import org.springframework.util.Assert;
12+
13+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest;
14+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse;
15+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
20+
/**
21+
* @author Ricken Bazolo
22+
*/
23+
public class MistralAiModerationModel implements ModerationModel {
24+
25+
private final Logger logger = LoggerFactory.getLogger(getClass());
26+
27+
private final MistralAiModerationApi mistralAiModerationApi;
28+
29+
private final RetryTemplate retryTemplate;
30+
31+
private final MistralAiModerationOptions defaultOptions;
32+
33+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) {
34+
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE,
35+
MistralAiModerationOptions.builder()
36+
.model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue())
37+
.build());
38+
}
39+
40+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) {
41+
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options);
42+
}
43+
44+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate,
45+
MistralAiModerationOptions options) {
46+
Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null");
47+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
48+
Assert.notNull(options, "options must not be null");
49+
this.mistralAiModerationApi = mistralAiModerationApi;
50+
this.retryTemplate = retryTemplate;
51+
this.defaultOptions = options;
52+
}
53+
54+
@Override
55+
public ModerationResponse call(ModerationPrompt moderationPrompt) {
56+
return this.retryTemplate.execute(ctx -> {
57+
58+
var instructions = moderationPrompt.getInstructions().getText();
59+
60+
var moderationRequest = new MistralAiModerationRequest(instructions);
61+
62+
if (this.defaultOptions != null) {
63+
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest,
64+
MistralAiModerationRequest.class);
65+
}
66+
else {
67+
// moderationPrompt.getOptions() never null but model can be empty, cause
68+
// by ModerationPrompt constructor
69+
moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()),
70+
moderationRequest, MistralAiModerationRequest.class);
71+
}
72+
73+
var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest);
74+
75+
return convertResponse(moderationResponseEntity, moderationRequest);
76+
});
77+
}
78+
79+
private ModerationResponse convertResponse(ResponseEntity<MistralAiModerationResponse> moderationResponseEntity,
80+
MistralAiModerationRequest openAiModerationRequest) {
81+
var moderationApiResponse = moderationResponseEntity.getBody();
82+
if (moderationApiResponse == null) {
83+
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
84+
return new ModerationResponse(new Generation());
85+
}
86+
87+
List<ModerationResult> moderationResults = new ArrayList<>();
88+
if (moderationApiResponse.results() != null) {
89+
90+
for (MistralAiModerationResult result : moderationApiResponse.results()) {
91+
Categories categories = null;
92+
CategoryScores categoryScores = null;
93+
if (result.categories() != null) {
94+
categories = Categories.builder()
95+
.sexual(result.categories().sexual())
96+
.pii(result.categories().pii())
97+
.law(result.categories().law())
98+
.financial(result.categories().financial())
99+
.health(result.categories().health())
100+
.dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent())
101+
.violence(result.categories().violenceAndThreats())
102+
.hate(result.categories().hateAndDiscrimination())
103+
.selfHarm(result.categories().selfHarm())
104+
.build();
105+
}
106+
if (result.categoryScores() != null) {
107+
categoryScores = CategoryScores.builder()
108+
.sexual(result.categoryScores().sexual())
109+
.pii(result.categoryScores().pii())
110+
.law(result.categoryScores().law())
111+
.financial(result.categoryScores().financial())
112+
.health(result.categoryScores().health())
113+
.dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent())
114+
.violence(result.categoryScores().violenceAndThreats())
115+
.hate(result.categoryScores().hateAndDiscrimination())
116+
.selfHarm(result.categoryScores().selfHarm())
117+
.build();
118+
}
119+
var moderationResult = ModerationResult.builder()
120+
.categories(categories)
121+
.categoryScores(categoryScores)
122+
.flagged(result.flagged())
123+
.build();
124+
moderationResults.add(moderationResult);
125+
}
126+
127+
}
128+
129+
var moderation = Moderation.builder()
130+
.id(moderationApiResponse.id())
131+
.model(moderationApiResponse.model())
132+
.results(moderationResults)
133+
.build();
134+
135+
return new ModerationResponse(new Generation(moderation));
136+
}
137+
138+
private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) {
139+
var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder();
140+
if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) {
141+
mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel());
142+
}
143+
return mistralAiModerationOptionsBuilder.build();
144+
}
145+
146+
}

0 commit comments

Comments
 (0)