20 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
21 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
36 double min = -DBL_MIN,
41 template<
typename MatType>
42 void Initialize(
const MatType& dataset,
const size_t rank)
44 const size_t n = dataset.n_rows;
45 const size_t m = dataset.n_cols;
60 template<
typename MatType>
72 arma::mat deltaW(n, r);
75 for(
size_t i = 0;i < n;i++)
77 for(
size_t j = 0;j < m;j++)
80 if((val = V(i, j)) != 0)
81 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
82 arma::trans(H.col(j));
84 if(
kw != 0) deltaW.row(i) -=
kw * W.row(i);
100 template<
typename MatType>
112 arma::mat deltaH(r, m);
115 for(
size_t j = 0;j < m;j++)
117 for(
size_t i = 0;i < n;i++)
120 if((val = V(i, j)) != 0)
121 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
122 arma::trans(W.row(i));
124 if(
kh != 0) deltaH.col(j) -=
kh * H.col(j);
144 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
154 arma::mat deltaW(n, r);
157 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
159 size_t row = it.row();
160 size_t col = it.col();
161 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
162 arma::trans(H.col(col));
165 if(kw != 0)
for(
size_t i = 0; i < n; i++)
167 deltaW.row(i) -= kw * W.row(i);
175 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
185 arma::mat deltaH(r, m);
188 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
190 size_t row = it.row();
191 size_t col = it.col();
192 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
193 arma::trans(W.row(row));
196 if(kh != 0)
for(
size_t j = 0; j < m; j++)
198 deltaH.col(j) -= kh * H.col(j);
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &dataset, const size_t rank)
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9, double min=-DBL_MIN, double max=DBL_MAX)
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.