Skip to content

Improve analysis performance #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 219 additions & 60 deletions dsm/src/main/java/org/hjug/dsm/DSM.java

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions dsm/src/main/java/org/hjug/dsm/EdgeToRemoveInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
@Data
public class EdgeToRemoveInfo {
private final DefaultWeightedEdge edge;
private final double edgeWeight;
private final int edgeInCycleCount;
private final int removedEdgeWeight;
private final int newCycleCount;
private final double averageCycleNodeCount;
private final double payoff; // impact / effort
}
105 changes: 105 additions & 0 deletions dsm/src/main/java/org/hjug/dsm/OptimalBackEdgeRemover.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package org.hjug.dsm;

import org.jgrapht.Graph;
import org.jgrapht.alg.cycle.CycleDetector;
import org.jgrapht.alg.cycle.JohnsonSimpleCycles;
import org.jgrapht.graph.AsSubgraph;

import java.util.*;

public class OptimalBackEdgeRemover<V, E> {
private Graph<V, E> graph;

/**
* Constructor initializing with the target graph.
* @param graph The directed weighted graph to analyze
*/
public OptimalBackEdgeRemover(Graph<V, E> graph) {
this.graph = graph;
}

/**
* Finds the optimal back edge(s) to remove to move the graph closer to a DAG.
* @return A set of edges to remove
*/
public Set<E> findOptimalBackEdgesToRemove() {
CycleDetector<V, E> cycleDetector = new CycleDetector<>(graph);

// If the graph is already acyclic, return empty set
if (!cycleDetector.detectCycles()) {
return Collections.emptySet();
}

// Find all cycles in the graph
JohnsonSimpleCycles<V, E> cycleFinder = new JohnsonSimpleCycles<>(graph);
List<List<V>> originalCycles = cycleFinder.findSimpleCycles();
int originalCycleCount = originalCycles.size();

// Identify edges that are part of at least one cycle
Set<E> edgesInCycles = new HashSet<>();
for (List<V> cycle : originalCycles) {
for (int i = 0; i < cycle.size(); i++) {
V source = cycle.get(i);
V target = cycle.get((i + 1) % cycle.size());
E edge = graph.getEdge(source, target);
edgesInCycles.add(edge);
}
}

// Calculate cycle elimination count for each edge
Map<E, Integer> edgeCycleEliminationCount = new HashMap<>();
for (E edge : edgesInCycles) {
// Create a subgraph without this edge
Graph<V, E> subgraph = new AsSubgraph<>(graph, graph.vertexSet(), new HashSet<>(graph.edgeSet()));
subgraph.removeEdge(edge);

// Calculate how many cycles would be eliminated
JohnsonSimpleCycles<V, E> subgraphCycleFinder = new JohnsonSimpleCycles<>(subgraph);
List<List<V>> remainingCycles = subgraphCycleFinder.findSimpleCycles();
int cyclesEliminated = originalCycleCount - remainingCycles.size();

edgeCycleEliminationCount.put(edge, cyclesEliminated);
}

// Find edges that eliminate the most cycles
int maxCycleElimination = 0;
List<E> maxEliminationEdges = new ArrayList<>();

for (Map.Entry<E, Integer> entry : edgeCycleEliminationCount.entrySet()) {
if (entry.getValue() > maxCycleElimination) {
maxCycleElimination = entry.getValue();
maxEliminationEdges.clear();
maxEliminationEdges.add(entry.getKey());
} else if (entry.getValue() == maxCycleElimination) {
maxEliminationEdges.add(entry.getKey());
}
}

// If no cycles are eliminated (shouldn't happen), return empty set
if (maxEliminationEdges.isEmpty() || maxCycleElimination == 0) {
return Collections.emptySet();
}

// If multiple edges eliminate the same number of cycles, choose the one with the lowest weight
if (maxEliminationEdges.size() > 1) {
double minWeight = Double.MAX_VALUE;
List<E> minWeightEdges = new ArrayList<>();

for (E edge : maxEliminationEdges) {
double weight = graph.getEdgeWeight(edge);
if (weight < minWeight) {
minWeight = weight;
minWeightEdges.clear();
minWeightEdges.add(edge);
} else if (weight == minWeight) {
minWeightEdges.add(edge);
}
}

return new HashSet<>(minWeightEdges);
}

// Return the single edge that eliminates the most cycles
return new HashSet<>(maxEliminationEdges);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.hjug.dsm;

import lombok.extern.slf4j.Slf4j;
import org.jgrapht.alg.cycle.CycleDetector;
import org.jgrapht.graph.AsSubgraph;
import org.jgrapht.opt.graph.sparse.SparseIntDirectedWeightedGraph;

import java.util.HashMap;
import java.util.Map;

@Slf4j
public class SparseGraphCircularReferenceChecker {

private final Map<Integer, AsSubgraph<Integer, Integer>> uniqueSubGraphs = new HashMap<>();

/**
* Detects cycles in the graph that is passed in
* and returns the unique cycles in the graph as a map of subgraphs
*
* @param graph
* @return a Map of unique cycles in the graph
*/
public Map<Integer, AsSubgraph<Integer, Integer>> getCycles(SparseIntDirectedWeightedGraph graph) {

if (!uniqueSubGraphs.isEmpty()) {
return uniqueSubGraphs;
}

// use CycleDetector.findCycles()?
Map<Integer, AsSubgraph<Integer, Integer>> cycles = detectCycles(graph);

cycles.forEach((vertex, subGraph) -> {
int vertexCount = subGraph.vertexSet().size();
int edgeCount = subGraph.edgeSet().size();

if (vertexCount > 1 && edgeCount > 1 && !isDuplicateSubGraph(subGraph, vertex)) {
uniqueSubGraphs.put(vertex, subGraph);
log.debug("Vertex: {} vertex count: {} edge count: {}", vertex, vertexCount, edgeCount);
}
});

return uniqueSubGraphs;
}

private boolean isDuplicateSubGraph(AsSubgraph<Integer, Integer> subGraph, Integer vertex) {
if (!uniqueSubGraphs.isEmpty()) {
for (AsSubgraph<Integer, Integer> renderedSubGraph : uniqueSubGraphs.values()) {
if (renderedSubGraph.vertexSet().size() == subGraph.vertexSet().size()
&& renderedSubGraph.edgeSet().size()
== subGraph.edgeSet().size()
&& renderedSubGraph.vertexSet().contains(vertex)) {
return true;
}
}
}

return false;
}

private Map<Integer, AsSubgraph<Integer, Integer>> detectCycles(
SparseIntDirectedWeightedGraph graph) {
Map<Integer, AsSubgraph<Integer, Integer>> cyclesForEveryVertexMap = new HashMap<>();
CycleDetector<Integer, Integer> cycleDetector = new CycleDetector<>(graph);
cycleDetector.findCycles().forEach(v -> {
AsSubgraph<Integer, Integer> subGraph =
new AsSubgraph<>(graph, cycleDetector.findCyclesContainingVertex(v));
cyclesForEveryVertexMap.put(v, subGraph);
});
return cyclesForEveryVertexMap;
}
}
180 changes: 180 additions & 0 deletions dsm/src/main/java/org/hjug/dsm/SparseIntDWGEdgeRemovalCalculator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package org.hjug.dsm;

import org.jgrapht.Graph;
import org.jgrapht.Graphs;
import org.jgrapht.alg.connectivity.KosarajuStrongConnectivityInspector;
import org.jgrapht.alg.util.Triple;
import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.opt.graph.sparse.SparseIntDirectedWeightedGraph;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

// TODO: Delete
class SparseIntDWGEdgeRemovalCalculator {
private final Graph<String, DefaultWeightedEdge> graph;
SparseIntDirectedWeightedGraph sparseGraph;
List<Triple<Integer, Integer, Double>> sparseEdges;
List<Integer> sparseEdgesAboveDiagonal;
private final double sumOfEdgeWeightsAboveDiagonal;
int vertexCount;
Map<String, Integer> vertexToInt;
Map<Integer, String> intToVertex;


SparseIntDWGEdgeRemovalCalculator(
Graph<String, DefaultWeightedEdge> graph,
SparseIntDirectedWeightedGraph sparseGraph,
List<Triple<Integer, Integer, Double>> sparseEdges,
List<Integer> sparseEdgesAboveDiagonal,
double sumOfEdgeWeightsAboveDiagonal,
int vertexCount,
Map<String, Integer> vertexToInt,
Map<Integer, String> intToVertex) {
this.graph = graph;
this.sparseGraph = sparseGraph;
this.sparseEdges = new CopyOnWriteArrayList<>(sparseEdges);
this.sparseEdgesAboveDiagonal = new CopyOnWriteArrayList<>(sparseEdgesAboveDiagonal);
this.sumOfEdgeWeightsAboveDiagonal = sumOfEdgeWeightsAboveDiagonal;
this.vertexCount = vertexCount;
this.vertexToInt = new ConcurrentHashMap<>(vertexToInt);
this.intToVertex = new ConcurrentHashMap<>(intToVertex);

}

public List<EdgeToRemoveInfo> getImpactOfSparseEdgesAboveDiagonalIfRemoved() {
return sparseEdgesAboveDiagonal.parallelStream()
.map(this::calculateSparseEdgeToRemoveInfo)
.sorted(Comparator.comparing(EdgeToRemoveInfo::getPayoff).thenComparing(EdgeToRemoveInfo::getRemovedEdgeWeight))
.collect(Collectors.toList());
}

private EdgeToRemoveInfo calculateSparseEdgeToRemoveInfo(Integer edgeToRemove) {
//clone graph and remove edge
int source = sparseGraph.getEdgeSource(edgeToRemove);
int target = sparseGraph.getEdgeTarget(edgeToRemove);
double weight = sparseGraph.getEdgeWeight(edgeToRemove);
Triple<Integer, Integer, Double> removedEdge = Triple.of(source, target, weight);

List<Triple<Integer, Integer, Double>> tempUpdatedEdgeList = new ArrayList<>(sparseEdges);
tempUpdatedEdgeList.remove(removedEdge);
List<Triple<Integer, Integer, Double>> updatedEdgeList = new CopyOnWriteArrayList<>(tempUpdatedEdgeList);

SparseIntDirectedWeightedGraph improvedGraph = new SparseIntDirectedWeightedGraph(vertexCount, updatedEdgeList);

// find edges above diagonal
List<Integer> sortedSparseVertices = orderVertices(improvedGraph);
List<Integer> updatedEdges = getSparseEdgesAboveDiagonal(improvedGraph, sortedSparseVertices);

// calculate new graph statistics
int newEdgeCount = updatedEdges.size();
double newEdgeWeightSum = updatedEdges.stream()
.mapToDouble(improvedGraph::getEdgeWeight).sum();
DefaultWeightedEdge defaultWeightedEdge =
graph.getEdge(intToVertex.get(source), intToVertex.get(target));
double payoff = (sumOfEdgeWeightsAboveDiagonal - newEdgeWeightSum) / weight;
return new EdgeToRemoveInfo(defaultWeightedEdge, (int) weight, newEdgeCount, payoff);
}

private List<Integer> orderVertices(SparseIntDirectedWeightedGraph sparseGraph) {
List<Set<Integer>> sccs = new CopyOnWriteArrayList<>(findStronglyConnectedSparseGraphComponents(sparseGraph));
// List<Integer> sparseIntSortedActivities = topologicalSortSparseGraph(sccs, sparseGraph);
List<Integer> sparseIntSortedActivities = topologicalParallelSortSparseGraph(sccs, sparseGraph);
// reversing corrects rendering of the DSM
// with sources as rows and targets as columns
// was needed after AI solution was generated and iterated
Collections.reverse(sparseIntSortedActivities);

return new CopyOnWriteArrayList<>(sparseIntSortedActivities);
}

/**
* Kosaraju SCC detector avoids stack overflow.
* It is used by JGraphT's CycleDetector, and makes sense to use it here as well for consistency
*
* @param graph
* @return
*/
private List<Set<Integer>> findStronglyConnectedSparseGraphComponents(Graph<Integer, Integer> graph) {
KosarajuStrongConnectivityInspector<Integer, Integer> kosaraju =
new KosarajuStrongConnectivityInspector<>(graph);
return kosaraju.stronglyConnectedSets();
}

private List<Integer> topologicalSortSparseGraph(List<Set<Integer>> sccs, Graph<Integer, Integer> graph) {
List<Integer> sortedActivities = new ArrayList<>();
Set<Integer> visited = new HashSet<>();

sccs.parallelStream()
.flatMap(Set::parallelStream)
.filter(activity -> !visited.contains(activity))
.forEach(activity -> topologicalSortUtilSparseGraph(activity, visited, sortedActivities, graph));


Collections.reverse(sortedActivities);
return sortedActivities;
}

private void topologicalSortUtilSparseGraph(
Integer activity, Set<Integer> visited, List<Integer> sortedActivities, Graph<Integer, Integer> graph) {
visited.add(activity);

for (Integer neighbor : Graphs.successorListOf(graph, activity)) {
if (!visited.contains(neighbor)) {
topologicalSortUtilSparseGraph(neighbor, visited, sortedActivities, graph);
}
}

sortedActivities.add(activity);
}

private List<Integer> getSparseEdgesAboveDiagonal(SparseIntDirectedWeightedGraph sparseGraph, List<Integer> sortedActivities) {
ConcurrentLinkedQueue<Integer> sparseEdgesAboveDiagonal = new ConcurrentLinkedQueue<>();

int size = sortedActivities.size();
IntStream.range(0, size).parallel().forEach(i -> {
for (int j = i + 1; j < size; j++) {
Integer edge = sparseGraph.getEdge(
sortedActivities.get(i),
sortedActivities.get(j)
);
if (edge != null) {
sparseEdgesAboveDiagonal.add(edge);
}
}
});

return new ArrayList<>(sparseEdgesAboveDiagonal);
}

private List<Integer> topologicalParallelSortSparseGraph(List<Set<Integer>> sccs, Graph<Integer, Integer> graph) {
ConcurrentLinkedQueue<Integer> sortedActivities = new ConcurrentLinkedQueue<>();
Set<Integer> visited = new ConcurrentSkipListSet<>();

sccs.parallelStream()
.flatMap(Set::parallelStream)
.filter(activity -> !visited.contains(activity))
.forEach(activity -> topologicalSortUtilSparseGraph(activity, visited, sortedActivities, graph));

ArrayList<Integer> sortedActivitiesList = new ArrayList<>(sortedActivities);
Collections.reverse(sortedActivitiesList);
return sortedActivitiesList;
}

private void topologicalSortUtilSparseGraph(
Integer activity, Set<Integer> visited, ConcurrentLinkedQueue<Integer> sortedActivities, Graph<Integer, Integer> graph) {
visited.add(activity);

Graphs.successorListOf(graph, activity).parallelStream()
.filter(neighbor -> !visited.contains(neighbor))
.forEach(neighbor -> topologicalSortUtilSparseGraph(neighbor, visited, sortedActivities, graph));

sortedActivities.add(activity);
}

}
1 change: 0 additions & 1 deletion dsm/src/test/java/org/hjug/dsm/DSMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,5 @@ void getImpactOfEdgesAboveDiagonalIfRemoved() {

assertEquals("(H : E)", infos.get(0).getEdge().toString());
assertEquals(2, infos.get(0).getNewCycleCount());
assertEquals(4.5, infos.get(0).getAverageCycleNodeCount());
}
}
Loading
Loading