-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
Hello! I've been trying to replicate the results on Sachs using the provided hyperparameters, but I'm getting SHD ~37-40 instead of the low 10s. Any clue why?
from cdt.data import load_dataset
data, graph = load_dataset("sachs")
data = data.to_numpy()
graph = nx.to_numpy_array(graph)
num_nodes = data.shape[1]
model = DiffAN(num_nodes, residue=True)
pred_graph, order = model.fit(data)
metrics = MetricsDAG(pred_graph, graph).metrics
This produces
{"fdr": 0.8919, "tpr": 0.2222, "fpr": 0.8919, "shd": 37, "nnz": 37, "precision": 0.1081, "recall": 0.2222, "F1": 0.1455, "gscore": 0.0}
Metadata
Metadata
Assignees
Labels
No labels