/usr/include/dlib/svm/rls.h is in libdlib-dev 18.18-2.
This file is owned by root:root, with mode 0o644.
The actual contents of the file can be viewed below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | // Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_RLs_Hh_
#define DLIB_RLs_Hh_
#include "rls_abstract.h"
#include "../matrix.h"
#include "function.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class rls
{
public:
explicit rls(
double forget_factor_,
double C_ = 1000
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 &&
0 < C_,
"\t rls::rls()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t forget_factor_: " << forget_factor_
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
forget_factor = forget_factor_;
}
rls(
)
{
C = 1000;
forget_factor = 1;
}
double get_c(
) const
{
return C;
}
double get_forget_factor(
) const
{
return forget_factor;
}
template <typename EXP>
void train (
const matrix_exp<EXP>& x,
double y
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(x) &&
(get_w().size() == 0 || get_w().size() == x.size()),
"\t void rls::train()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t is_col_vector(x): " << is_col_vector(x)
<< "\n\t x.size(): " << x.size()
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
if (R.size() == 0)
{
R = identity_matrix<double>(x.size())*C;
w.set_size(x.size());
w = 0;
}
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor;
const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm.
R = 0.5*(R + trans(R));
w = w + R*x*(y - trans(x)*w);
}
const matrix<double,0,1>& get_w(
) const
{
return w;
}
template <typename EXP>
double operator() (
const matrix_exp<EXP>& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(),
"\t double rls::operator()()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t is_col_vector(x): " << is_col_vector(x)
<< "\n\t x.size(): " << x.size()
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
return dot(x,w);
}
decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_w().size() != 0,
"\t decision_function rls::get_decision_function()"
<< "\n\t invalid arguments were given to this function"
<< "\n\t get_w().size(): " << get_w().size()
<< "\n\t this: " << this
);
decision_function<linear_kernel<matrix<double,0,1> > > df;
df.alpha.set_size(1);
df.basis_vectors.set_size(1);
df.b = 0;
df.alpha = 1;
df.basis_vectors(0) = w;
return df;
}
friend inline void serialize(const rls& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.w, out);
serialize(item.R, out);
serialize(item.C, out);
serialize(item.forget_factor, out);
}
friend inline void deserialize(rls& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
deserialize(item.w, in);
deserialize(item.R, in);
deserialize(item.C, in);
deserialize(item.forget_factor, in);
}
private:
void add_eye_to_inv(
matrix<double>& m,
double C
)
/*!
ensures
- Let m == inv(M)
- this function returns inv(M + C*identity_matrix<double>(m.nr()))
!*/
{
for (long r = 0; r < m.nr(); ++r)
{
m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
}
}
matrix<double,0,1> w;
matrix<double> R;
double C;
double forget_factor;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_RLs_Hh_
|