gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm_mixed]

Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm_mixed.hpp>

List of all members.

Static Public Member Functions

template<typename out_eT , typename in_eT1 , typename in_eT2 >
static arma_hot void apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0))

Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >

Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 28 of file gemm_mixed.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename out_eT , typename in_eT1 , typename in_eT2 >
static arma_hot void gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< out_eT > &  C,
const Mat< in_eT1 > &  A,
const Mat< in_eT2 > &  B,
const out_eT  alpha = out_eT(1),
const out_eT  beta = out_eT(0) 
) [inline, static]

Definition at line 38 of file gemm_mixed.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), podarray< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().

00045     {
00046     arma_extra_debug_sigprint();
00047     
00048     const u32 A_n_rows = A.n_rows;
00049     const u32 A_n_cols = A.n_cols;
00050     
00051     const u32 B_n_rows = B.n_rows;
00052     const u32 B_n_cols = B.n_cols;
00053     
00054     if( (do_trans_A == false) && (do_trans_B == false) )
00055       {
00056       podarray<in_eT1> tmp(A_n_cols);
00057       in_eT1* A_rowdata = tmp.memptr();
00058       
00059       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00060         {
00061         
00062         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00063           {
00064           A_rowdata[col_A] = A.at(row_A,col_A);
00065           }
00066         
00067         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00068           {
00069           const in_eT2* B_coldata = B.colptr(col_B);
00070           
00071           out_eT acc = out_eT(0);
00072           for(u32 i=0; i < B_n_rows; ++i)
00073             {
00074             acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00075             }
00076         
00077           if( (use_alpha == false) && (use_beta == false) )
00078             {
00079             C.at(row_A,col_B) = acc;
00080             }
00081           else
00082           if( (use_alpha == true) && (use_beta == false) )
00083             {
00084             C.at(row_A,col_B) = alpha * acc;
00085             }
00086           else
00087           if( (use_alpha == false) && (use_beta == true) )
00088             {
00089             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00090             }
00091           else
00092           if( (use_alpha == true) && (use_beta == true) )
00093             {
00094             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00095             }
00096           
00097           }
00098         }
00099       }
00100     else
00101     if( (do_trans_A == true) && (do_trans_B == false) )
00102       {
00103       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00104         {
00105         // col_A is interpreted as row_A when storing the results in matrix C
00106         
00107         const in_eT1* A_coldata = A.colptr(col_A);
00108         
00109         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00110           {
00111           const in_eT2* B_coldata = B.colptr(col_B);
00112           
00113           out_eT acc = out_eT(0);
00114           for(u32 i=0; i < B_n_rows; ++i)
00115             {
00116             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00117             }
00118         
00119           if( (use_alpha == false) && (use_beta == false) )
00120             {
00121             C.at(col_A,col_B) = acc;
00122             }
00123           else
00124           if( (use_alpha == true) && (use_beta == false) )
00125             {
00126             C.at(col_A,col_B) = alpha * acc;
00127             }
00128           else
00129           if( (use_alpha == false) && (use_beta == true) )
00130             {
00131             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00132             }
00133           else
00134           if( (use_alpha == true) && (use_beta == true) )
00135             {
00136             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00137             }
00138           
00139           }
00140         }
00141       }
00142     else
00143     if( (do_trans_A == false) && (do_trans_B == true) )
00144       {
00145       Mat<in_eT2> B_tmp = trans(B);
00146       gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00147       }
00148     else
00149     if( (do_trans_A == true) && (do_trans_B == true) )
00150       {
00151       // mat B_tmp = trans(B);
00152       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00153       
00154       
00155       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00156       // transpose operations are not needed
00157       
00158       podarray<in_eT2> tmp(B.n_cols);
00159       in_eT2* B_rowdata = tmp.memptr();
00160       
00161       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00162         {
00163         
00164         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00165           {
00166           B_rowdata[col_B] = B.at(row_B,col_B);
00167           }
00168         
00169         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00170           {
00171           const in_eT1* A_coldata = A.colptr(col_A);
00172           
00173           out_eT acc = out_eT(0);
00174           for(u32 i=0; i < A_n_rows; ++i)
00175             {
00176             acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00177             }
00178         
00179           if( (use_alpha == false) && (use_beta == false) )
00180             {
00181             C.at(col_A,row_B) = acc;
00182             }
00183           else
00184           if( (use_alpha == true) && (use_beta == false) )
00185             {
00186             C.at(col_A,row_B) = alpha * acc;
00187             }
00188           else
00189           if( (use_alpha == false) && (use_beta == true) )
00190             {
00191             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00192             }
00193           else
00194           if( (use_alpha == true) && (use_beta == true) )
00195             {
00196             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00197             }
00198           
00199           }
00200         }
00201       
00202       }
00203     }