20 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
21 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
52 double min = -DBL_MIN,
57 template<
typename MatType>
58 void Initialize(
const MatType& dataset,
const size_t rank)
60 const size_t n = dataset.n_rows;
61 const size_t m = dataset.n_cols;
76 template<
typename MatType>
88 arma::mat deltaW(n, r);
91 for(
size_t i = 0;i < n;i++)
93 for(
size_t j = 0;j < m;j++)
96 if((val = V(i, j)) != 0)
97 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
98 arma::trans(H.col(j));
100 if(
kw != 0) deltaW.row(i) -=
kw * W.row(i);
116 template<
typename MatType>
128 arma::mat deltaH(r, m);
131 for(
size_t j = 0;j < m;j++)
133 for(
size_t i = 0;i < n;i++)
136 if((val = V(i, j)) != 0)
137 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
138 arma::trans(W.row(i));
140 if(
kh != 0) deltaH.col(j) -=
kh * H.col(j);
166 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
176 arma::mat deltaW(n, r);
179 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
181 size_t row = it.row();
182 size_t col = it.col();
183 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
184 arma::trans(H.col(col));
187 if(kw != 0)
for(
size_t i = 0; i < n; i++)
189 deltaW.row(i) -= kw * W.row(i);
197 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
207 arma::mat deltaH(r, m);
210 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
212 size_t row = it.row();
213 size_t col = it.col();
214 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
215 arma::trans(W.row(row));
218 if(kh != 0)
for(
size_t j = 0; j < m; j++)
220 deltaH.col(j) -= kh * H.col(j);