MLPACK  1.0.11
svd_batch_learning.hpp
Go to the documentation of this file.
1 
20 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
21 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
22 
23 #include <mlpack/core.hpp>
24 
25 namespace mlpack {
26 namespace amf {
27 
38 {
39  public:
48  SVDBatchLearning(double u = 0.0002,
49  double kw = 0,
50  double kh = 0,
51  double momentum = 0.9,
52  double min = -DBL_MIN,
53  double max = DBL_MAX)
54  : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
55  {}
56 
57  template<typename MatType>
58  void Initialize(const MatType& dataset, const size_t rank)
59  {
60  const size_t n = dataset.n_rows;
61  const size_t m = dataset.n_cols;
62 
63  mW.zeros(n, rank);
64  mH.zeros(rank, m);
65  }
66 
76  template<typename MatType>
77  inline void WUpdate(const MatType& V,
78  arma::mat& W,
79  const arma::mat& H)
80  {
81  size_t n = V.n_rows;
82  size_t m = V.n_cols;
83 
84  size_t r = W.n_cols;
85 
86  mW = momentum * mW;
87 
88  arma::mat deltaW(n, r);
89  deltaW.zeros();
90 
91  for(size_t i = 0;i < n;i++)
92  {
93  for(size_t j = 0;j < m;j++)
94  {
95  double val;
96  if((val = V(i, j)) != 0)
97  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
98  arma::trans(H.col(j));
99  }
100  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
101  }
102 
103  mW += u * deltaW;
104  W += mW;
105  }
106 
116  template<typename MatType>
117  inline void HUpdate(const MatType& V,
118  const arma::mat& W,
119  arma::mat& H)
120  {
121  size_t n = V.n_rows;
122  size_t m = V.n_cols;
123 
124  size_t r = W.n_cols;
125 
126  mH = momentum * mH;
127 
128  arma::mat deltaH(r, m);
129  deltaH.zeros();
130 
131  for(size_t j = 0;j < m;j++)
132  {
133  for(size_t i = 0;i < n;i++)
134  {
135  double val;
136  if((val = V(i, j)) != 0)
137  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
138  arma::trans(W.row(i));
139  }
140  if(kh != 0) deltaH.col(j) -= kh * H.col(j);
141  }
142 
143  mH += u*deltaH;
144  H += mH;
145  }
146 
147  private:
148  double u;
149  double kw;
150  double kh;
151  double min;
152  double max;
153  double momentum;
154 
155  arma::mat mW;
156  arma::mat mH;
157 };
158 
161 
165 template<>
166 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
167  arma::mat& W,
168  const arma::mat& H)
169 {
170  size_t n = V.n_rows;
171 
172  size_t r = W.n_cols;
173 
174  mW = momentum * mW;
175 
176  arma::mat deltaW(n, r);
177  deltaW.zeros();
178 
179  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
180  {
181  size_t row = it.row();
182  size_t col = it.col();
183  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
184  arma::trans(H.col(col));
185  }
186 
187  if(kw != 0) for(size_t i = 0; i < n; i++)
188  {
189  deltaW.row(i) -= kw * W.row(i);
190  }
191 
192  mW += u * deltaW;
193  W += mW;
194 }
195 
196 template<>
197 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
198  const arma::mat& W,
199  arma::mat& H)
200 {
201  size_t m = V.n_cols;
202 
203  size_t r = W.n_cols;
204 
205  mH = momentum * mH;
206 
207  arma::mat deltaH(r, m);
208  deltaH.zeros();
209 
210  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
211  {
212  size_t row = it.row();
213  size_t col = it.col();
214  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
215  arma::trans(W.row(row));
216  }
217 
218  if(kh != 0) for(size_t j = 0; j < m; j++)
219  {
220  deltaH.col(j) -= kh * H.col(j);
221  }
222 
223  mH += u*deltaH;
224  H += mH;
225 }
226 
227 } // namespace amf
228 } // namespace mlpack
229 
230 #endif
231 
232 
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &dataset, const size_t rank)
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9, double min=-DBL_MIN, double max=DBL_MAX)
SVD Batch learning constructor.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.