Main Page
Related Pages
Modules
Namespaces
Classes
Files
Examples
File List
File Members
wire-cell-build
ress
src
ElasticNetModel.cxx
Go to the documentation of this file.
1
#include "
WireCellRess/ElasticNetModel.h
"
2
3
#include <Eigen/Dense>
4
using namespace
Eigen
;
5
6
#include <iostream>
7
using namespace
std
;
8
9
/* Minimize the following problem:
10
* 1/(2) * ||Y - beta * X||_2^2 + N * lambda * (
11
* alpha * ||beta||_1 + 0.5 * (1-alpha) * ||beta||_2^2
12
* )
13
* To control L1 and L2 separately, this is equivaletnt to a * L1 + b * L2,
14
* where lambda = a + b and alpha = a / (a + b)
15
*/
16
17
WireCell::ElasticNetModel::ElasticNetModel
(
double
lambda,
double
alpha,
int
max_iter,
double
TOL,
bool
non_negtive)
18
: lambda(lambda), alpha(alpha), max_iter(max_iter), TOL(TOL), non_negtive(non_negtive)
19
{
20
name
=
"Elastic net"
;
21
}
22
23
WireCell::ElasticNetModel::~ElasticNetModel
()
24
{}
25
26
void
WireCell::ElasticNetModel::Fit
()
27
{
28
// initialize solution to zero unless user set beta already
29
Eigen::VectorXd beta =
_beta
;
30
if
(0 == beta.size()) {
31
beta = VectorXd::Zero(
_X
.cols());
32
}
33
34
// initialize active_beta to true
35
int
nbeta = beta.size();
36
_active_beta
= vector<bool>(nbeta,
true
);
37
38
// use alias for easy notation
39
Eigen::VectorXd
y
=
Gety
();
40
Eigen::MatrixXd
X
=
GetX
();
41
42
// cooridate decsent
43
44
//int N = y.size();
45
VectorXd
norm
(nbeta);
46
for
(
int
j=0; j<nbeta; j++) {
47
norm
(j) = X.col(j).squaredNorm();
48
if
(
norm
(j) < 1
e
-6) {
49
cerr <<
"warning: the "
<< j <<
"th variable is not used, please consider removing it."
<<
endl
;
50
norm
(j) = 1;
51
}
52
}
53
double
tol2 =
TOL
*
TOL
*nbeta;
54
55
int
double_check = 0;
56
for
(
int
i=0; i<
max_iter
; i++) {
57
VectorXd betalast = beta;
58
for
(
int
j=0; j<nbeta; j++) {
59
if
(!
_active_beta
[j]) {
continue
;}
60
VectorXd X_j = X.col(j);
61
VectorXd beta_tmp = beta;
62
beta_tmp(j) = 0;
63
VectorXd r_j = (y - X * beta_tmp);
64
double
delta_j = X_j.dot(r_j);
65
// beta(j) = _soft_thresholding(delta_j, N*lambda*alpha*lambda_weight(j)) / (1+lambda*(1-alpha)) / norm(j);
66
beta(j) =
_soft_thresholding
(delta_j/
norm
(j),
lambda
*
alpha
*
lambda_weight
(j)) / (1+
lambda
*(1-
alpha
));
67
68
//cout << i << " " << j << " " << beta(j) << std::endl;
69
if
(fabs(beta(j)) < 1
e
-6) {
_active_beta
[j] =
false
; }
70
// else { cout << beta(j) << endl;}
71
// beta(j) = _soft_thresholding(delta_j, N*lambda*alpha, j) / (1+lambda*(1-alpha)) / norm(j);
72
// if (j==0) cout << beta(j) << ", " << arg1 << endl;
73
}
74
double_check++;
75
// cout << endl;
76
VectorXd diff = beta - betalast;
77
78
//std::cout << i << " " << diff.squaredNorm() << " " << tol2 << std::endl;
79
if
(diff.squaredNorm()<tol2) {
80
if
(double_check!=1) {
81
double_check = 0;
82
for
(
int
k
=0;
k
<nbeta;
k
++) {
83
_active_beta
[
k
] =
true
;
84
}
85
}
86
else
{
87
// cout << "found minimum at iteration: " << i << endl;
88
break
;
89
}
90
91
}
92
}
93
94
// save results in the model
95
Setbeta
(beta);
96
}
97
98
double
WireCell::ElasticNetModel::_soft_thresholding
(
double
delta,
double
lambda_)
99
{
100
101
if
(delta > lambda_) {
102
return
delta - lambda_;
103
}
104
else
{
105
if
(
non_negtive
) {
106
return
0;
107
}
108
else
{
109
if
(delta < -lambda_) {
110
return
delta + lambda_;
111
}
112
else
{
113
return
0;
114
}
115
}
116
}
117
}
WireCell::ElasticNetModel::max_iter
int max_iter
Definition:
ElasticNetModel.h:17
WireCell::LinearModel::_X
Eigen::MatrixXd _X
Definition:
LinearModel.h:33
WireCell::ElasticNetModel::TOL
double TOL
Definition:
ElasticNetModel.h:18
WireCell::ElasticNetModel::_active_beta
std::vector< bool > _active_beta
Definition:
ElasticNetModel.h:29
WireCell::LinearModel::GetX
Eigen::MatrixXd & GetX()
Definition:
LinearModel.h:15
WireCell::LinearModel::Setbeta
virtual void Setbeta(Eigen::VectorXd beta)
Definition:
LinearModel.h:21
Eigen
std
STL namespace.
WireCell::LinearModel::_beta
Eigen::VectorXd _beta
Definition:
LinearModel.h:34
y
double y
Definition:
GapWidth_module.cc:109
e
const double e
Definition:
gUpMuFluxGen.cxx:165
WireCell::ElasticNetModel::alpha
double alpha
Definition:
ElasticNetModel.h:16
WireCell::LinearModel::Gety
Eigen::VectorXd & Gety()
Definition:
LinearModel.h:14
wirecell.validate.cmaps.X
X
Definition:
cmaps.py:74
WireCell::ElasticNetModel::ElasticNetModel
ElasticNetModel(double lambda=1., double alpha=1., int max_iter=100000, double TOL=1e-3, bool non_negtive=true)
Definition:
ElasticNetModel.cxx:17
WireCell::ElasticNetModel::Fit
virtual void Fit()
Definition:
ElasticNetModel.cxx:26
WireCell::ElasticNetModel::_soft_thresholding
double _soft_thresholding(double x, double lambda_)
Definition:
ElasticNetModel.cxx:98
geo::vect::norm
auto norm(Vector const &v)
Return norm of the specified vector.
Definition:
geo_vectors_utils.h:1202
WireCell::LinearModel::name
std::string name
Definition:
LinearModel.h:27
WireCell::ElasticNetModel::~ElasticNetModel
~ElasticNetModel()
Definition:
ElasticNetModel.cxx:23
ElasticNetModel.h
WireCell::ElasticNetModel::non_negtive
bool non_negtive
Definition:
ElasticNetModel.h:19
WireCell::ElasticNetModel::lambda_weight
Eigen::VectorXd lambda_weight
Definition:
ElasticNetModel.h:20
muoncounters.k
k
Definition:
muoncounters.py:27
WireCell::ElasticNetModel::lambda
double lambda
Definition:
ElasticNetModel.h:15
endl
QTextStream & endl(QTextStream &s)
Definition:
qtextstream.cpp:2030
Generated by
1.8.11