You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
299 lines
7.5 KiB
299 lines
7.5 KiB
template <class T> int MRFEnergy<T>::Minimize_TRW_S(Options& options, REAL& lowerBound, REAL& energy, REAL* min_marginals) |
|
{ |
|
Node* i; |
|
Node* j; |
|
MRFEdge* e; |
|
REAL vMin; |
|
int iter; |
|
REAL lowerBoundPrev; |
|
|
|
if (!m_isEnergyConstructionCompleted) |
|
{ |
|
CompleteGraphConstruction(); |
|
} |
|
|
|
printf("TRW_S algorithm\n"); |
|
|
|
SetMonotonicTrees(); |
|
|
|
Vector* Di = (Vector*) m_buf; |
|
void* buf = (void*) (m_buf + m_vectorMaxSizeInBytes); |
|
|
|
iter = 0; |
|
bool lastIter = false; |
|
|
|
// main loop |
|
for (iter=1; ; iter++) |
|
{ |
|
if (iter >= options.m_iterMax) lastIter = true; |
|
|
|
//////////////////////////////////////////////// |
|
// forward pass // |
|
//////////////////////////////////////////////// |
|
REAL* min_marginals_ptr = min_marginals; |
|
|
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
Di->Copy(m_Kglobal, i->m_K, &i->m_D); |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
|
|
// normalize Di, update lower bound |
|
// vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); // do not compute lower bound |
|
// lowerBound += vMin; // during the forward pass |
|
|
|
// pass messages from i to nodes with higher m_ordering |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
assert(e->m_tail == i); |
|
j = e->m_head; |
|
|
|
vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaForward, 0, buf); |
|
|
|
// lowerBound += vMin; // do not compute lower bound during the forward pass |
|
} |
|
|
|
if (lastIter && min_marginals) |
|
{ |
|
min_marginals_ptr += Di->GetArraySize(m_Kglobal, i->m_K); |
|
} |
|
} |
|
|
|
//////////////////////////////////////////////// |
|
// backward pass // |
|
//////////////////////////////////////////////// |
|
lowerBound = 0; |
|
|
|
for (i=m_nodeLast; i; i=i->m_prev) |
|
{ |
|
Di->Copy(m_Kglobal, i->m_K, &i->m_D); |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
|
|
// normalize Di, update lower bound |
|
vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); |
|
lowerBound += vMin; |
|
|
|
// pass messages from i to nodes with smaller m_ordering |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
assert(e->m_head == i); |
|
j = e->m_tail; |
|
|
|
vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaBackward, 1, buf); |
|
|
|
lowerBound += vMin; |
|
} |
|
|
|
if (lastIter && min_marginals) |
|
{ |
|
min_marginals_ptr -= Di->GetArraySize(m_Kglobal, i->m_K); |
|
for (int k=0; k<Di->GetArraySize(m_Kglobal, i->m_K); k++) |
|
{ |
|
min_marginals_ptr[k] = Di->GetArrayValue(m_Kglobal, i->m_K, k); |
|
} |
|
} |
|
} |
|
|
|
//////////////////////////////////////////////// |
|
// check stopping criterion // |
|
//////////////////////////////////////////////// |
|
|
|
// print lower bound and energy, if necessary |
|
if ( lastIter || |
|
( iter>=options.m_printMinIter && |
|
(options.m_printIter<1 || iter%options.m_printIter==0) ) |
|
) |
|
{ |
|
energy = ComputeSolutionAndEnergy(); |
|
printf("iter %d: lower bound = %f, energy = %f\n", iter, lowerBound, energy); |
|
} |
|
|
|
if (lastIter) break; |
|
|
|
// check convergence of lower bound |
|
if (options.m_eps >= 0) |
|
{ |
|
if (iter > 1 && lowerBound - lowerBoundPrev <= options.m_eps) |
|
{ |
|
lastIter = true; |
|
} |
|
lowerBoundPrev = lowerBound; |
|
} |
|
} |
|
|
|
return iter; |
|
} |
|
|
|
template <class T> int MRFEnergy<T>::Minimize_BP(Options& options, REAL& energy, REAL* min_marginals) |
|
{ |
|
Node* i; |
|
Node* j; |
|
MRFEdge* e; |
|
REAL vMin; |
|
int iter; |
|
|
|
if (!m_isEnergyConstructionCompleted) |
|
{ |
|
CompleteGraphConstruction(); |
|
} |
|
|
|
printf("BP algorithm\n"); |
|
|
|
Vector* Di = (Vector*) m_buf; |
|
void* buf = (void*) (m_buf + m_vectorMaxSizeInBytes); |
|
|
|
iter = 0; |
|
bool lastIter = false; |
|
|
|
// main loop |
|
for (iter=1; ; iter++) |
|
{ |
|
if (iter >= options.m_iterMax) lastIter = true; |
|
|
|
//////////////////////////////////////////////// |
|
// forward pass // |
|
//////////////////////////////////////////////// |
|
REAL* min_marginals_ptr = min_marginals; |
|
|
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
Di->Copy(m_Kglobal, i->m_K, &i->m_D); |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
|
|
// pass messages from i to nodes with higher m_ordering |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
assert(i == e->m_tail); |
|
j = e->m_head; |
|
|
|
const REAL gamma = 1; |
|
|
|
e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, gamma, 0, buf); |
|
} |
|
|
|
if (lastIter && min_marginals) |
|
{ |
|
min_marginals_ptr += Di->GetArraySize(m_Kglobal, i->m_K); |
|
} |
|
} |
|
|
|
//////////////////////////////////////////////// |
|
// backward pass // |
|
//////////////////////////////////////////////// |
|
|
|
for (i=m_nodeLast; i; i=i->m_prev) |
|
{ |
|
Di->Copy(m_Kglobal, i->m_K, &i->m_D); |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
|
|
// pass messages from i to nodes with smaller m_ordering |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
assert(i == e->m_head); |
|
j = e->m_tail; |
|
|
|
const REAL gamma = 1; |
|
|
|
vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, gamma, 1, buf); |
|
} |
|
|
|
if (lastIter && min_marginals) |
|
{ |
|
min_marginals_ptr -= Di->GetArraySize(m_Kglobal, i->m_K); |
|
for (int k=0; k<Di->GetArraySize(m_Kglobal, i->m_K); k++) |
|
{ |
|
min_marginals_ptr[k] = Di->GetArrayValue(m_Kglobal, i->m_K, k); |
|
} |
|
} |
|
} |
|
|
|
//////////////////////////////////////////////// |
|
// check stopping criterion // |
|
//////////////////////////////////////////////// |
|
|
|
// print energy, if necessary |
|
if ( lastIter || |
|
( iter>=options.m_printMinIter && |
|
(options.m_printIter<1 || iter%options.m_printIter==0) ) |
|
) |
|
{ |
|
energy = ComputeSolutionAndEnergy(); |
|
printf("iter %d: energy = %f\n", iter, energy); |
|
} |
|
|
|
// if finishFlag==true terminate |
|
if (lastIter) break; |
|
} |
|
|
|
return iter; |
|
} |
|
|
|
template <class T> typename T::REAL MRFEnergy<T>::ComputeSolutionAndEnergy() |
|
{ |
|
Node* i; |
|
Node* j; |
|
MRFEdge* e; |
|
REAL E = 0; |
|
|
|
Vector* DiBackward = (Vector*) m_buf; // cost of backward edges plus Di at the node |
|
Vector* Di = (Vector*) (m_buf + m_vectorMaxSizeInBytes); // all edges plus Di at the node |
|
|
|
for (i=m_nodeFirst; i; i=i->m_next) |
|
{ |
|
// Set Ebackward[ki] to be the sum of V(ki,j->m_solution) for backward edges (i,j). |
|
// Set Di[ki] to be the value of the energy corresponding to |
|
// part of the graph considered so far, assuming that nodes u |
|
// in this subgraph are fixed to u->m_solution |
|
|
|
DiBackward->Copy(m_Kglobal, i->m_K, &i->m_D); |
|
for (e=i->m_firstBackward; e; e=e->m_nextBackward) |
|
{ |
|
assert(i == e->m_head); |
|
j = e->m_tail; |
|
|
|
e->m_message.AddColumn(m_Kglobal, j->m_K, i->m_K, j->m_solution, DiBackward, 0); |
|
} |
|
|
|
// add forward edges |
|
Di->Copy(m_Kglobal, i->m_K, DiBackward); |
|
|
|
for (e=i->m_firstForward; e; e=e->m_nextForward) |
|
{ |
|
Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); |
|
} |
|
|
|
Di->ComputeMin(m_Kglobal, i->m_K, i->m_solution); |
|
|
|
// update energy |
|
E += DiBackward->GetValue(m_Kglobal, i->m_K, i->m_solution); |
|
} |
|
|
|
return E; |
|
}
|
|
|