Skip to content

Add BDD-based rules engine trait #2703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/spotbugs/filter.xml
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,10 @@
<Bug code="CT"/>
</Match>

<!-- This is a false positive. Yeah, we have a terminal node that's a singleton, but the ctor is still valid. -->
<Match>
<Class name="software.amazon.smithy.rulesengine.logic.cfg.ResultNode" />
<Bug pattern="SING_SINGLETON_HAS_NONPRIVATE_CONSTRUCTOR" />
</Match>

</FindBugsFilter>
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package software.amazon.smithy.rulesengine.aws.language.functions;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -39,28 +39,67 @@ private AwsArn(Builder builder) {
* @return the optional ARN.
*/
public static Optional<AwsArn> parse(String arn) {
String[] base = arn.split(":", 6);
if (base.length != 6) {
if (arn == null || arn.length() < 8 || !arn.startsWith("arn:")) {
return Optional.empty();
}
// First section must be "arn".
if (!base[0].equals("arn")) {

// find each of the first five ':' positions
int p0 = 3; // after "arn"
int p1 = arn.indexOf(':', p0 + 1);
if (p1 < 0) {
return Optional.empty();
}

int p2 = arn.indexOf(':', p1 + 1);
if (p2 < 0) {
return Optional.empty();
}

int p3 = arn.indexOf(':', p2 + 1);
if (p3 < 0) {
return Optional.empty();
}
// Sections for partition, service, and resource type must not be empty.
if (base[1].isEmpty() || base[2].isEmpty() || base[5].isEmpty()) {

int p4 = arn.indexOf(':', p3 + 1);
if (p4 < 0) {
return Optional.empty();
}

// extract and validate mandatory parts
String partition = arn.substring(p0 + 1, p1);
String service = arn.substring(p1 + 1, p2);
String region = arn.substring(p2 + 1, p3);
String accountId = arn.substring(p3 + 1, p4);
String resource = arn.substring(p4 + 1);

if (partition.isEmpty() || service.isEmpty() || resource.isEmpty()) {
return Optional.empty();
}

return Optional.of(builder()
.partition(base[1])
.service(base[2])
.region(base[3])
.accountId(base[4])
.resource(Arrays.asList(base[5].split("[:/]", -1)))
.partition(partition)
.service(service)
.region(region)
.accountId(accountId)
.resource(splitResource(resource))
.build());
}

private static List<String> splitResource(String resource) {
List<String> result = new ArrayList<>();
int start = 0;
int length = resource.length();
for (int i = 0; i < length; i++) {
char c = resource.charAt(i);
if (c == ':' || c == '/') {
result.add(resource.substring(start, i));
start = i + 1;
}
}
result.add(resource.substring(start));
return result;
}

/**
* Builder to create an {@link AwsArn} instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ public Value evaluate(List<Value> arguments) {
public AwsPartition createFunction(FunctionNode functionNode) {
return new AwsPartition(functionNode);
}

@Override
public int getCostHeuristic() {
return 6;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public Value evaluate(List<Value> arguments) {
public IsVirtualHostableS3Bucket createFunction(FunctionNode functionNode) {
return new IsVirtualHostableS3Bucket(functionNode);
}

@Override
public int getCostHeuristic() {
return 8;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,10 @@ public Value evaluate(List<Value> arguments) {
public ParseArn createFunction(FunctionNode functionNode) {
return new ParseArn(functionNode);
}

@Override
public int getCostHeuristic() {
return 9;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import software.amazon.smithy.model.FromSourceLocation;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.validation.AbstractValidator;
import software.amazon.smithy.model.validation.ValidationEvent;
import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns;
import software.amazon.smithy.rulesengine.language.EndpointRuleSet;
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter;
import software.amazon.smithy.rulesengine.traits.BddTrait;
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait;
import software.amazon.smithy.utils.SetUtils;

Expand All @@ -33,36 +32,30 @@ public class RuleSetAwsBuiltInValidator extends AbstractValidator {
@Override
public List<ValidationEvent> validate(Model model) {
List<ValidationEvent> events = new ArrayList<>();
for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) {
events.addAll(validateRuleSetAwsBuiltIns(serviceShape,
serviceShape.expectTrait(EndpointRuleSetTrait.class)
.getEndpointRuleSet()));

for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) {
EndpointRuleSetTrait trait = s.expectTrait(EndpointRuleSetTrait.class);
validateRuleSetAwsBuiltIns(events, s, trait.getEndpointRuleSet().getParameters());
}

for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) {
validateRuleSetAwsBuiltIns(events, s, s.expectTrait(BddTrait.class).getBdd().getParameters());
}

return events;
}

private List<ValidationEvent> validateRuleSetAwsBuiltIns(ServiceShape serviceShape, EndpointRuleSet ruleSet) {
List<ValidationEvent> events = new ArrayList<>();
for (Parameter parameter : ruleSet.getParameters()) {
private void validateRuleSetAwsBuiltIns(List<ValidationEvent> events, ServiceShape s, Iterable<Parameter> params) {
for (Parameter parameter : params) {
if (parameter.isBuiltIn()) {
validateBuiltIn(serviceShape, parameter.getBuiltIn().get(), parameter).ifPresent(events::add);
validateBuiltIn(events, s, parameter.getBuiltIn().get(), parameter);
}
}
return events;
}

private Optional<ValidationEvent> validateBuiltIn(
ServiceShape serviceShape,
String builtInName,
FromSourceLocation source
) {
if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(builtInName)) {
return Optional.of(danger(
serviceShape,
source,
String.format(ADDITIONAL_CONSIDERATION_MESSAGE, builtInName),
builtInName));
private void validateBuiltIn(List<ValidationEvent> events, ServiceShape s, String name, FromSourceLocation source) {
if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(name)) {
events.add(danger(s, source, String.format(ADDITIONAL_CONSIDERATION_MESSAGE, name), name));
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,21 @@ public int hashCode() {
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("url: ").append(url).append("\n");
sb.append("url: ").append(url);

if (!headers.isEmpty()) {
sb.append("headers:\n");
sb.append("\nheaders:");
for (Map.Entry<String, List<Expression>> entry : headers.entrySet()) {
sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2))
.append("\n");
sb.append("\n");
sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2));
}
}

if (!properties.isEmpty()) {
sb.append("properties:\n");
sb.append("\nproperties:");
for (Map.Entry<Identifier, Literal> entry : properties.entrySet()) {
sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2))
.append("\n");
sb.append("\n");
sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr;
import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal;
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter;
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters;
import software.amazon.smithy.rulesengine.language.syntax.rule.Condition;
import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule;
import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule;
import software.amazon.smithy.rulesengine.language.syntax.rule.Rule;
import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor;
import software.amazon.smithy.rulesengine.logic.RuleBasedConditionEvaluator;
import software.amazon.smithy.rulesengine.logic.bdd.Bdd;
import software.amazon.smithy.rulesengine.logic.bdd.BddEvaluator;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -44,6 +50,44 @@ public static Value evaluate(EndpointRuleSet ruleset, Map<Identifier, Value> par
return new RuleEvaluator().evaluateRuleSet(ruleset, parameterArguments);
}

/**
* Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments.
*
* @param bdd The endpoint bdd.
* @param parameterArguments The rule-set parameter identifiers and values to evaluate the BDD against.
* @return The resulting value from the final matched rule.
*/
public static Value evaluate(Bdd bdd, Map<Identifier, Value> parameterArguments) {
return new RuleEvaluator().evaluateBdd(bdd, parameterArguments);
}

private Value evaluateBdd(Bdd bdd, Map<Identifier, Value> parameterArguments) {
return scope.inScope(() -> {
for (Parameter parameter : bdd.getParameters()) {
parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value));
}

parameterArguments.forEach(scope::insert);
BddEvaluator evaluator = BddEvaluator.from(bdd);
Condition[] conds = bdd.getConditions().toArray(new Condition[0]);
RuleBasedConditionEvaluator conditionEvaluator = new RuleBasedConditionEvaluator(this, conds);
int result = evaluator.evaluate(conditionEvaluator);

if (result <= 0) {
throw new RuntimeException("No BDD result matched");
}

Rule rule = bdd.getResults().get(result);
if (rule instanceof EndpointRule) {
return resolveEndpoint(this, ((EndpointRule) rule).getEndpoint());
} else if (rule instanceof ErrorRule) {
return resolveError(this, ((ErrorRule) rule).getError());
} else {
throw new RuntimeException("Invalid BDD rule result: " + rule);
}
});
}

/**
* Evaluate the provided ruleset and parameter arguments.
*
Expand All @@ -70,6 +114,21 @@ public Value evaluateRuleSet(EndpointRuleSet ruleset, Map<Identifier, Value> par
});
}

/**
* Configure the rule evaluator with the given parameters and parameter values for manual evaluation.
*
* @param parameters Parameters of the ruleset to evaluate.
* @param parameterArguments Parameter values to evaluate the ruleset against.
* @return the updated evaluator.
*/
public RuleEvaluator withParameters(Parameters parameters, Map<Identifier, Value> parameterArguments) {
for (Parameter parameter : parameters) {
parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value));
}
parameterArguments.forEach(scope::insert);
return this;
}

/**
* Evaluates the given condition in the current scope.
*
Expand Down Expand Up @@ -159,32 +218,40 @@ public Value visitTreeRule(List<Rule> rules) {

@Override
public Value visitErrorRule(Expression error) {
return error.accept(self);
return resolveError(self, error);
}

@Override
public Value visitEndpointRule(Endpoint endpoint) {
EndpointValue.Builder builder = EndpointValue.builder()
.sourceLocation(endpoint)
.url(endpoint.getUrl()
.accept(RuleEvaluator.this)
.expectStringValue()
.getValue());

for (Map.Entry<Identifier, Literal> entry : endpoint.getProperties().entrySet()) {
builder.putProperty(entry.getKey().toString(), entry.getValue().accept(RuleEvaluator.this));
}

for (Map.Entry<String, List<Expression>> entry : endpoint.getHeaders().entrySet()) {
List<String> values = new ArrayList<>();
for (Expression expression : entry.getValue()) {
values.add(expression.accept(RuleEvaluator.this).expectStringValue().getValue());
}
builder.putHeader(entry.getKey(), values);
}
return builder.build();
return resolveEndpoint(self, endpoint);
}
});
});
}

private static Value resolveEndpoint(RuleEvaluator self, Endpoint endpoint) {
EndpointValue.Builder builder = EndpointValue.builder()
.sourceLocation(endpoint)
.url(endpoint.getUrl()
.accept(self)
.expectStringValue()
.getValue());

for (Map.Entry<Identifier, Literal> entry : endpoint.getProperties().entrySet()) {
builder.putProperty(entry.getKey().toString(), entry.getValue().accept(self));
}

for (Map.Entry<String, List<Expression>> entry : endpoint.getHeaders().entrySet()) {
List<String> values = new ArrayList<>();
for (Expression expression : entry.getValue()) {
values.add(expression.accept(self).expectStringValue().getValue());
}
builder.putHeader(entry.getKey(), values);
}
return builder.build();
}

private static Value resolveError(RuleEvaluator self, Expression error) {
return error.accept(self);
}
}
Loading
Loading