statismo代码分析--statismo-build-shape-model | HyperPlane

statismo代码分析--statismo-build-shape-model

statismo中statismo-build-shape-model程序代码的解读

总结一下就是:给出参考帧(因为PCA需要先中心化)之后PCA

流程图


命令行参数

  • -l, --data-list DATA_LIST :指定一个文件路径,这个文件包含所有用来生成模型的mesh文件路径,每一行只写一个文件
  • -o, --output-file OUTPUT_FILE :输出模型文件的路径
  • -p, --procrustes PROCRUSTES_MODE :选择数据对齐(aligned)模式。如果选择reference,所有的数据和指定的参考mesh对齐;如果选择GPA,则和均值对齐
  • -r, --reference FILE :在PROCRUSTES_MODE选择reference之后,指定参考mesh
  • -n, --noise NOISE :指定PPCA(probabilistic principal component analysis)模型的噪声方差,默认为0

读入mesh文件的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
typedef itk::MeshFileReader MeshReaderType;
typedef vector MeshReaderList;
MeshReaderList meshes;
meshes.reserve(fileNames.size());
for (StringList::const_iterator it = fileNames.begin(); it != fileNames.end();
++it)
{
MeshReaderType::Pointer reader = MeshReaderType::New();
reader->SetFileName(it->c_str());
reader->Update();
//这段注释是说Update()这个函数很重要
// itk::PCAModelBuilder is not a Filter in the ITK world, so the pipeline
// would not get executed if its main method is called. So the pipeline
// before calling itk::PCAModelBuilder must be executed by the means of calls
// to Update() (at least for last elements needed by itk::PCAModelBuilder).
meshes.push_back(reader);
}

关于为什么要用Update()的解释:
You only have to call Update() on the last filter in your pipeline. The rest of this answer is the explanation.
ITK uses a pipeline execution framework for filters. Assume we have three filters that are connected sequentially like the following:
input --> |filter1| --> |filter2| --> |filter3| --> output
If you call Update() on filter3, ITK starts from filter3 and checks if the input(s) to each filter have changed. If they have, ITK calls update on them in turn. See slide 5 of this link.


计算平均mesh作为参考mesh的代码

originalMeshes传入的是meshes的指针,也就是align也会影响meshes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
typedef itk::Mesh<float, Dimensions> MeshType;
///////////////////////////////////////////////////////////////

// 将已经读入的mesh都拷贝过来,mesh为之前读入的文件
vector originalMeshes;
for (MeshReaderList::iterator it = meshes.begin(); it != meshes.end();
++it)
{
MeshReaderType::Pointer reader = *it;
originalMeshes.push_back(reader->GetOutput());
}

const unsigned uMaxGPAIterations = 20;
const unsigned uNumberOfPoints = 100; // 最多使用这么多个点
const float fBreakIfChangeBelow = 0.001f;

typedef itk::VersorRigid3DTransform<float> Rigid3DTransformType;
typedef itk::Image<float, Dimensions> ImageType;
typedef itk::LandmarkBasedTransformInitializer
ImageType, ImageType>
LandmarkBasedTransformInitializerType;
typedef itk::TransformMeshFilter
FilterType;
// 计算参考mesh
MeshType::Pointer referenceMesh =
calculateProcrustesMeanMesh
LandmarkBasedTransformInitializerType,
Rigid3DTransformType, FilterType>(
originalMeshes, uMaxGPAIterations, uNumberOfPoints,
fBreakIfChangeBelow);
representer->SetReference(referenceMesh);

计算参考mesh的函数

计算流程:

  1. 随机取出一些点,只使用这些点进行align
  2. 以其中一帧mesh作为初始参考对所有帧进行align
  3. 对align后的所有mesh算一个平均mesh
  4. 计算当前参考mesh和当前平均mesh的差值
  5. 如果差值比较大并且没有达到最大迭代次数就用当前平均mesh覆盖当前参考mesh然后跳转回2,否则返回当前平均mesh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
template <class MeshType, class LandmarkBasedTransformInitializerType,
class TransformType, class FilterType>
typename MeshType::Pointer
calculateProcrustesMeanMesh(std::vector<typename MeshType::Pointer> meshes,
unsigned maxIterations, unsigned nrOfLandmarks,
float breakIfChangeBelow)
{
// the initial mesh to which all others will be aligned to is the first one in
// the list here. Any other mesh could be chosen as well
// 设定第一帧mesh为初始参考mesh
typename MeshType::Pointer referenceMesh = *meshes.begin();

unsigned rngSeed = time(0);
unsigned meshVerticesCount = referenceMesh->GetNumberOfPoints();
srand(rngSeed); // 随机数生成器,跟rand()有关
std::set<unsigned> pointNumbers;
// 使用点的数量不能超过nrOfLandmarks
// 使用点的索引是随机获取的
while (pointNumbers.size() < std::min(nrOfLandmarks, meshVerticesCount))
{
// 随机获得一个点的索引,但是没有多余的处理是否会添加重复的点?
unsigned randomIndex = ((unsigned)rand()) % meshVerticesCount;
pointNumbers.insert(randomIndex);
}

float fPreviousDifference = -1;

// 进入迭代
for (unsigned i = 0; i < maxIterations; ++i)
{
// calculate the difference to the previous iteration's mesh and break if
// the difference is very small
// 计算每一个mesh和当前参考mesh之间的刚体变换,并且进行align
std::vector<typename MeshType::Pointer> translatedMeshes =
superimposeMeshes
TransformType, FilterType>(meshes, referenceMesh,
pointNumbers);
// 计算align后的平均mesh
typename MeshType::Pointer meanMesh =
calculateMeanMesh(translatedMeshes);
// 计算当前平均mesh和当前参考mesh的差值
float fDifference =
calculateMeshDistance(meanMesh, referenceMesh);
float fDifferenceDelta = std::abs(fDifference - fPreviousDifference);
fPreviousDifference = fDifference;
referenceMesh = meanMesh;

// 如果小于阈值认为收敛了
if (fDifferenceDelta < breakIfChangeBelow)
{
break;
}
}
return referenceMesh;
}

Align的函数

使用随机采样的对应点计算每一帧mesh和当前参考mesh之间的刚体变换,并且进行align

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
template <class MeshType, class LandmarkBasedTransformInitializerType,
class TransformType, class FilterType>

std::vector<typename MeshType::Pointer>

superimposeMeshes(std::vector<typename MeshType::Pointer> originalMeshes,
typename MeshType::Pointer referenceMesh,
std::set<unsigned> landmarkIndices)
{
std::vector<typename MeshType::Pointer> translatedMeshes(
originalMeshes.begin(), originalMeshes.end());
//遍历每一帧mesh,计算每一mesh相对于当前参考mesh的刚体变换并align
for (typename std::vector<typename MeshType::Pointer>::iterator it =
translatedMeshes.begin();
it != translatedMeshes.end(); ++it)
{
typedef
typename LandmarkBasedTransformInitializerType::LandmarkPointContainer
LandmarkContainerType;
LandmarkContainerType movingLandmarks;
LandmarkContainerType fixedLandmarks;
typename MeshType::Pointer movingMesh = *it;

// 数据点数和拓扑都需要一致(实际上索引相同的点需要已经对应好)
if (movingMesh->GetNumberOfPoints() != referenceMesh->GetNumberOfPoints() ||
movingMesh->GetNumberOfCells() != referenceMesh->GetNumberOfCells())
{
itkGenericExceptionMacro(
<< "All meshes must have the same number of Edges & Vertices");
}

// Only use a subset of the meshes' points for the alignment since we don't
// have that many degrees of freedom anyways and since calculating a SVD with
// too many points is expensive
// 使用前面随机选出来的部分点进行计算
for (std::set<unsigned>::const_iterator rng = landmarkIndices.begin();
rng != landmarkIndices.end(); ++rng)
{
movingLandmarks.push_back(movingMesh->GetPoint(*rng));
fixedLandmarks.push_back(referenceMesh->GetPoint(*rng));
}

// only rotate & translate the moving mesh to best fit with the fixed mesh;
// there's no scaling taking place.
// 使用这一部分点计算mesh和当前参考mesh之间的刚体变换
typename LandmarkBasedTransformInitializerType::Pointer
landmarkBasedTransformInitializer =
LandmarkBasedTransformInitializerType::New();
landmarkBasedTransformInitializer->SetFixedLandmarks(fixedLandmarks);
landmarkBasedTransformInitializer->SetMovingLandmarks(movingLandmarks);
typename TransformType::Pointer transform = TransformType::New();
transform->SetIdentity();
landmarkBasedTransformInitializer->SetTransform(transform);
landmarkBasedTransformInitializer->InitializeTransform();

// 使用计算得到的刚体变换进行align,这里是mesh上所有的点
typename FilterType::Pointer filter = FilterType::New();
filter->SetInput(movingMesh);
filter->SetTransform(transform);
filter->Update();

*it = filter->GetOutput();
}

return translatedMeshes;
}

计算平均mesh的函数

遍历求和,然后遍历求平均

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
template <class MeshType>
typename MeshType::Pointer
calculateMeanMesh(std::vector<typename MeshType::Pointer> meshes)
{
// 必须要存在mesh
if (meshes.size() == 0)
{
itkGenericExceptionMacro(
<< "Can't calculate the mean since no meshes were provided.");
}

// 使用这个类型的原因是这个类型精度更高
typedef itk::CompensatedSummation<typename MeshType::PixelType>
CompensatedSummationType;
typedef std::vector MeshPointsVectorType;

typename MeshType::Pointer pFirstMesh = *meshes.begin();

// prepare for summation
MeshPointsVectorType vMeshPoints;
// shape向量的长度,点数乘上每个点的维度
unsigned uDataSize =
pFirstMesh->GetNumberOfPoints() * MeshType::PointDimension;
vMeshPoints.reserve(uDataSize);
for (int i = 0; i < uDataSize; ++i)
{
CompensatedSummationType sum;
vMeshPoints.push_back(sum);
}

// 遍历每个mesh
// 求和
for (typename std::vector<typename MeshType::Pointer>::const_iterator i =
meshes.begin();
i != meshes.end(); ++i)
{
typename MeshType::Pointer pMesh = *i;
// 验证向量维度(点的个数和点的维度)要一样
if (vMeshPoints.size() !=
pMesh->GetNumberOfPoints() * MeshType::PointDimension)
{
itkGenericExceptionMacro(
<< "All meshes must have the same number of Edges");
}

typename MeshPointsVectorType::iterator sum = vMeshPoints.begin();
typename MeshType::PointsContainer::ConstIterator pointData =
pMesh->GetPoints()->Begin();
// sum up all meshes
// 遍历这个mesh的每个点
for (; pointData != pMesh->GetPoints()->End(); ++pointData)
{
const typename MeshType::PointType point = pointData->Value();
// 遍历这个点的每个维度
// 将对应维度的数值加入vMeshPoints中
for (typename MeshType::PointType::ConstIterator pointIter =
point.Begin();
pointIter != point.End(); ++pointIter, ++sum)
{
(*sum) += *pointIter;
}
}
}

float fInvNumberOfMeshes = 1.0f / meshes.size();
// 虽然是复制的第一帧mesh,但是后面是直接覆盖
typename MeshType::Pointer pMeanMesh = cloneMesh(pFirstMesh);

// write the data to the mean mesh
typename MeshPointsVectorType::iterator sum = vMeshPoints.begin();
// 遍历meanmesh的每个点
// 计算平均值
for (typename MeshType::PointsContainer::Iterator pointData =
pMeanMesh->GetPoints()->Begin();
pointData != pMeanMesh->GetPoints()->End(); ++pointData)
{
// 遍历点的每个维度
for (typename MeshType::PointType::Iterator pointIter =
pointData->Value().Begin();
pointIter != pointData->Value().End(); ++pointIter, ++sum)
{
// 计算平均值传给meanmesh
*pointIter = sum->GetSum() * fInvNumberOfMeshes;
}
}

return pMeanMesh;
}

计算平均mesh和当前参考mesh的差值

