Skip to content

Commit dcebfde

Browse files
authored
Merge pull request #188 from tbirdso/point-set-registration
2 parents 0db8e1d + d50d4ef commit dcebfde

File tree

5 files changed

+473
-0
lines changed

5 files changed

+473
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_example(PerformRegistrationOnVectorImages)
2+
add_example(RegisterTwoPointSets)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
cmake_minimum_required(VERSION 3.10.2)
2+
3+
project(RegisterTwoPointSets)
4+
5+
find_package(ITK REQUIRED)
6+
include(${ITK_USE_FILE})
7+
8+
9+
add_executable(RegisterTwoPointSets Code.cxx)
10+
target_link_libraries(RegisterTwoPointSets ${ITK_LIBRARIES})
11+
12+
install(TARGETS RegisterTwoPointSets
13+
DESTINATION bin/ITKExamples/Registration/Common
14+
COMPONENT Runtime
15+
)
16+
17+
install(FILES Code.cxx Code.py CMakeLists.txt
18+
DESTINATION share/ITKExamples/Code/Registration/Common/RegisterTwoPointSets/
19+
COMPONENT Code
20+
)
21+
22+
23+
enable_testing()
24+
add_test(NAME RegisterTwoPointSetsTest
25+
COMMAND ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/RegisterTwoPointSets)
26+
27+
if(ITK_WRAP_PYTHON)
28+
add_test(NAME RegisterTwoPointSetsTest2DPython
29+
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/Code.py 2
30+
)
31+
add_test(NAME RegisterTwoPointSetsTest3DPython
32+
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/Code.py 3)
33+
endif()
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*=========================================================================
2+
*
3+
* Copyright NumFOCUS
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0.txt
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*=========================================================================*/
18+
19+
// Adapted from ITK itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTest.cxx
20+
21+
#include "itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.h"
22+
#include "itkGradientDescentOptimizerv4.h"
23+
#include "itkTransform.h"
24+
#include "itkAffineTransform.h"
25+
#include "itkRegistrationParameterScalesFromPhysicalShift.h"
26+
#include "itkCommand.h"
27+
28+
#include <fstream>
29+
30+
template <typename TFilter>
31+
class itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate : public itk::Command
32+
{
33+
public:
34+
using Self = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate;
35+
36+
using Superclass = itk::Command;
37+
using Pointer = itk::SmartPointer<Self>;
38+
itkNewMacro(Self);
39+
40+
protected:
41+
itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate() = default;
42+
43+
public:
44+
void
45+
Execute(itk::Object * caller, const itk::EventObject & event) override
46+
{
47+
Execute((const itk::Object *)caller, event);
48+
}
49+
50+
void
51+
Execute(const itk::Object * object, const itk::EventObject & event) override
52+
{
53+
if (typeid(event) != typeid(itk::IterationEvent))
54+
{
55+
return;
56+
}
57+
const auto * optimizer = dynamic_cast<const TFilter *>(object);
58+
59+
if (!optimizer)
60+
{
61+
itkGenericExceptionMacro("Error dynamic_cast failed");
62+
}
63+
std::cout << "It: " << optimizer->GetCurrentIteration() << " metric value: " << optimizer->GetCurrentMetricValue();
64+
std::cout << std::endl;
65+
}
66+
};
67+
68+
69+
int
70+
main(int argc, char * argv[])
71+
{
72+
constexpr unsigned int Dimension = 2;
73+
74+
unsigned int numberOfIterations = 10;
75+
if (argc > 1)
76+
{
77+
numberOfIterations = std::stoi(argv[1]);
78+
}
79+
80+
using PointSetType = itk::PointSet<unsigned char, Dimension>;
81+
82+
using PointType = PointSetType::PointType;
83+
84+
PointSetType::Pointer fixedPoints = PointSetType::New();
85+
fixedPoints->Initialize();
86+
87+
PointSetType::Pointer movingPoints = PointSetType::New();
88+
movingPoints->Initialize();
89+
90+
91+
// two ellipses, one rotated slightly
92+
/*
93+
// Having trouble with these, as soon as there's a slight rotation added.
94+
unsigned long count = 0;
95+
for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
96+
{
97+
float radius = 100.0;
98+
PointType fixedPoint;
99+
fixedPoint[0] = 2 * radius * std::cos( theta );
100+
fixedPoint[1] = radius * std::sin( theta );
101+
fixedPoints->SetPoint( count, fixedPoint );
102+
103+
PointType movingPoint;
104+
movingPoint[0] = 2 * radius * std::cos( theta + (0.02 * itk::Math::pi) ) + 2.0;
105+
movingPoint[1] = radius * std::sin( theta + (0.02 * itk::Math::pi) ) + 2.0;
106+
movingPoints->SetPoint( count, movingPoint );
107+
108+
count++;
109+
}
110+
*/
111+
112+
// two circles with a small offset
113+
PointType offset;
114+
for (unsigned int d = 0; d < Dimension; d++)
115+
{
116+
offset[d] = 2.0;
117+
}
118+
unsigned long count = 0;
119+
for (float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1)
120+
{
121+
PointType fixedPoint;
122+
float radius = 100.0;
123+
fixedPoint[0] = radius * std::cos(theta);
124+
fixedPoint[1] = radius * std::sin(theta);
125+
if (Dimension > 2)
126+
{
127+
fixedPoint[2] = radius * std::sin(theta);
128+
}
129+
fixedPoints->SetPoint(count, fixedPoint);
130+
131+
PointType movingPoint;
132+
movingPoint[0] = fixedPoint[0] + offset[0];
133+
movingPoint[1] = fixedPoint[1] + offset[1];
134+
if (Dimension > 2)
135+
{
136+
movingPoint[2] = fixedPoint[2] + offset[2];
137+
}
138+
movingPoints->SetPoint(count, movingPoint);
139+
140+
count++;
141+
}
142+
143+
using AffineTransformType = itk::AffineTransform<double, Dimension>;
144+
AffineTransformType::Pointer transform = AffineTransformType::New();
145+
transform->SetIdentity();
146+
147+
// Instantiate the metric
148+
using PointSetMetricType = itk::JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4<PointSetType>;
149+
PointSetMetricType::Pointer metric = PointSetMetricType::New();
150+
metric->SetFixedPointSet(fixedPoints);
151+
metric->SetMovingPointSet(movingPoints);
152+
metric->SetPointSetSigma(1.0);
153+
metric->SetKernelSigma(10.0);
154+
metric->SetUseAnisotropicCovariances(false);
155+
metric->SetCovarianceKNeighborhood(5);
156+
metric->SetEvaluationKNeighborhood(10);
157+
metric->SetMovingTransform(transform);
158+
metric->SetAlpha(1.1);
159+
metric->Initialize();
160+
161+
// scales estimator
162+
using RegistrationParameterScalesFromShiftType =
163+
itk::RegistrationParameterScalesFromPhysicalShift<PointSetMetricType>;
164+
RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator =
165+
RegistrationParameterScalesFromShiftType::New();
166+
shiftScaleEstimator->SetMetric(metric);
167+
// needed with pointset metrics
168+
shiftScaleEstimator->SetVirtualDomainPointSet(metric->GetVirtualTransformedPointSet());
169+
170+
// optimizer
171+
using OptimizerType = itk::GradientDescentOptimizerv4;
172+
OptimizerType::Pointer optimizer = OptimizerType::New();
173+
optimizer->SetMetric(metric);
174+
optimizer->SetNumberOfIterations(numberOfIterations);
175+
optimizer->SetScalesEstimator(shiftScaleEstimator);
176+
optimizer->SetMaximumStepSizeInPhysicalUnits(3.0);
177+
178+
using CommandType = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate<OptimizerType>;
179+
CommandType::Pointer observer = CommandType::New();
180+
optimizer->AddObserver(itk::IterationEvent(), observer);
181+
182+
optimizer->SetMinimumConvergenceValue(0.0);
183+
optimizer->SetConvergenceWindowSize(10);
184+
optimizer->StartOptimization();
185+
186+
std::cout << "numberOfIterations: " << numberOfIterations << std::endl;
187+
std::cout << "Moving-source final value: " << optimizer->GetCurrentMetricValue() << std::endl;
188+
std::cout << "Moving-source final position: " << optimizer->GetCurrentPosition() << std::endl;
189+
std::cout << "Optimizer scales: " << optimizer->GetScales() << std::endl;
190+
std::cout << "Optimizer learning rate: " << optimizer->GetLearningRate() << std::endl;
191+
192+
// applying the resultant transform to moving points and verify result
193+
std::cout << "Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff" << std::endl;
194+
bool passed = true;
195+
PointType::ValueType tolerance = 1e-2;
196+
AffineTransformType::InverseTransformBasePointer movingInverse = metric->GetMovingTransform()->GetInverseTransform();
197+
AffineTransformType::InverseTransformBasePointer fixedInverse = metric->GetFixedTransform()->GetInverseTransform();
198+
for (unsigned int n = 0; n < metric->GetNumberOfComponents(); n++)
199+
{
200+
// compare the points in virtual domain
201+
PointType transformedMovingPoint = movingInverse->TransformPoint(movingPoints->GetPoint(n));
202+
PointType transformedFixedPoint = fixedInverse->TransformPoint(fixedPoints->GetPoint(n));
203+
PointType difference;
204+
difference[0] = transformedMovingPoint[0] - transformedFixedPoint[0];
205+
difference[1] = transformedMovingPoint[1] - transformedFixedPoint[1];
206+
std::cout << fixedPoints->GetPoint(n) << "\t" << movingPoints->GetPoint(n) << "\t" << transformedMovingPoint << "\t"
207+
<< transformedFixedPoint << "\t" << difference << std::endl;
208+
if (fabs(difference[0]) > tolerance || fabs(difference[1]) > tolerance)
209+
{
210+
passed = false;
211+
}
212+
}
213+
if (!passed)
214+
{
215+
std::cerr << "Results do not match truth within tolerance." << std::endl;
216+
return EXIT_FAILURE;
217+
}
218+
219+
220+
return EXIT_SUCCESS;
221+
}

0 commit comments

Comments
 (0)