You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
8.4 KiB
246 lines
8.4 KiB
/****************************************************************** |
|
Vladimir Kolmogorov, 2005 |
|
vnk@microsoft.com |
|
|
|
(c) Microsoft Corporation. All rights reserved. |
|
*******************************************************************/ |
|
|
|
#ifndef __MRFENERGY_H__ |
|
#define __MRFENERGY_H__ |
|
|
|
#include "instances.h" |
|
|
|
|
|
// After MRFEnergy is allocated, there are two phases: |
|
// 1. Energy construction. Only AddNode(), AddNodeData() and AddEdge() may be called. |
|
// |
|
// Any call ZeroMessages(), SetAutomaticOrdering(), Minimize_TRW_S() or Minimize_BP() |
|
// completes graph construction; MRFEnergy goes to the second phase: |
|
// 2. Only functions AddNodeData(), ZeroMessages(), Minimize_TRW_S(), Minimize_BP() |
|
// or GetSolution() may be called. (The last function can be called only after |
|
// Minimize_TRW_S() or Minimize_BP()). |
|
|
|
|
|
template <class T> class MRFEnergy |
|
{ |
|
private: |
|
struct Node; |
|
|
|
public: |
|
typedef typename T::Label Label; |
|
typedef typename T::REAL REAL; |
|
typedef typename T::GlobalSize GlobalSize; |
|
typedef typename T::LocalSize LocalSize; |
|
typedef typename T::NodeData NodeData; |
|
typedef typename T::EdgeData EdgeData; |
|
|
|
typedef Node* NodeId; |
|
typedef void (*ErrorFunction)(const char* msg); |
|
|
|
// Constructor. Function errorFn is called with an error message, if an error occurs. |
|
MRFEnergy(GlobalSize Kglobal, ErrorFunction errorFn = NULL); |
|
|
|
// Destructor. |
|
~MRFEnergy(); |
|
|
|
////////////////////////////////////////////////////////// |
|
// Energy construction // |
|
////////////////////////////////////////////////////////// |
|
|
|
// Adds a node with parameters K and data |
|
// (see the corresponding message*.h file for description). |
|
// Note: information in data is copied into internal memory. |
|
// Cannot be called after energy construction is completed. |
|
NodeId AddNode(LocalSize K, NodeData data); |
|
|
|
// Modifies node parameter for existing node (namely, add information |
|
// in data to existing parameter). May be called at any time. |
|
// Node i must be NodeId returned by AddNode(). |
|
void AddNodeData(NodeId i, NodeData data); |
|
|
|
// Adds an edge between i and j. data determines edge parameters |
|
// (see the corresponding message*.h file for description). |
|
// Note: information in data is copied into internal memory. |
|
// Cannot be called after energy construction is completed. |
|
void AddEdge(NodeId i, NodeId j, EdgeData data); |
|
|
|
////////////////////////////////////////////////////////// |
|
// Energy construction end // |
|
////////////////////////////////////////////////////////// |
|
|
|
// Clears all messages. Completes energy construction (if not completed yet). |
|
void ZeroMessages(); |
|
|
|
// Adds to all message entries a value drawn uniformly from [min_value, max_value]. |
|
// Normally, min_value can be set to 0 (except for TypeBinaryFast, in which case min_value = -max_value) |
|
void AddRandomMessages(unsigned int random_seed, REAL min_value, REAL max_value); |
|
|
|
// The algorithm depends on the order of nodes. |
|
// By default nodes are processed in the order in which they were added. |
|
// The function below permutes this order using certain heuristics. |
|
// It may speed up the algorithm if, for example, the original order is random. |
|
// |
|
// Completes energy construction. |
|
// Cannot be called after energy construction is completed. |
|
void SetAutomaticOrdering(); |
|
|
|
// The structure below specifies (1) stopping criteria and |
|
// (2) how often to compute solution and print its energy. |
|
struct Options |
|
{ |
|
Options() |
|
{ |
|
// default parameters |
|
m_eps = -1; // not used |
|
m_iterMax = 1000000; |
|
m_printIter = 5; // After 10 iterations start printing the lower bound |
|
m_printMinIter = 10; // and the energy every 5 iterations. |
|
} |
|
|
|
// stopping criterion |
|
REAL m_eps; // stop if the increase in the lower bound during one iteration is less or equal than m_eps. |
|
// Used only if m_eps >= 0, and only for TRW-S algorithm. |
|
int m_iterMax; // maximum number of iterations |
|
|
|
// Option for printing lower bound and the energy. |
|
// Note: computing solution and its energy is slow |
|
// (it is comparable to the cost of one iteration). |
|
int m_printIter; // print lower bound and energy every m_printIter iterations |
|
int m_printMinIter; // do not print lower bound and energy before m_printMinIter iterations |
|
}; |
|
|
|
// Returns number of iterations. Sets lowerBound and energy. |
|
// If the user provides array min_marginals, then the code |
|
// sets this array accordingly. (The size of the array depends on the type |
|
// used. Normally, it's (# nodes)*(# labels). Exception: for TypeBinaryFast it's (# nodes). |
|
int Minimize_TRW_S(Options& options, REAL& lowerBound, REAL& energy, REAL* min_marginals = NULL); |
|
|
|
// Returns number of iterations. Sets energy. |
|
int Minimize_BP(Options& options, REAL& energy, REAL* min_marginals = NULL); |
|
|
|
// Returns an integer in [0,Ki). Can be called only after Minimize(). |
|
Label GetSolution(NodeId i); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////// |
|
////////////////////////////////////////////////////////// |
|
////////////////////////////////////////////////////////// |
|
// Implementation // |
|
////////////////////////////////////////////////////////// |
|
////////////////////////////////////////////////////////// |
|
////////////////////////////////////////////////////////// |
|
private: |
|
|
|
typedef typename T::Vector Vector; |
|
typedef typename T::Edge Edge; |
|
|
|
struct MRFEdge; |
|
struct MallocBlock; |
|
|
|
ErrorFunction m_errorFn; |
|
MallocBlock* m_mallocBlockFirst; |
|
Node* m_nodeFirst; |
|
Node* m_nodeLast; |
|
int m_nodeNum; |
|
int m_edgeNum; |
|
GlobalSize m_Kglobal; |
|
int m_vectorMaxSizeInBytes; |
|
|
|
bool m_isEnergyConstructionCompleted; |
|
|
|
char* m_buf; // buffer of size m_vectorMaxSizeInBytes |
|
// + max(m_vectorMaxSizeInBytes, Edge::GetBufSizeInBytes(m_vectorMaxSizeInBytes)) |
|
|
|
void CompleteGraphConstruction(); // nodes and edges cannot be added after calling this function |
|
void SetMonotonicTrees(); |
|
|
|
REAL ComputeSolutionAndEnergy(); // sets Node::m_solution, returns value of the energy |
|
|
|
|
|
|
|
struct Node |
|
{ |
|
int m_ordering; // unique integer in [0,m_nodeNum-1) |
|
|
|
MRFEdge* m_firstForward; // first edge going to nodes with greater m_ordering |
|
MRFEdge* m_firstBackward; // first edge going to nodes with smaller m_ordering |
|
|
|
Node* m_prev; // previous and next |
|
Node* m_next; // nodes according to m_ordering |
|
|
|
Label m_solution; // integer in [0,m_D.m_K) |
|
LocalSize m_K; // local information about number of labels |
|
|
|
Vector m_D; // must be the last member in the struct since its size is not fixed |
|
}; |
|
|
|
struct MRFEdge |
|
{ |
|
MRFEdge* m_nextForward; // next forward edge with the same tail |
|
MRFEdge* m_nextBackward; // next backward edge with the same head |
|
Node* m_tail; |
|
Node* m_head; |
|
|
|
REAL m_gammaForward; // = rho_{ij} / rho_{i} where i=m_tail, j=m_head |
|
REAL m_gammaBackward; // = rho_{ij} / rho_{j} where i=m_tail, j=m_head |
|
|
|
Edge m_message; // must be the last member in the struct since its size is not fixed. |
|
// Stores edge information and either forward or backward message. |
|
// Most of the time it's the backward message; it gets replaced |
|
// by the forward message only temporarily inside Minimize_TRW_S() and Minimize_BP(). |
|
}; |
|
|
|
// Use our own Malloc since |
|
// (a) new in C++ is slow and allocates minimum memory of 64 bytes (in Visual C++) |
|
// (b) we want simple (one function) deallocation instead of going through every allocated element |
|
struct MallocBlock |
|
{ |
|
static const int minBlockSizeInBytes = 4096 - 3*sizeof(void*); |
|
MallocBlock* m_next; |
|
char* m_current; // first element of available memory in this block |
|
char* m_last; // first element outside of allocated memory for this block |
|
}; |
|
char* Malloc(int bytesNum); |
|
}; |
|
|
|
|
|
|
|
|
|
template <class T> inline char* MRFEnergy<T>::Malloc(int bytesNum) |
|
{ |
|
if (!m_mallocBlockFirst || m_mallocBlockFirst->m_current+bytesNum > m_mallocBlockFirst->m_last) |
|
{ |
|
int size = (bytesNum > MallocBlock::minBlockSizeInBytes) ? bytesNum : MallocBlock::minBlockSizeInBytes; |
|
MallocBlock* b = (MallocBlock*) new char[sizeof(MallocBlock) + size]; |
|
if (!b) m_errorFn("Not enough memory"); |
|
b->m_current = (char*) b + sizeof(MallocBlock); |
|
b->m_last = b->m_current + size; |
|
|
|
b->m_next = m_mallocBlockFirst; |
|
m_mallocBlockFirst = b; |
|
} |
|
|
|
char* ptr = m_mallocBlockFirst->m_current; |
|
m_mallocBlockFirst->m_current += bytesNum; |
|
return ptr; |
|
} |
|
|
|
template <class T> inline typename T::Label MRFEnergy<T>::GetSolution(NodeId i) |
|
{ |
|
return i->m_solution; |
|
} |
|
|
|
#include "MRFEnergy.inl" |
|
|
|
#endif
|
|
|