TooN 2.0.0-beta8
QR_Lapack.h
00001 #ifndef TOON_INCLUDE_QR_LAPACK_H
00002 #define TOON_INCLUDE_QR_LAPACK_H
00003 
00004 
00005 #include <TooN/TooN.h>
00006 #include <TooN/lapack.h>
00007 #include <utility>
00008 
00009 namespace TooN{
00010 
00011 /**
00012 Performs %QR decomposition.
00013 
00014 @warning this will only work if the number of columns is greater than 
00015 the number of rows!
00016 
00017 The QR decomposition operates on a matrix A. It can be performed with
00018 or without column pivoting. In general:
00019 \f[
00020 AP = QR
00021 \f]
00022 Where \f$P\f$ is a permutation matrix constructed to permute the columns
00023 of A. In practise, \f$P\f$ is stored as a vector of integer elements.
00024 
00025 With column pivoting, the elements of the leading diagonal of \f$R\f$ will
00026 be sorted from largest in magnitude to smallest in magnitude.
00027 
00028 @ingroup gDecomps
00029 */
00030 template<int Rows=Dynamic, int Cols=Rows, class Precision=double>
00031 class QR_Lapack{
00032 
00033     private:
00034         static const int square_Size = (Rows>=0 && Cols>=0)?(Rows<Cols?Rows:Cols):Dynamic;
00035 
00036     public: 
00037         /// Construct the %QR decomposition of a matrix. This initialises the class, and
00038         /// performs the decomposition immediately.
00039         /// @param m The matrix to decompose
00040         /// @param p Whether or not to perform pivoting
00041         template<int R, int C, class P, class B> 
00042         QR_Lapack(const Matrix<R,C,P,B>& m, bool p=0)
00043         :copy(m),tau(square_size()), Q(square_size(), square_size()), do_pivoting(p), pivot(Zeros(square_size()))
00044         {
00045             //pivot is set to all zeros, which means all columns are free columns
00046             //and can take part in column pivoting.
00047 
00048             compute();
00049         }
00050         
00051         ///Return R
00052         const Matrix<Rows, Cols, Precision, ColMajor>& get_R()
00053         {
00054             return copy;
00055         }
00056         
00057         ///Return Q
00058         const Matrix<square_Size, square_Size, Precision, ColMajor>& get_Q()
00059         {
00060             return Q;
00061         }   
00062 
00063         ///Return the permutation vector. The definition is that column \f$i\f$ of A is
00064         ///column \f$P(i)\f$ of \f$QR\f$.
00065         const Vector<Cols, int>& get_P()
00066         {
00067             return pivot;
00068         }
00069 
00070     private:
00071 
00072         void compute()
00073         {   
00074             FortranInteger M = copy.num_rows();
00075             FortranInteger N = copy.num_cols();
00076             
00077             FortranInteger LWORK=-1;
00078             FortranInteger INFO;
00079             FortranInteger lda = M;
00080 
00081             Precision size;
00082             
00083             //Set up the pivot vector
00084             if(do_pivoting)
00085                 pivot = Zeros;
00086             else
00087                 for(int i=0; i < pivot.size(); i++)
00088                     pivot[i] = i+1;
00089 
00090             
00091             //Compute the working space
00092             geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), &size, &LWORK, &INFO);
00093 
00094             LWORK = (FortranInteger) size;
00095 
00096             Precision* work = new Precision[LWORK];
00097             
00098             geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), work, &LWORK, &INFO);
00099 
00100 
00101             if(INFO < 0)
00102                 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00103 
00104             //The upper "triangle+" of copy is R
00105             //The lower right and tau contain enough information to reconstruct Q
00106             
00107             //LAPACK provides a handy function to do the reconstruction
00108             Q = copy.template slice<0,0,square_Size, square_Size>(0,0,square_size(), square_size());
00109             
00110             FortranInteger K = square_size();
00111             M=K;
00112             N=K;
00113             lda = K;
00114             orgqr_(&M, &N, &K, Q.get_data_ptr(), &lda, tau.get_data_ptr(), work, &LWORK, &INFO);
00115 
00116             if(INFO < 0)
00117                 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00118 
00119             delete [] work;
00120             
00121             //Now zero out the lower triangle
00122             for(int r=1; r < square_size(); r++)
00123                 for(int c=0; c<r; c++)
00124                     copy[r][c] = 0;
00125 
00126             //Now fix the pivot matrix.
00127             //We need to go from FORTRAN to C numbering. 
00128             for(int i=0; i < pivot.size(); i++)
00129                 pivot[i]--;
00130         }
00131 
00132         Matrix<Rows, Cols, Precision, ColMajor> copy;
00133         Vector<square_Size, Precision> tau;
00134         Matrix<square_Size, square_Size, Precision, ColMajor> Q;
00135         bool do_pivoting;
00136         Vector<Cols, FortranInteger> pivot;
00137         
00138 
00139         int square_size()
00140         {
00141             return std::min(copy.num_rows(), copy.num_cols());  
00142         }
00143 };
00144 
00145 }
00146 
00147 
00148 #endif