Public Member Functions | Public Attributes | Protected Member Functions | Protected Attributes | List of all members
WireCell::ElasticNetModel Class Reference

#include <ElasticNetModel.h>

Inheritance diagram for WireCell::ElasticNetModel:
WireCell::LinearModel WireCell::LassoModel

Public Member Functions

 ElasticNetModel (double lambda=1., double alpha=1., int max_iter=100000, double TOL=1e-3, bool non_negtive=true)
 
 ~ElasticNetModel ()
 
void SetLambdaWeight (Eigen::VectorXd w)
 
void SetLambdaWeight (int i, double weight)
 
void SetX (Eigen::MatrixXd X)
 
virtual void Fit ()
 
- Public Member Functions inherited from WireCell::LinearModel
 LinearModel ()
 
virtual ~LinearModel ()
 
Eigen::VectorXd & Gety ()
 
Eigen::MatrixXd & GetX ()
 
Eigen::VectorXd & Getbeta ()
 
virtual void SetData (Eigen::MatrixXd X, Eigen::VectorXd y)
 
virtual void Sety (Eigen::VectorXd y)
 
virtual void Setbeta (Eigen::VectorXd beta)
 
Eigen::VectorXd Predict ()
 
double chi2_base ()
 
double MeanResidual ()
 

Public Attributes

double lambda
 
double alpha
 
int max_iter
 
double TOL
 
bool non_negtive
 
Eigen::VectorXd lambda_weight
 
- Public Attributes inherited from WireCell::LinearModel
std::string name
 

Protected Member Functions

double _soft_thresholding (double x, double lambda_)
 

Protected Attributes

std::vector< bool_active_beta
 
- Protected Attributes inherited from WireCell::LinearModel
Eigen::VectorXd _y
 
Eigen::MatrixXd _X
 
Eigen::VectorXd _beta
 

Detailed Description

Definition at line 10 of file ElasticNetModel.h.

Constructor & Destructor Documentation

WireCell::ElasticNetModel::ElasticNetModel ( double  lambda = 1.,
double  alpha = 1.,
int  max_iter = 100000,
double  TOL = 1e-3,
bool  non_negtive = true 
)
WireCell::ElasticNetModel::~ElasticNetModel ( )

Definition at line 23 of file ElasticNetModel.cxx.

24 {}

Member Function Documentation

double WireCell::ElasticNetModel::_soft_thresholding ( double  x,
double  lambda_ 
)
protected

Definition at line 98 of file ElasticNetModel.cxx.

99 {
100 
101  if (delta > lambda_) {
102  return delta - lambda_;
103  }
104  else {
105  if (non_negtive) {
106  return 0;
107  }
108  else {
109  if (delta < -lambda_) {
110  return delta + lambda_;
111  }
112  else {
113  return 0;
114  }
115  }
116  }
117 }
void WireCell::ElasticNetModel::Fit ( )
virtual

Reimplemented from WireCell::LinearModel.

Reimplemented in WireCell::LassoModel.

Definition at line 26 of file ElasticNetModel.cxx.

27 {
28  // initialize solution to zero unless user set beta already
29  Eigen::VectorXd beta = _beta;
30  if (0 == beta.size()) {
31  beta = VectorXd::Zero(_X.cols());
32  }
33 
34  // initialize active_beta to true
35  int nbeta = beta.size();
36  _active_beta = vector<bool>(nbeta, true);
37 
38  // use alias for easy notation
39  Eigen::VectorXd y = Gety();
40  Eigen::MatrixXd X = GetX();
41 
42  // cooridate decsent
43 
44  //int N = y.size();
45  VectorXd norm(nbeta);
46  for (int j=0; j<nbeta; j++) {
47  norm(j) = X.col(j).squaredNorm();
48  if (norm(j) < 1e-6) {
49  cerr << "warning: the " << j << "th variable is not used, please consider removing it." << endl;
50  norm(j) = 1;
51  }
52  }
53  double tol2 = TOL*TOL*nbeta;
54 
55  int double_check = 0;
56  for (int i=0; i<max_iter; i++) {
57  VectorXd betalast = beta;
58  for (int j=0; j<nbeta; j++) {
59  if (!_active_beta[j]) {continue;}
60  VectorXd X_j = X.col(j);
61  VectorXd beta_tmp = beta;
62  beta_tmp(j) = 0;
63  VectorXd r_j = (y - X * beta_tmp);
64  double delta_j = X_j.dot(r_j);
65  // beta(j) = _soft_thresholding(delta_j, N*lambda*alpha*lambda_weight(j)) / (1+lambda*(1-alpha)) / norm(j);
66  beta(j) = _soft_thresholding(delta_j/norm(j), lambda*alpha*lambda_weight(j)) / (1+lambda*(1-alpha));
67 
68  //cout << i << " " << j << " " << beta(j) << std::endl;
69  if(fabs(beta(j)) < 1e-6) { _active_beta[j] = false; }
70  // else { cout << beta(j) << endl;}
71  // beta(j) = _soft_thresholding(delta_j, N*lambda*alpha, j) / (1+lambda*(1-alpha)) / norm(j);
72  // if (j==0) cout << beta(j) << ", " << arg1 << endl;
73  }
74  double_check++;
75  // cout << endl;
76  VectorXd diff = beta - betalast;
77 
78  //std::cout << i << " " << diff.squaredNorm() << " " << tol2 << std::endl;
79  if (diff.squaredNorm()<tol2) {
80  if (double_check!=1) {
81  double_check = 0;
82  for (int k=0; k<nbeta; k++) {
83  _active_beta[k] = true;
84  }
85  }
86  else {
87  // cout << "found minimum at iteration: " << i << endl;
88  break;
89  }
90 
91  }
92  }
93 
94  // save results in the model
95  Setbeta(beta);
96 }
Eigen::MatrixXd _X
Definition: LinearModel.h:33
std::vector< bool > _active_beta
Eigen::MatrixXd & GetX()
Definition: LinearModel.h:15
virtual void Setbeta(Eigen::VectorXd beta)
Definition: LinearModel.h:21
Eigen::VectorXd _beta
Definition: LinearModel.h:34
double y
const double e
Eigen::VectorXd & Gety()
Definition: LinearModel.h:14
double _soft_thresholding(double x, double lambda_)
auto norm(Vector const &v)
Return norm of the specified vector.
Eigen::VectorXd lambda_weight
QTextStream & endl(QTextStream &s)
void WireCell::ElasticNetModel::SetLambdaWeight ( Eigen::VectorXd  w)
inline

Definition at line 22 of file ElasticNetModel.h.

22 { lambda_weight = w; }
Eigen::VectorXd lambda_weight
void WireCell::ElasticNetModel::SetLambdaWeight ( int  i,
double  weight 
)
inline

Definition at line 23 of file ElasticNetModel.h.

23 { lambda_weight(i) = weight; }
Eigen::VectorXd lambda_weight
weight
Definition: test.py:257
void WireCell::ElasticNetModel::SetX ( Eigen::MatrixXd  X)
inlinevirtual

Reimplemented from WireCell::LinearModel.

Definition at line 24 of file ElasticNetModel.h.

24 { LinearModel::SetX(X); SetLambdaWeight(Eigen::VectorXd::Zero(X.cols()) + Eigen::VectorXd::Constant(X.cols(),1.)); }
void SetLambdaWeight(Eigen::VectorXd w)
virtual void SetX(Eigen::MatrixXd X)
Definition: LinearModel.h:20

Member Data Documentation

std::vector<bool> WireCell::ElasticNetModel::_active_beta
protected

Definition at line 29 of file ElasticNetModel.h.

double WireCell::ElasticNetModel::alpha

Definition at line 16 of file ElasticNetModel.h.

double WireCell::ElasticNetModel::lambda

Definition at line 15 of file ElasticNetModel.h.

Eigen::VectorXd WireCell::ElasticNetModel::lambda_weight

Definition at line 20 of file ElasticNetModel.h.

int WireCell::ElasticNetModel::max_iter

Definition at line 17 of file ElasticNetModel.h.

bool WireCell::ElasticNetModel::non_negtive

Definition at line 19 of file ElasticNetModel.h.

double WireCell::ElasticNetModel::TOL

Definition at line 18 of file ElasticNetModel.h.


The documentation for this class was generated from the following files: