Skip to content

Commit fbebdaf

Browse files
authored
Merge pull request #3 from mwien/list_dags
List DAGs in MEC
2 parents 69b29bd + 43b4840 commit fbebdaf

File tree

8 files changed

+302
-3
lines changed

8 files changed

+302
-3
lines changed

cliquepicking_python/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cliquepicking_python/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "cliquepicking"
3-
version = "0.2.3"
3+
version = "0.2.4"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

cliquepicking_python/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ The module provides the functions
77
- ```mec_size(G)```, which outputs the number of DAGs in the MEC represented by CPDAG G
88
- ```mec_sample_dags(G, k)```, which returns k uniformly sampled DAGs from the MEC represented by CPDAG G
99
- ```mec_sample_orders(G, k)``` which returns topological orders of k uniformly sampled DAGs from the MEC represented by CPDAG G
10+
- ```mec_list_dags(G)```, which returns a list of all DAGs in the MEC represented by CPDAG G
11+
- ```mec_list_orders(G)```, which returns topological orders of all DAGs in the MEC represented by CPDAG G
12+
13+
The DAGs are returned as edge lists and they can be read e.g. in networkx using ```nx.DiGraph(dag)``` (see the example at the bottom).
1014

1115
Be aware that ```mec_sample_dags(G, k)``` holds (and returns) k DAGs in memory. (For large graphs) to avoid high memory demand, generate DAGs in smaller batches or use ```mec_sample_orders(G, k)```, which only returns the easier-to-store topological order.
1216

17+
The same holds for ```mec_list_dags(G)```, consider checking the size of the MEC using ```mec_size(G)``` before calling this method.
18+
1319
In all cases, G should be given as an edge list (vertices should be represented by zero-indexed integers), which includes ```(a, b)``` and ```(b, a)``` for undirected edges $a - b$ and only ```(a, b)``` for directed edges $a \rightarrow b$. E.g.
1420

