diff --git a/CurvedPlanarReformat/CurvedPlanarReformat.py b/CurvedPlanarReformat/CurvedPlanarReformat.py index 2b2c014..7f116c7 100644 --- a/CurvedPlanarReformat/CurvedPlanarReformat.py +++ b/CurvedPlanarReformat/CurvedPlanarReformat.py @@ -169,326 +169,62 @@ def __init__(self): # there is no need to compute displacement for each slice, # we just compute for every n-th to make computation faster and inverse computation more robust # (less contradiction because of there is less overlapping between neighbor slices) - self.transformSpacingFactor = 5.0 - - @staticmethod - def getPointsProjectedToPlane(pointsArray, transformWorldToPlane): + appLogic = slicer.app.applicationLogic() + resamplerName = "ResampleScalarVectorDWIVolume" + found = appLogic.IsVolumeResamplerRegistered(resamplerName) + if not found: + mesg = f"CurvedPlanarReformat: {resamplerName!r} is not registered" + raise LookupError(mesg) + collectionOfSliceLogics = appLogic.GetSliceLogics() + numSliceLogics = collectionOfSliceLogics.GetNumberOfItems() + if numSliceLogics == 0: + mesg = "CurvedPlanarReformat: no SliceLogics found" + raise LookupError(mesg) + self.sliceLogic = collectionOfSliceLogics.GetItemAsObject(0) + self.sliceLogic.CurvedPlanarReformationInit() + + def getPointsProjectedToPlane(self, pointsArray, transformWorldToPlane): """ Returns points projected to the plane coordinate system (plane normal = plane Z axis). pointsArray contains each point as a column vector. """ - import numpy as np - numberOfPoints = pointsArray.shape[1] - # Concatenate a 4th line containing 1s so that we can transform the positions using - # a single matrix multiplication. - pointsArray_World = np.row_stack((pointsArray,np.ones(numberOfPoints))) - - # Point positions in the plane coordinate system: - pointsArray_Plane = np.dot(transformWorldToPlane, pointsArray_World) - # Projected point positions in the plane coordinate system: - pointsArray_Plane[2,:] = np.zeros(numberOfPoints) - # Projected point positions in the world coordinate system: - pointsArrayProjected_World = np.dot(np.linalg.inv(transformWorldToPlane), pointsArray_Plane) - - # remove the last row (all ones) - pointsArrayProjected_World = pointsArrayProjected_World[0:3,:] - - return pointsArrayProjected_World + pointsArrayOut = vtk.vtkPoints() + success = self.sliceLogic.CurvedPlanarReformationGetPointsProjectedToPlane( + pointsArray, transformWorldToPlane, pointsArrayOut + ) + if not success: + raise ValueError("getPointsProjectedToPlane failed") + return pointsArrayOut def computeStraighteningTransform(self, transformToStraightenedNode, curveNode, sliceSizeMm, outputSpacingMm, stretching=False, rotationDeg=0.0, reslicingPlanesModelNode=None): """ Compute straightened volume (useful for example for visualization of curved vessels) stretching: if True then stretching transform will be computed, otherwise straightening """ - - # Create a temporary resampled curve - resamplingCurveSpacing = outputSpacingMm * self.transformSpacingFactor - originalCurvePoints = curveNode.GetCurvePointsWorld() - sampledPoints = vtk.vtkPoints() - if not slicer.vtkMRMLMarkupsCurveNode.ResamplePoints(originalCurvePoints, sampledPoints, resamplingCurveSpacing, False): - raise ValueError("Resampling curve failed") - resampledCurveNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsCurveNode", "CurvedPlanarReformat_resampled_curve_temp") - resampledCurveNode.SetNumberOfPointsPerInterpolatingSegment(1) - resampledCurveNode.SetCurveTypeToLinear() - resampledCurveNode.SetControlPointPositionsWorld(sampledPoints) - - curveNodePlane = vtk.vtkPlane() - slicer.modules.markups.logic().GetBestFitPlane(resampledCurveNode, curveNodePlane) - - # Z axis (from first curve point to last, this will be the straightened curve long axis) - curveStartPoint = np.zeros(3) - curveEndPoint = np.zeros(3) - resampledCurveNode.GetNthControlPointPositionWorld(0, curveStartPoint) - resampledCurveNode.GetNthControlPointPositionWorld(resampledCurveNode.GetNumberOfControlPoints()-1, curveEndPoint) - transformGridAxisZ = (curveEndPoint-curveStartPoint)/np.linalg.norm(curveEndPoint-curveStartPoint) - - if stretching: - # Y axis = best fit plane normal - transformGridAxisY = np.copy(curveNodePlane.GetNormal()) - - # X axis normalize - transformGridAxisX = np.cross(transformGridAxisZ, transformGridAxisY) - transformGridAxisX = transformGridAxisX/np.linalg.norm(transformGridAxisX) - - # Make sure that Z axis is orthogonal to X and Y - orthogonalizedTransformGridAxisZ = np.cross(transformGridAxisX, transformGridAxisY) - orthogonalizedTransformGridAxisZ = orthogonalizedTransformGridAxisZ/np.linalg.norm(orthogonalizedTransformGridAxisZ) - if np.dot(transformGridAxisZ, orthogonalizedTransformGridAxisZ) > 0: - transformGridAxisZ = orthogonalizedTransformGridAxisZ - else: - transformGridAxisZ = -orthogonalizedTransformGridAxisZ - transformGridAxisX = -transformGridAxisX - - else: - - # X axis = average X axis of curve, to minimize torsion (and so have a simple displacement field, which can be robustly inverted) - sumCurveAxisX_RAS = np.zeros(3) - numberOfPoints = resampledCurveNode.GetNumberOfControlPoints() - for gridK in range(numberOfPoints): - curvePointToWorld = vtk.vtkMatrix4x4() - resampledCurveNode.GetCurvePointToWorldTransformAtPointIndex(resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK), curvePointToWorld) - curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(curvePointToWorld) - curveAxisX_RAS = curvePointToWorldArray[0:3, 0] - sumCurveAxisX_RAS += curveAxisX_RAS - meanCurveAxisX_RAS = sumCurveAxisX_RAS/np.linalg.norm(sumCurveAxisX_RAS) - transformGridAxisX = meanCurveAxisX_RAS - - # Y axis normalize - transformGridAxisY = np.cross(transformGridAxisZ, transformGridAxisX) - transformGridAxisY = transformGridAxisY/np.linalg.norm(transformGridAxisY) - - # Make sure that X axis is orthogonal to Y and Z - transformGridAxisX = np.cross(transformGridAxisY, transformGridAxisZ) - transformGridAxisX = transformGridAxisX/np.linalg.norm(transformGridAxisX) - - # Rotate by rotationDeg around the Z axis - gridDirectionMatrixArray = np.eye(4) - gridDirectionMatrixArray[0:3, 0] = transformGridAxisX - gridDirectionMatrixArray[0:3, 1] = transformGridAxisY - gridDirectionMatrixArray[0:3, 2] = transformGridAxisZ - gridDirectionMatrix = slicer.util.vtkMatrixFromArray(gridDirectionMatrixArray) - # - gridDirectionTransform = vtk.vtkTransform() - gridDirectionTransform.Concatenate(gridDirectionMatrix) - gridDirectionTransform.RotateZ(rotationDeg) - # - gridDirectionMatrixArray = slicer.util.arrayFromVTKMatrix(gridDirectionTransform.GetMatrix()) - transformGridAxisX = gridDirectionMatrixArray[0:3, 0] - transformGridAxisY = gridDirectionMatrixArray[0:3, 1] - transformGridAxisZ = gridDirectionMatrixArray[0:3, 2] - - if stretching: - # Project curve points to grid YZ plane - transformFromGridYZPlane = np.eye(4) - transformFromGridYZPlane[0:3, 0] = transformGridAxisY - transformFromGridYZPlane[0:3, 1] = transformGridAxisZ - transformFromGridYZPlane[0:3, 2] = transformGridAxisX - transformFromGridYZPlane[0:3, 3] = curveNodePlane.GetOrigin() - transformToGridYZPlane = np.linalg.inv(transformFromGridYZPlane) - - originalCurvePointsArray = slicer.util.arrayFromMarkupsCurvePoints(curveNode) - curvePointsProjected_RAS = CurvedPlanarReformatLogic.getPointsProjectedToPlane(originalCurvePointsArray.T, transformToGridYZPlane).T - slicer.util.updateMarkupsControlPointsFromArray(resampledCurveNode, curvePointsProjected_RAS) - - # After projection, resampling is needed to get uniform distances - originalCurvePoints = resampledCurveNode.GetCurvePointsWorld() - sampledPoints = vtk.vtkPoints() - if not slicer.vtkMRMLMarkupsCurveNode.ResamplePoints(originalCurvePoints, sampledPoints, resamplingCurveSpacing, False): - raise ValueError("Resampling curve failed") - resampledCurveNode.SetControlPointPositionsWorld(sampledPoints) - - # Origin (makes the grid centered at the curve) - curveLength = resampledCurveNode.GetCurveLengthWorld() - transformGridOrigin = np.array(curveNodePlane.GetOrigin()) - transformGridOrigin -= transformGridAxisX * sliceSizeMm[0]/2.0 - transformGridOrigin -= transformGridAxisY * sliceSizeMm[1]/2.0 - transformGridOrigin -= transformGridAxisZ * curveLength/2.0 - - # Create grid transform - # Each corner of each slice is mapped from the original volume's reformatted slice - # to the straightened volume slice. - # The grid transform contains one vector at the corner of each slice. - # The transform is in the same space and orientation as the straightened volume. - - numberOfSlices = resampledCurveNode.GetNumberOfControlPoints() - gridDimensions = [2, 2, numberOfSlices] - gridSpacing = [sliceSizeMm[0], sliceSizeMm[1], resamplingCurveSpacing] - gridDirectionMatrixArray = np.eye(4) - gridDirectionMatrixArray[0:3, 0] = transformGridAxisX - gridDirectionMatrixArray[0:3, 1] = transformGridAxisY - gridDirectionMatrixArray[0:3, 2] = transformGridAxisZ - gridDirectionMatrix = slicer.util.vtkMatrixFromArray(gridDirectionMatrixArray) - - gridImage = vtk.vtkImageData() - gridImage.SetOrigin(transformGridOrigin) - gridImage.SetDimensions(gridDimensions) - gridImage.SetSpacing(gridSpacing) - gridImage.AllocateScalars(vtk.VTK_DOUBLE, 3) - transform = slicer.vtkOrientedGridTransform() - transform.SetDisplacementGridData(gridImage) - transform.SetGridDirectionMatrix(gridDirectionMatrix) - transformToStraightenedNode.SetAndObserveTransformFromParent(transform) - - if reslicingPlanesModelNode: - appender = vtk.vtkAppendPolyData() - - # Currently there is no API to set PreferredInitialNormalVector in the curve coordinate system, therefore - # a new coordinate system generator must be set up: - curveCoordinateSystemGeneratorWorld = slicer.vtkParallelTransportFrame() - curveCoordinateSystemGeneratorWorld.SetInputData(resampledCurveNode.GetCurveWorld()) - curveCoordinateSystemGeneratorWorld.SetPreferredInitialNormalVector(transformGridAxisX) - curveCoordinateSystemGeneratorWorld.Update() - curvePoly = curveCoordinateSystemGeneratorWorld.GetOutput() - pointData = curvePoly.GetPointData() - normals = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetNormalsArrayName()) - binormals = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetBinormalsArrayName()) - tangents = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetTangentsArrayName()) - - # Compute displacements - transformDisplacements_RAS = slicer.util.arrayFromGridTransform(transformToStraightenedNode) - for gridK in range(gridDimensions[2]): - - # The curve's built-in coordinate system generator could be used like this (if it had PreferredInitialNormalVector exposed): - # - # curvePointToWorld = vtk.vtkMatrix4x4() - # resampledCurveNode.GetCurvePointToWorldTransformAtPointIndex(resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK), curvePointToWorld) - # curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(curvePointToWorld) - # curveAxisX_RAS = curvePointToWorldArray[0:3, 0] - # curveAxisY_RAS = curvePointToWorldArray[0:3, 1] - # curvePoint_RAS = curvePointToWorldArray[0:3, 3] - # - # But now we get the values from our own coordinate system generator: - curvePointIndex = resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK) - curveAxisX_RAS = np.array(normals.GetTuple3(curvePointIndex)) - curveAxisY_RAS = np.array(binormals.GetTuple3(curvePointIndex)) - curvePoint_RAS = np.array(curvePoly.GetPoint(curvePointIndex)) - - for gridJ in range(gridDimensions[1]): - for gridI in range(gridDimensions[0]): - straightenedVolume_RAS = (transformGridOrigin - + gridI*gridSpacing[0]*transformGridAxisX - + gridJ*gridSpacing[1]*transformGridAxisY - + gridK*gridSpacing[2]*transformGridAxisZ) - inputVolume_RAS = (curvePoint_RAS - + (gridI-0.5)*sliceSizeMm[0]*curveAxisX_RAS - + (gridJ-0.5)*sliceSizeMm[1]*curveAxisY_RAS) - if reslicingPlanesModelNode: - if gridI == 0 and gridJ == 0: - plane = vtk.vtkPlaneSource() - plane.SetOrigin(inputVolume_RAS) - elif gridI == 1 and gridJ == 0: - plane.SetPoint1(inputVolume_RAS) - elif gridI == 0 and gridJ == 1: - plane.SetPoint2(inputVolume_RAS) - transformDisplacements_RAS[gridK][gridJ][gridI] = inputVolume_RAS - straightenedVolume_RAS - - if reslicingPlanesModelNode: - plane.Update() - appender.AddInputData(plane.GetOutput()) - - slicer.util.arrayFromGridTransformModified(transformToStraightenedNode) - - # delete temporary curve - slicer.mrmlScene.RemoveNode(resampledCurveNode) - - if reslicingPlanesModelNode: - appender.Update() - if not reslicingPlanesModelNode.GetPolyData(): - reslicingPlanesModelNode.CreateDefaultDisplayNodes() - reslicingPlanesModelNode.GetDisplayNode().SetVisibility2D(True) - reslicingPlanesModelNode.SetAndObservePolyData(appender.GetOutput()) - + return self.sliceLogic.CurvedPlanarReformationComputeStraighteningTransform( + transformToStraightenedNode, + curveNode, + sliceSizeMm, + outputSpacingMm, + stretching, + rotationDeg, + reslicingPlanesModelNode, + ) def straightenVolume(self, outputStraightenedVolume, volumeNode, outputStraightenedVolumeSpacing, straighteningTransformNode): """ Compute straightened volume (useful for example for visualization of curved vessels) """ - gridTransform = straighteningTransformNode.GetTransformFromParentAs("vtkOrientedGridTransform") - if not gridTransform: - raise ValueError("Straightening transform is expected to contain a vtkOrientedGridTransform form parent") - - # Get transformation grid geometry - gridIjkToRasDirectionMatrix = gridTransform.GetGridDirectionMatrix() - gridTransformImage = gridTransform.GetDisplacementGrid() - gridOrigin = gridTransformImage.GetOrigin() - gridSpacing = gridTransformImage.GetSpacing() - gridDimensions = gridTransformImage.GetDimensions() - gridExtentMm = [gridSpacing[0]*(gridDimensions[0]-1), gridSpacing[1]*(gridDimensions[1]-1), gridSpacing[2]*(gridDimensions[2]-1)] - - # Compute IJK to RAS matrix of output volume - # Get grid axis directions - straightenedVolumeIJKToRASArray = slicer.util.arrayFromVTKMatrix(gridIjkToRasDirectionMatrix) - # Apply scaling - straightenedVolumeIJKToRASArray = np.dot(straightenedVolumeIJKToRASArray, - np.diag([outputStraightenedVolumeSpacing[0], outputStraightenedVolumeSpacing[1], outputStraightenedVolumeSpacing[2], 1])) - # Set origin - straightenedVolumeIJKToRASArray[0:3,3] = gridOrigin - - outputStraightenedImageData = vtk.vtkImageData() - outputStraightenedImageData.SetExtent( - 0, int(gridExtentMm[0]/outputStraightenedVolumeSpacing[0])-1, - 0, int(gridExtentMm[1]/outputStraightenedVolumeSpacing[1])-1, - 0, int(gridExtentMm[2]/outputStraightenedVolumeSpacing[2])-1) - outputStraightenedImageData.AllocateScalars(volumeNode.GetImageData().GetScalarType(), volumeNode.GetImageData().GetNumberOfScalarComponents()) - outputStraightenedVolume.SetAndObserveImageData(outputStraightenedImageData) - outputStraightenedVolume.SetIJKToRASMatrix(slicer.util.vtkMatrixFromArray(straightenedVolumeIJKToRASArray)) - - # Resample input volume to straightened volume - parameters = {} - parameters["inputVolume"] = volumeNode.GetID() - parameters["outputVolume"] = outputStraightenedVolume.GetID() - parameters["referenceVolume"] = outputStraightenedVolume.GetID() - parameters["transformationFile"] = straighteningTransformNode.GetID() - # Use nearest neighbor interpolation for label volumes (to avoid incorrect labels at boundaries) - # and higher-order (bspline) interpolation for scalar volumes. - parameters["interpolationType"] = "nn" if volumeNode.IsA('vtkMRMLLabelMapVolumeNode') else "bs" - resamplerModule = slicer.modules.resamplescalarvectordwivolume - parameterNode = slicer.cli.runSync(resamplerModule, None, parameters) - - outputStraightenedVolume.CreateDefaultDisplayNodes() - outputStraightenedVolume.GetDisplayNode().CopyContent(volumeNode.GetDisplayNode()) - slicer.mrmlScene.RemoveNode(parameterNode) + return self.sliceLogic.CurvedPlanarReformationStraightenVolume( + outputStraightenedVolume, volumeNode, outputStraightenedVolumeSpacing, straighteningTransformNode + ) def projectVolume(self, outputProjectedVolume, inputStraightenedVolume, projectionAxisIndex = 0): """Create panoramic volume by mean intensity projection along an axis of the straightened volume """ - - projectedImageData = vtk.vtkImageData() - outputProjectedVolume.SetAndObserveImageData(projectedImageData) - straightenedImageData = inputStraightenedVolume.GetImageData() - - outputImageDimensions = list(straightenedImageData.GetDimensions()) - outputImageDimensions[projectionAxisIndex] = 1 - projectedImageData.SetDimensions(outputImageDimensions) - - projectedImageData.AllocateScalars(straightenedImageData.GetScalarType(), straightenedImageData.GetNumberOfScalarComponents()) - outputProjectedVolumeArray = slicer.util.arrayFromVolume(outputProjectedVolume) - inputStraightenedVolumeArray = slicer.util.arrayFromVolume(inputStraightenedVolume) - - if projectionAxisIndex == 0: - outputProjectedVolumeArray[:, :, 0] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex) - elif projectionAxisIndex == 1: - outputProjectedVolumeArray[:, 0, :] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex) - else: - outputProjectedVolumeArray[0, :, :] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex) - - slicer.util.arrayFromVolumeModified(outputProjectedVolume) - - # Shift projection image into the center of the input image - ijkToRas = vtk.vtkMatrix4x4() - inputStraightenedVolume.GetIJKToRASMatrix(ijkToRas) - curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(ijkToRas) - origin = curvePointToWorldArray[0:3, 3] - offsetToCenterDirectionVector = curvePointToWorldArray[0:3, projectionAxisIndex] - offsetToCenterDirectionLength = inputStraightenedVolume.GetImageData().GetDimensions()[projectionAxisIndex] * inputStraightenedVolume.GetSpacing()[projectionAxisIndex] - newOrigin = origin + offsetToCenterDirectionVector * offsetToCenterDirectionLength - ijkToRas.SetElement(0, 3, newOrigin[0]) - ijkToRas.SetElement(1, 3, newOrigin[1]) - ijkToRas.SetElement(2, 3, newOrigin[2]) - outputProjectedVolume.SetIJKToRASMatrix(ijkToRas) - outputProjectedVolume.CreateDefaultDisplayNodes() - - return True + return self.sliceLogic.CurvedPlanarReformationProjectVolume( + outputProjectedVolume, inputStraightenedVolume, projectionAxisIndex + ) class CurvedPlanarReformatTest(ScriptedLoadableModuleTest): """