1
1
import { Edge , MarkerType } from '@xyflow/react'
2
2
3
3
import { NODE_GAP_X , NODE_GAP_Y , NODE_HEIGHT_MAP , NODE_WIDTH_MAP } from './constants'
4
- import { GraphVisualizerBaseNode , GraphVisualizerNode , GraphVisualizerProps } from './types'
4
+ import { GraphVisualizerExtendedNode , GraphVisualizerNode , GraphVisualizerProps } from './types'
5
5
6
6
/**
7
7
* Processes edges by assigning a default type and customizing the marker (arrow style).
@@ -29,22 +29,15 @@ export const processEdges = (edges: GraphVisualizerProps['edges']): Edge[] =>
29
29
* @param edges - List of all edges representing parent-child relationships.
30
30
* @returns The root node's ID, or `null` if no root is found.
31
31
*/
32
- const findRootNode = ( nodes : GraphVisualizerProps [ 'nodes' ] , edges : GraphVisualizerProps [ 'edges' ] ) : string | null => {
32
+ const findRootNodes = ( nodes : GraphVisualizerProps [ 'nodes' ] , edges : GraphVisualizerProps [ 'edges' ] ) => {
33
33
// Create a set of all node IDs
34
34
const nodeIds = new Set ( nodes . map ( ( node ) => node . id ) )
35
35
36
36
// Create a set of all child node IDs (targets in edges)
37
37
const childIds = new Set ( edges . map ( ( edge ) => edge . target ) )
38
38
39
- // The root node is the one that is in nodeIds but not in childIds
40
- const rootNodeId = Array . from ( nodeIds ) . find ( ( nodeId ) => ! childIds . has ( nodeId ) )
41
-
42
- if ( rootNodeId ) {
43
- return rootNodeId
44
- }
45
-
46
- // If no root node is found, return null (could indicate a cycle or disconnected nodes)
47
- return null
39
+ // Find all nodes that are NOT a child of any other node (i.e., root nodes)
40
+ return Array . from ( nodeIds ) . filter ( ( nodeId ) => ! childIds . has ( nodeId ) )
48
41
}
49
42
50
43
/**
@@ -56,7 +49,7 @@ const findRootNode = (nodes: GraphVisualizerProps['nodes'], edges: GraphVisualiz
56
49
* @param childrenMap - Map of node ID → child nodes.
57
50
* @param nodeMap - Map of node ID → node data.
58
51
*/
59
- const placeNodes = (
52
+ const placeNode = (
60
53
nodeId : string ,
61
54
x : number ,
62
55
y : number ,
@@ -88,19 +81,13 @@ const placeNodes = (
88
81
// Determine x-position for child nodes
89
82
const childX = x + nodeWidth + NODE_GAP_X
90
83
91
- // Get height values for each child node
92
- const childHeights = children . map ( ( id ) => NODE_HEIGHT_MAP [ nodeMap . get ( id ) . type ] )
93
-
94
- // Calculate total height required for children (with spacing)
95
- const totalHeight = childHeights . reduce ( ( sum , h ) => sum + h + NODE_GAP_Y , - NODE_GAP_Y )
96
-
97
- // Start positioning children from the topmost position
98
- let startY = y - totalHeight / 2
99
-
100
- children . forEach ( ( child , index ) => {
84
+ // Start placing children **below** the parent
85
+ let startY = y
86
+ children . forEach ( ( child ) => {
87
+ const childHeight = NODE_HEIGHT_MAP [ nodeMap . get ( child ) . type ]
101
88
// Position each child at the calculated coordinates
102
- placeNodes ( child , childX , startY + childHeights [ index ] / 2 , updatedPositions , childrenMap , nodeMap )
103
- startY += childHeights [ index ] + NODE_GAP_Y // Move Y down for next sibling
89
+ placeNode ( child , childX , startY , updatedPositions , childrenMap , nodeMap )
90
+ startY += childHeight + NODE_GAP_Y // Move the next sibling below
104
91
} )
105
92
}
106
93
@@ -116,14 +103,19 @@ const calculateNodePositions = (nodes: GraphVisualizerProps['nodes'], edges: Gra
116
103
nodes . forEach ( ( node ) => childrenMap . set ( node . id , [ ] ) )
117
104
edges . forEach ( ( edge ) => childrenMap . get ( edge . source ) . push ( edge . target ) )
118
105
119
- // Identify the root node (the node that is never a target in edges)
120
- const rootNode = findRootNode ( nodes , edges )
121
- if ( ! rootNode ) {
122
- throw new Error ( 'Either cyclic or disconnected nodes are present!' )
106
+ // Identify all the root nodes (the nodes that are never a target in edges)
107
+ const rootNodes = findRootNodes ( nodes , edges )
108
+ if ( ! rootNodes . length ) {
109
+ return { }
123
110
}
124
111
125
- // Start recursive positioning from the root node
126
- placeNodes ( rootNode , 0 , 0 , positions , childrenMap , nodeMap )
112
+ // Place multiple root nodes vertically spaced at x = 0
113
+ let startY = 0
114
+ rootNodes . forEach ( ( rootId ) => {
115
+ const nodeHeight = NODE_HEIGHT_MAP [ nodeMap . get ( rootId ) . type ]
116
+ placeNode ( rootId , 0 , startY , positions , childrenMap , nodeMap )
117
+ startY += nodeHeight + NODE_GAP_Y // Move next root node downward
118
+ } )
127
119
128
120
return positions
129
121
}
@@ -138,7 +130,7 @@ const calculateNodePositions = (nodes: GraphVisualizerProps['nodes'], edges: Gra
138
130
export const processNodes = (
139
131
nodes : GraphVisualizerProps [ 'nodes' ] ,
140
132
edges : GraphVisualizerProps [ 'edges' ] ,
141
- ) : GraphVisualizerBaseNode [ ] => {
133
+ ) : GraphVisualizerExtendedNode [ ] => {
142
134
// Compute node positions based on hierarchy
143
135
const positions = calculateNodePositions ( nodes , edges )
144
136
@@ -149,6 +141,6 @@ export const processNodes = (
149
141
selectable : node . type === 'dropdownNode' ,
150
142
// Assign computed position; default to (0,0) if not found (shouldn't happen in a valid tree)
151
143
position : positions [ node . id ] ?? { x : 0 , y : 0 } ,
152
- } ) as GraphVisualizerBaseNode ,
144
+ } ) as GraphVisualizerExtendedNode ,
153
145
)
154
146
}
0 commit comments