You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
5.4 KiB
255 lines
5.4 KiB
#include "minimize.inl" |
|
#include "ordering.inl" |
|
#include "treeProbabilities.inl" |
|
|
|
static void DefaultErrorFn(const char* msg) |
|
{ |
|
fprintf(stderr, "%s\n", msg); |
|
exit(1); |
|
} |
|
|
|
template <class T> MRFEnergy<T>::MRFEnergy(GlobalSize Kglobal, ErrorFunction errorFn) |
|
: m_errorFn(errorFn ? errorFn : DefaultErrorFn), |
|
m_mallocBlockFirst(NULL), |
|
m_nodeFirst(NULL), |
|
m_nodeLast(NULL), |
|
m_nodeNum(0), |
|
m_edgeNum(0), |
|
m_Kglobal(Kglobal), |
|
m_vectorMaxSizeInBytes(0), |
|
m_isEnergyConstructionCompleted(false), |
|
m_buf(NULL) |
|
{ |
|
} |
|
|
|
template <class T> MRFEnergy<T>::~MRFEnergy() |
|
{ |
|
while (m_mallocBlockFirst) |
|
{ |
|
MallocBlock* next = m_mallocBlockFirst->m_next; |
|
delete m_mallocBlockFirst; |
|
m_mallocBlockFirst = next; |
|
} |
|
} |
|
|
|
template <class T> typename MRFEnergy<T>::NodeId MRFEnergy<T>::AddNode(LocalSize K, NodeData data) |
|
{ |
|
if (m_isEnergyConstructionCompleted) |
|
{ |
|
m_errorFn("Error in AddNode(): graph construction completed - nodes cannot be added"); |
|
} |
|
|
|
int actualVectorSize = Vector::GetSizeInBytes(m_Kglobal, K); |
|
if (actualVectorSize < 0) |
|
{ |
|
m_errorFn("Error in AddNode() (invalid parameter?)"); |
|
} |
|
if (m_vectorMaxSizeInBytes < actualVectorSize) |
|
{ |
|
m_vectorMaxSizeInBytes = actualVectorSize; |
|
} |
|
int nodeSize = sizeof(Node) - sizeof(Vector) + actualVectorSize; |
|
Node* i = (Node *) Malloc(nodeSize); |
|
|
|
i->m_K = K; |
|
i->m_D.Initialize(m_Kglobal, K, data); |
|
|
|
i->m_firstForward = NULL; |
|
i->m_firstBackward = NULL; |
|
i->m_prev = m_nodeLast; |
|
if (m_nodeLast) |
|
{ |
|
m_nodeLast->m_next = i; |
|
} |
|
else |
|
{ |
|
m_nodeFirst = i; |
|
} |
|
m_nodeLast = i; |
|
i->m_next = NULL; |
|
|
|
i->m_ordering = m_nodeNum ++; |
|
|
|
return i; |
|
} |
|
|
|
template <class T> void MRFEnergy<T>::AddNodeData(NodeId i, NodeData data) |
|
{ |
|
i->m_D.Add(m_Kglobal, i->m_K, data); |
|
} |
|
|
|
template <class T> void MRFEnergy<T>::AddEdge(NodeId i, NodeId j, EdgeData data) |
|
{ |
|
if (m_isEnergyConstructionCompleted) |
|
{ |
|
m_errorFn("Error in AddNode(): graph construction completed - nodes cannot be added"); |
|
} |
|
|
|
MRFEdge* e; |
|
|
|
int actualEdgeSize = Edge::GetSizeInBytes(m_Kglobal, i->m_K, j->m_K, data); |
|
if (actualEdgeSize < 0) |
|
{ |
|
m_errorFn("Error in AddEdge() (invalid parameter?)"); |
|
} |
|
int MRFedgeSize = sizeof(MRFEdge) - sizeof(Edge) + actualEdgeSize; |
|
e = (MRFEdge*) Malloc(MRFedgeSize); |
|
|
|
e->m_message.Initialize(m_Kglobal, i->m_K, j->m_K, data, &i->m_D, &j->m_D); |
|
|
|
e->m_tail = i; |
|
e->m_nextForward = i->m_firstForward; |
|
i->m_firstForward = e; |
|
|
|
e->m_head = j; |
|
e->m_nextBackward = j->m_firstBackward; |
|
j->m_firstBackward = e; |
|
|
|
m_edgeNum ++; |
|
} |
|
|
|
///////////////////////////////////////////////////////////////////////////////// |
|
|
|
template <class T> void MRFEnergy<T>::ZeroMessages() |
|
{ |
|
Node* i; |
|
MRFEdge* e; |
|
|
|
if (!m_isEnergyConstructionCompleted) |
|
{ |
|
CompleteGraphConstruction(); |
|
} |
|
|
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
e->m_message.GetMessagePtr()->SetZero(m_Kglobal, i->m_K); |
|
} |
|
} |
|
} |
|
|
|
template <class T> void MRFEnergy<T>::AddRandomMessages(unsigned int random_seed, REAL min_value, REAL max_value) |
|
{ |
|
Node* i; |
|
MRFEdge* e; |
|
int k; |
|
|
|
if (!m_isEnergyConstructionCompleted) |
|
{ |
|
CompleteGraphConstruction(); |
|
} |
|
|
|
srand(random_seed); |
|
|
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Vector* M = e->m_message.GetMessagePtr(); |
|
for (k=0; k<M->GetArraySize(m_Kglobal, i->m_K); k++) |
|
{ |
|
REAL x = (REAL)( min_value + rand()/((double)RAND_MAX) * (max_value - min_value) ); |
|
x += M->GetArrayValue(m_Kglobal, i->m_K, k); |
|
M->SetArrayValue(m_Kglobal, i->m_K, k, x); |
|
} |
|
} |
|
} |
|
} |
|
|
|
///////////////////////////////////////////////////////////////////////////////// |
|
|
|
template <class T> void MRFEnergy<T>::CompleteGraphConstruction() |
|
{ |
|
Node* i; |
|
Node* j; |
|
MRFEdge* e; |
|
MRFEdge* ePrev; |
|
|
|
if (m_isEnergyConstructionCompleted) |
|
{ |
|
m_errorFn("Fatal error in CompleteGraphConstruction"); |
|
} |
|
|
|
printf("Completing graph construction... "); |
|
|
|
if (m_buf) |
|
{ |
|
m_errorFn("CompleteGraphConstruction(): fatal error"); |
|
} |
|
|
|
m_buf = (char *) Malloc(m_vectorMaxSizeInBytes + |
|
( m_vectorMaxSizeInBytes > Edge::GetBufSizeInBytes(m_vectorMaxSizeInBytes) ? |
|
m_vectorMaxSizeInBytes : Edge::GetBufSizeInBytes(m_vectorMaxSizeInBytes) ) ); |
|
|
|
// set forward and backward edges properly |
|
#ifdef _DEBUG |
|
int ordering; |
|
for (i=m_nodeFirst, ordering=0; i; i=i->m_next, ordering++) |
|
{ |
|
if ( (i->m_ordering != ordering) |
|
|| (i->m_ordering == 0 && i->m_prev) |
|
|| (i->m_ordering != 0 && i->m_prev->m_ordering != ordering-1) ) |
|
{ |
|
m_errorFn("CompleteGraphConstruction(): fatal error (wrong ordering)"); |
|
} |
|
} |
|
if (ordering != m_nodeNum) |
|
{ |
|
m_errorFn("CompleteGraphConstruction(): fatal error"); |
|
} |
|
#endif |
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
i->m_firstBackward = NULL; |
|
} |
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
ePrev = NULL; |
|
for (e=i->m_firstForward; e; ) |
|
{ |
|
assert(i == e->m_tail); |
|
j = e->m_head; |
|
|
|
if (i->m_ordering < j->m_ordering) |
|
{ |
|
e->m_nextBackward = j->m_firstBackward; |
|
j->m_firstBackward = e; |
|
|
|
ePrev = e; |
|
e = e->m_nextForward; |
|
} |
|
else |
|
{ |
|
e->m_message.Swap(m_Kglobal, i->m_K, j->m_K); |
|
e->m_tail = j; |
|
e->m_head = i; |
|
|
|
MRFEdge* eNext = e->m_nextForward; |
|
|
|
if (ePrev) |
|
{ |
|
ePrev->m_nextForward = e->m_nextForward; |
|
} |
|
else |
|
{ |
|
i->m_firstForward = e->m_nextForward; |
|
} |
|
|
|
e->m_nextForward = j->m_firstForward; |
|
j->m_firstForward = e; |
|
|
|
e->m_nextBackward = i->m_firstBackward; |
|
i->m_firstBackward = e; |
|
|
|
e = eNext; |
|
} |
|
} |
|
} |
|
|
|
m_isEnergyConstructionCompleted = true; |
|
|
|
// ZeroMessages(); |
|
|
|
printf("done\n"); |
|
}
|
|
|