每对对应点的距离求和,再除以shape向量的整体维度进行平均

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
template <class MeshType>
float calculateMeshDistance(typename MeshType::Pointer mesh1,
typename MeshType::Pointer mesh2)
{
// 点数和拓扑要一样(实际上也需要完全对应)
if (mesh1->GetNumberOfPoints() != mesh2->GetNumberOfPoints() ||
mesh1->GetNumberOfCells() != mesh2->GetNumberOfCells())
{
itkGenericExceptionMacro(
<< "Both meshes must have the same number of Edges & Vertices");
}

float fDifference = 0;
typedef typename MeshType::PointsContainer::Iterator IteratorType;
IteratorType point1 = mesh1->GetPoints()->Begin();
IteratorType point2 = mesh2->GetPoints()->Begin();
// 每对对应点的距离求和
for (; point1 != mesh1->GetPoints()->End(); ++point1, ++point2)
{
fDifference += point1->Value().SquaredEuclideanDistanceTo(point2->Value());
}
// 除以(mesh点的个数 乘上 点的维度)
fDifference /= (mesh1->GetNumberOfPoints() * MeshType::PointDimension);
return fDifference;
}

将参考帧加入数据管理

1
2
3
4
5
6
7
8
9
10
11
typedef itk::Mesh<float, Dimensions> MeshType;
typedef itk::DataManager DataManagerType;
DataManagerType::Pointer dataManager = DataManagerType::New();
////////////////////////////////////////////////////////////////////
dataManager->SetRepresenter(representer);

for (MeshReaderList::const_iterator it = meshes.begin(); it != meshes.end();
++it) {
MeshReaderType::Pointer reader = *it;
dataManager->AddDataset(reader->GetOutput(), reader->GetFileName());
}

进行PCA并保存model

这个PCA之前没有align的过程??GPA会进行align,但是如果选择reference这里就没有进行align了?

1
2
3
4
5
6
7
8
9
10
11
12
// model类型
typedef itk::StatisticalModel StatisticalModelType;
StatisticalModelType::Pointer model;
// 进行PCA
typedef itk::PCAModelBuilder PCAModelBuilder;
PCAModelBuilder::Pointer pcaModelBuilder = PCAModelBuilder::New();
// 直接就进行PCA了
model = pcaModelBuilder->BuildNewModel(dataManager->GetData(),
opt.fNoiseVariance);
// 保存为文件
itk::StatismoIO::SaveStatisticalModel(
model, opt.strOutputFileName.c_str());

PCA函数
比较常规的PCA计算过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
template <typename T> PCAModelBuilder::PCAModelBuilder() : Superclass() {}

template <typename T>
typename PCAModelBuilder::StatisticalModelType *
PCAModelBuilder::BuildNewModel(const DataItemListType &sampleDataList,
double noiseVariance, bool computeScores,
EigenValueMethod method) const {
// 设定数据样本个数
unsigned n = sampleDataList.size();
if (n <= 0) {
throw StatisticalModelException(
"Provided empty sample set. Cannot build the sample matrix");
}

unsigned p = sampleDataList.front()->GetSampleVector().rows();
const Representer *representer = sampleDataList.front()->GetRepresenter();

// Compute the mean vector mu
// 求和在平均得到重心
VectorType mu = VectorType::Zero(p);

for (typename DataItemListType::const_iterator it = sampleDataList.begin();
it != sampleDataList.end(); ++it) {
assert((*it)->GetSampleVector().rows() ==
p); // all samples must have same number of rows
assert((*it)->GetRepresenter() ==
representer); // all samples have the same representer
mu += (*it)->GetSampleVector();
}
mu /= n;

// Build the mean free sample matrix X0
// 减去重心,中心化
MatrixType X0(n, p);
unsigned i = 0;
for (typename DataItemListType::const_iterator it = sampleDataList.begin();
it != sampleDataList.end(); ++it) {
X0.row(i++) = (*it)->GetSampleVector() - mu;
}

// build the model
// 使用SVD或者特征值分解等方式计算
StatisticalModelType *model =
BuildNewModelInternal(representer, X0, mu, noiseVariance, method);

...

return model;
}