1521
```python

cliquepicking_python/src/lib.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use pyo3::prelude::*;
22

33
use cliquepicking_rs::count::count_cpdag;
4+
use cliquepicking_rs::enumerate::list_cpdag;
5+
use cliquepicking_rs::enumerate::list_cpdag_orders;
46
use cliquepicking_rs::partially_directed_graph::PartiallyDirectedGraph;
57
use cliquepicking_rs::sample::sample_cpdag;
68
use cliquepicking_rs::sample::sample_cpdag_orders;
@@ -13,6 +15,8 @@ fn cliquepicking(m: &Bound<'_, PyModule>) -> PyResult<()> {
1315
m.add_function(wrap_pyfunction!(mec_size, m)?)?;
1416
m.add_function(wrap_pyfunction!(mec_sample_dags, m)?)?;
1517
m.add_function(wrap_pyfunction!(mec_sample_orders, m)?)?;
18+
m.add_function(wrap_pyfunction!(mec_list_dags, m)?)?;
19+
m.add_function(wrap_pyfunction!(mec_list_orders, m)?)?;
1620
Ok(())
1721
}
1822

@@ -37,14 +41,34 @@ fn mec_sample_dags(cpdag: Vec<(usize, usize)>, k: usize) -> PyResult<Vec<Vec<(us
3741
Ok(samples)
3842
}
3943

40-
/// Sample k DAGs uniformly from the Markov equivalence class represented by CPDAG cpdag.
44+
/// Sample k DAGs (represented by a topological order) uniformly from the Markov equivalence class represented by CPDAG cpdag.
4145
#[pyfunction]
4246
fn mec_sample_orders(cpdag: Vec<(usize, usize)>, k: usize) -> PyResult<Vec<Vec<usize>>> {
4347
let mx = max_element(&cpdag);
4448
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
4549
Ok(sample_cpdag_orders(&g, k))
4650
}
4751

52+
/// List all DAGs from the Markov equivalence class represented by CPDAG cpdag.
53+
#[pyfunction]
54+
fn mec_list_dags(cpdag: Vec<(usize, usize)>) -> PyResult<Vec<Vec<(usize, usize)>>> {
55+
let mx = max_element(&cpdag);
56+
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
57+
let samples = list_cpdag(&g)
58+
.into_iter()
59+
.map(|sample| sample.to_edge_list())
60+
.collect();
61+
Ok(samples)
62+
}
63+
64+
/// List all DAGs (represented by a topological orderfrom the Markov equivalence class represented by CPDAG cpdag.
65+
#[pyfunction]
66+
fn mec_list_orders(cpdag: Vec<(usize, usize)>) -> PyResult<Vec<Vec<usize>>> {
67+
let mx = max_element(&cpdag);
68+
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
69+
Ok(list_cpdag_orders(&g))
70+
}
71+
4872
// small helper
4973
fn max_element(tuple_list: &[(usize, usize)]) -> usize {
5074
let mut mx = 0;

cliquepicking_rs/src/enumerate.rs

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
use crate::{
2+
directed_graph::DirectedGraph, graph::Graph, partially_directed_graph::PartiallyDirectedGraph,
3+
};
4+
5+
#[derive(Debug)]
6+
struct McsState {
7+
ordering: Vec<usize>,
8+
sets: Vec<Vec<usize>>,
9+
cardinality: Vec<usize>,
10+
max_cardinality: usize,
11+
position: usize,
12+
}
13+
14+
impl McsState {
15+
pub fn new(n: usize) -> McsState {
16+
let mut sets = vec![Vec::new(); n];
17+
sets[0] = (0..n).collect();
18+
McsState {
19+
ordering: Vec::new(),
20+
sets,
21+
cardinality: vec![0; n],
22+
max_cardinality: 0,
23+
position: 0,
24+
}
25+
}
26+
}
27+
28+
fn visit(g: &Graph, state: &mut McsState, u: usize) {
29+
state.position += 1;
30+
state.ordering.push(u);
31+
state.cardinality[u] = usize::MAX; // TODO: use Option to encode this
32+
for &v in g.neighbors(u) {
33+
if state.cardinality[v] < g.n {
34+
state.cardinality[v] += 1;
35+
state.sets[state.cardinality[v]].push(v);
36+
}
37+
}
38+
state.max_cardinality += 1;
39+
while state.max_cardinality > 0 && state.sets[state.max_cardinality].is_empty() {
40+
state.max_cardinality -= 1;
41+
}
42+
}
43+
44+
fn unvisit(g: &Graph, state: &mut McsState, u: usize, last_cardinality: usize) {
45+
state.position -= 1;
46+
state.ordering.pop();
47+
state.cardinality[u] = last_cardinality;
48+
state.sets[state.cardinality[u]].push(u);
49+
50+
for &v in g.neighbors(u).rev() {
51+
// TODO: sets will get bigger and bigger -> cleanup?
52+
if state.cardinality[v] < g.n {
53+
state.cardinality[v] -= 1;
54+
state.sets[state.cardinality[v]].push(v);
55+
}
56+
}
57+
58+
state.max_cardinality = state.cardinality[u];
59+
}
60+
61+
fn reach(g: &Graph, st: &[usize], s: usize) -> Vec<usize> {
62+
let mut visited = vec![false; g.n];
63+
visited[s] = true;
64+
let mut blocked = vec![true; g.n];
65+
st.iter().for_each(|&v| blocked[v] = false);
66+
let mut queue = vec![s];
67+
68+
while let Some(u) = queue.pop() {
69+
for &v in g.neighbors(u) {
70+
if !visited[v] && !blocked[v] {
71+
queue.push(v);
72+
visited[v] = true;
73+
}
74+
}
75+
}
76+
77+
visited
78+
.iter()
79+
.enumerate()
80+
.filter(|(_, &val)| val)
81+
.map(|(i, _)| i)
82+
.collect()
83+
}
84+
85+
fn rec_list_chordal_orders(g: &Graph, orders: &mut Vec<Vec<usize>>, state: &mut McsState) {
86+
if state.position == g.n {
87+
orders.push(state.ordering.clone());
88+
return;
89+
}
90+
91+
// do this better
92+
let u = loop {
93+
while state.max_cardinality > 0 && state.sets[state.max_cardinality].is_empty() {
94+
state.max_cardinality -= 1;
95+
}
96+
let next_vertex = state.sets[state.max_cardinality].pop().unwrap();
97+
// use Result instead of this hack
98+
if state.cardinality[next_vertex] == state.max_cardinality {
99+
break next_vertex;
100+
}
101+
};
102+
103+
let last_cardinality = state.cardinality[u];
104+
visit(g, state, u);
105+
rec_list_chordal_orders(g, orders, state);
106+
unvisit(g, state, u, last_cardinality);
107+
108+
let st: Vec<_> = state.sets[state.max_cardinality]
109+
.iter()
110+
.copied()
111+
.filter(|&v| state.max_cardinality == state.cardinality[v])
112+
.collect();
113+
let reachable = reach(g, &st, u);
114+
115+
for x in reachable {
116+
if x == u || state.cardinality[x] != state.max_cardinality {
117+
continue;
118+
}
119+
let last_cardinality = state.cardinality[x];
120+
visit(g, state, x);
121+
rec_list_chordal_orders(g, orders, state);
122+
unvisit(g, state, x, last_cardinality);
123+
}
124+
}
125+
126+
fn list_chordal_orders(g: &Graph) -> Vec<Vec<usize>> {
127+
let mut orders = Vec::new();
128+
rec_list_chordal_orders(g, &mut orders, &mut McsState::new(g.n));
129+
orders
130+
}
131+
132+
fn sort_order(d: &DirectedGraph, cmp: &[usize], order: &[usize]) -> Vec<usize> {
133+
let mut component_no = vec![usize::MAX; *cmp.iter().max().unwrap() + 1];
134+
let mut sorted_order = Vec::new();
135+
136+
let to = d.topological_order();
137+
let mut found_comps = 0;
138+
for &u in to.iter() {
139+
if component_no[cmp[u]] == usize::MAX {
140+
component_no[cmp[u]] = found_comps;
141+
found_comps += 1;
142+
sorted_order.push(Vec::new());
143+
}
144+
}
145+
146+
for &u in order.iter() {
147+
let cmp_u = component_no[cmp[u]];
148+
sorted_order[cmp_u].push(u);
149+
}
150+
151+
sorted_order.into_iter().flatten().collect()
152+
}
153+
154+
// TODO: rename
155+
pub fn list_cpdag_orders(g: &PartiallyDirectedGraph) -> Vec<Vec<usize>> {
156+
let undirected_subgraph = g.undirected_subgraph();
157+
let directed_subgraph = g.directed_subgraph();
158+
let unsorted_orders = list_chordal_orders(&undirected_subgraph);
159+
160+
// could use a method which only returns list of vertex lists
161+
let (_, vertices) = undirected_subgraph.connected_components();
162+
let mut cmp = vec![0; g.n];
163+
vertices
164+
.iter()
165+
.enumerate()
166+
.for_each(|(i, l)| l.iter().for_each(|&v| cmp[v] = i));
167+
168+
unsorted_orders
169+
.iter()
170+
.map(|order| sort_order(&directed_subgraph, &cmp, order))
171+
.collect()
172+
}
173+
174+
pub fn list_cpdag(g: &PartiallyDirectedGraph) -> Vec<DirectedGraph> {
175+
let undirected_subgraph = g.undirected_subgraph();
176+
let directed_subgraph = g.directed_subgraph();
177+
178+
let mut dags = Vec::new();
179+
for order in list_cpdag_orders(g).iter() {
180+
let mut position = vec![0; order.len()];
181+
order.iter().enumerate().for_each(|(i, &v)| position[v] = i);
182+
let mut dag_edge_list = directed_subgraph.to_edge_list();
183+
for &(u, v) in undirected_subgraph.to_edge_list().iter() {
184+
if u > v {
185+
continue;
186+
}
187+
if position[u] < position[v] {
188+
dag_edge_list.push((u, v));
189+
} else {
190+
dag_edge_list.push((v, u));
191+
}
192+
}
193+
dags.push(DirectedGraph::from_edge_list(dag_edge_list, order.len()));
194+
}
195+
dags
196+
}
197+
198+
#[cfg(test)]
199+
mod tests {
200+
201+
use crate::partially_directed_graph::PartiallyDirectedGraph;
202+
203+
fn get_paper_graph() -> PartiallyDirectedGraph {
204+
PartiallyDirectedGraph::from_edge_list(
205+
vec![
206+
(0, 1),
207+
(1, 0),
208+
(0, 2),
209+
(2, 0),
210+
(1, 2),
211+
(2, 1),
212+
(1, 3),
213+
(3, 1),
214+
(1, 4),
215+
(4, 1),
216+
(1, 5),
217+
(5, 1),
218+
(2, 3),
219+
(3, 2),
220+
(2, 4),
221+
(4, 2),
222+
(2, 5),
223+
(5, 2),
224+
(3, 4),
225+
(4, 3),
226+
(4, 5),
227+
(5, 4),
228+
],
229+
6,
230+
)
231+
}
232+
233+
fn get_basic_graph() -> PartiallyDirectedGraph {
234+
PartiallyDirectedGraph::from_edge_list(
235+
vec![(0, 1), (1, 0), (1, 2), (2, 1), (0, 3), (2, 3)],
236+
4,
237+
)
238+
}
239+
240+
#[test]
241+
fn list_cpdag_basic_check() {
242+
let dags = super::list_cpdag(&get_paper_graph());
243+
assert_eq!(dags.len(), 54);
244+
let dags = super::list_cpdag(&get_basic_graph());
245+
assert_eq!(dags.len(), 3);
246+
// TODO: better tests
247+
}
248+
249+
#[test]
250+
fn list_cpdag_orders_basic_check() {
251+
let orders = super::list_cpdag_orders(&get_paper_graph());
252+
assert_eq!(orders.len(), 54);
253+
let orders = super::list_cpdag_orders(&get_basic_graph());
254+
assert_eq!(orders.len(), 3);
255+
// TODO: better tests
256+
}
257+
}

cliquepicking_rs/src/graph.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ impl Graph {
4646
}
4747
}
4848

49+
pub fn to_edge_list(&self) -> Vec<(usize, usize)> {
50+
let mut edge_list = Vec::new();
51+
for u in 0..self.n {
52+
for &v in self.neighbors(u) {
53+
edge_list.push((u, v));
54+
}
55+
}
56+
edge_list
57+
}
58+
4959
pub fn neighbors(&self, u: usize) -> std::slice::Iter<'_, usize> {
5060
self.neighbors[u].iter()
5161
}

cliquepicking_rs/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod count;
22
pub mod directed_graph;
3+
pub mod enumerate;
34
pub mod graph;
45
pub mod partially_directed_graph;
56
pub mod sample;

cliquepicking_rs/src/sample.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ pub fn sample_chordal(g: &Graph, k: usize) -> Vec<DirectedGraph> {
444444
}
445445

446446
// there are unnecessary allocations/conversions here, maybe optimize this at some point
447+
// maybe call "sample_from_cpdag"
447448
pub fn sample_cpdag(g: &PartiallyDirectedGraph, k: usize) -> Vec<DirectedGraph> {
448449
let undirected_subgraph = g.undirected_subgraph();
449450
let directed_subgraph = g.directed_subgraph();

0 commit comments

Comments
 (0)