📄 cluster.cpp
字号:
#include "muscle.h"
#include "cluster.h"
#include "distfunc.h"
static inline float Min(float d1, float d2)
{
return d1 < d2 ? d1 : d2;
}
static inline float Max(float d1, float d2)
{
return d1 > d2 ? d1 : d2;
}
static inline float Mean(float d1, float d2)
{
return (float) ((d1 + d2)/2.0);
}
#if _DEBUG
void ClusterTree::Validate(unsigned uNodeCount)
{
unsigned n;
ClusterNode *pNode;
unsigned uDisjointListCount = 0;
for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint())
{
ClusterNode *pPrev = pNode->GetPrevDisjoint();
ClusterNode *pNext = pNode->GetNextDisjoint();
if (0 != pPrev)
{
if (pPrev->GetNextDisjoint() != pNode)
{
Log("Prev->This mismatch, prev=\n");
pPrev->LogMe();
Log("This=\n");
pNode->LogMe();
Quit("ClusterTree::Validate()");
}
}
else
{
if (pNode != m_ptrDisjoints)
{
Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n",
pNode->GetIndex(),
m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
pNode->LogMe();
Quit("ClusterTree::Validate()");
}
}
if (0 != pNext)
{
if (pNext->GetPrevDisjoint() != pNode)
{
Log("Next->This mismatch, next=\n");
pNext->LogMe();
Log("This=\n");
pNode->LogMe();
Quit("ClusterTree::Validate()");
}
}
++uDisjointListCount;
if (uDisjointListCount > m_uNodeCount)
Quit("Loop in disjoint list");
}
unsigned uParentlessNodeCount = 0;
for (n = 0; n < uNodeCount; ++n)
if (0 == m_Nodes[n].GetParent())
++uParentlessNodeCount;
if (uDisjointListCount != uParentlessNodeCount)
Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount,
uParentlessNodeCount);
}
#else // !_DEBUG
#define Validate(uNodeCount) // empty
#endif
void ClusterNode::LogMe() const
{
unsigned uClusterSize = GetClusterSize();
Log("[%02u] w=%5.3f CW=%5.3f LBW=%5.3f RBW=%5.3f LWT=%5.3f RWT=%5.3f L=%02d R=%02d P=%02d NxDj=%02d PvDj=%02d Sz=%02d {",
m_uIndex,
m_dWeight,
GetClusterWeight(),
GetLeftBranchWeight(),
GetRightBranchWeight(),
GetLeftWeight(),
GetRightWeight(),
m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff,
m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff,
m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff,
m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff,
m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff,
uClusterSize);
for (unsigned i = 0; i < uClusterSize; ++i)
Log(" %u", GetClusterLeaf(i)->GetIndex());
Log(" }\n");
}
// How many leaves in the sub-tree under this node?
unsigned ClusterNode::GetClusterSize() const
{
unsigned uLeafCount = 0;
if (0 == m_ptrLeft && 0 == m_ptrRight)
return 1;
if (0 != m_ptrLeft)
uLeafCount += m_ptrLeft->GetClusterSize();
if (0 != m_ptrRight)
uLeafCount += m_ptrRight->GetClusterSize();
assert(uLeafCount > 0);
return uLeafCount;
}
double ClusterNode::GetClusterWeight() const
{
double dWeight = 0.0;
if (0 != m_ptrLeft)
dWeight += m_ptrLeft->GetClusterWeight();
if (0 != m_ptrRight)
dWeight += m_ptrRight->GetClusterWeight();
return dWeight + GetWeight();
}
double ClusterNode::GetLeftBranchWeight() const
{
const ClusterNode *ptrLeft = GetLeft();
if (0 == ptrLeft)
return 0.0;
return GetWeight() - ptrLeft->GetWeight();
}
double ClusterNode::GetRightBranchWeight() const
{
const ClusterNode *ptrRight = GetRight();
if (0 == ptrRight)
return 0.0;
return GetWeight() - ptrRight->GetWeight();
}
double ClusterNode::GetRightWeight() const
{
const ClusterNode *ptrRight = GetRight();
if (0 == ptrRight)
return 0.0;
return ptrRight->GetClusterWeight() + GetWeight();
}
double ClusterNode::GetLeftWeight() const
{
const ClusterNode *ptrLeft = GetLeft();
if (0 == ptrLeft)
return 0.0;
return ptrLeft->GetClusterWeight() + GetWeight();
}
// Return n'th leaf in the sub-tree under this node.
const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const
{
if (0 != m_ptrLeft)
{
if (0 == m_ptrRight)
return this;
unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize();
if (uLeafIndex < uLeftLeafCount)
return m_ptrLeft->GetClusterLeaf(uLeafIndex);
assert(uLeafIndex >= uLeftLeafCount);
return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount);
}
if (0 == m_ptrRight)
return this;
return m_ptrRight->GetClusterLeaf(uLeafIndex);
}
void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode)
{
ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint();
ClusterNode *ptrNext = ptrNode->GetNextDisjoint();
if (0 != ptrPrev)
ptrPrev->SetNextDisjoint(ptrNext);
else
m_ptrDisjoints = ptrNext;
if (0 != ptrNext)
ptrNext->SetPrevDisjoint(ptrPrev);
#if _DEBUG
// not algorithmically necessary, but improves clarity
// and supports Validate().
ptrNode->SetPrevDisjoint(0);
ptrNode->SetNextDisjoint(0);
#endif
}
void ClusterTree::AddToDisjoints(ClusterNode *ptrNode)
{
ptrNode->SetNextDisjoint(m_ptrDisjoints);
ptrNode->SetPrevDisjoint(0);
if (0 != m_ptrDisjoints)
m_ptrDisjoints->SetPrevDisjoint(ptrNode);
m_ptrDisjoints = ptrNode;
}
ClusterTree::ClusterTree()
{
m_ptrDisjoints = 0;
m_Nodes = 0;
m_uNodeCount = 0;
}
ClusterTree::~ClusterTree()
{
delete[] m_Nodes;
}
void ClusterTree::LogMe() const
{
Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
for (unsigned i = 0; i < m_uNodeCount; ++i)
{
m_Nodes[i].LogMe();
}
}
ClusterNode *ClusterTree::GetRoot() const
{
return &m_Nodes[m_uNodeCount - 1];
}
// This is the UPGMA algorithm as described in Durbin et al. p166.
void ClusterTree::Create(const DistFunc &Dist)
{
unsigned i;
m_uLeafCount = Dist.GetCount();
m_uNodeCount = 2*m_uLeafCount - 1;
delete[] m_Nodes;
m_Nodes = new ClusterNode[m_uNodeCount];
for (i = 0; i < m_uNodeCount; ++i)
m_Nodes[i].SetIndex(i);
for (i = 0; i < m_uLeafCount - 1; ++i)
m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]);
for (i = 1; i < m_uLeafCount; ++i)
m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]);
m_ptrDisjoints = &m_Nodes[0];
// Log("Initial state\n");
// LogMe();
// Log("\n");
DistFunc ClusterDist;
ClusterDist.SetCount(m_uNodeCount);
double dMaxDist = 0.0;
for (i = 0; i < m_uLeafCount; ++i)
for (unsigned j = 0; j < m_uLeafCount; ++j)
{
float dDist = Dist.GetDist(i, j);
ClusterDist.SetDist(i, j, dDist);
}
Validate(m_uLeafCount);
// Iteration. N-1 joins needed to create a binary tree from N leaves.
for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount;
++uJoinIndex)
{
// Find closest pair of clusters
unsigned uIndexClosest1;
unsigned uIndexClosest2;
bool bFound = false;
double dDistClosest = 9e99;
for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1;
ptrNode1 = ptrNode1->GetNextDisjoint())
{
for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2;
ptrNode2 = ptrNode2->GetNextDisjoint())
{
unsigned i1 = ptrNode1->GetIndex();
unsigned i2 = ptrNode2->GetIndex();
double dDist = ClusterDist.GetDist(i1, i2);
if (dDist < dDistClosest)
{
bFound = true;
dDistClosest = dDist;
uIndexClosest1 = i1;
uIndexClosest2 = i2;
}
}
}
assert(bFound);
ClusterNode &Join = m_Nodes[uJoinIndex];
ClusterNode &Child1 = m_Nodes[uIndexClosest1];
ClusterNode &Child2 = m_Nodes[uIndexClosest2];
Join.SetLeft(&Child1);
Join.SetRight(&Child2);
Join.SetWeight(dDistClosest);
Child1.SetParent(&Join);
Child2.SetParent(&Join);
DeleteFromDisjoints(&Child1);
DeleteFromDisjoints(&Child2);
AddToDisjoints(&Join);
// Log("After join %d %d\n", uIndexClosest1, uIndexClosest2);
// LogMe();
// Calculate distance of every remaining disjoint cluster to the
// new cluster created by the join
for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode;
ptrNode = ptrNode->GetNextDisjoint())
{
unsigned uNodeIndex = ptrNode->GetIndex();
float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1);
float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2);
float dDist = Min(dDist1, dDist2);
ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist);
}
Validate(uJoinIndex+1);
}
GetRoot()->GetClusterWeight();
// LogMe();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -