00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef GP_LWPR_H
00011 #define GP_LWPR_H
00012
00013 #include "array.h"
00014
00015
00016
00017 class Lmodel;
00018
00023 void getStdDeviations(doubleA& norm,const doubleA& X);
00024
00025
00026
00029 class Lwpr {
00030 public:
00032 Lwpr();
00033
00035 void learn(const doubleA& X,const doubleA& Y);
00036
00038 void predict(const doubleA& X,doubleA& Y);
00039
00041 double confidence(const doubleA &x){
00042 doubleA nx;
00043 normalizeInput(nx,x);
00044 return ::sqrt(s2pred(nx));
00045 }
00046
00048 void save(char *);
00049
00051 void load(char *,bool alsoParameters=true);
00052
00054 void report(std::ostream& os);
00055
00057 int get_rfs_no(int out_dim);
00058
00060 double get_proj_average(int out_dim);
00061
00063 int verbosity;
00064
00066 uint rfsno;
00067
00080 doubleA norm;
00081
00085 bool add_proj;
00086
00088 bool updateD;
00089
00091 bool useNorm;
00092
00100 double cut_off;
00105 double w_prune;
00110 double w_gen;
00113 double initD;
00116 double init_lambda;
00119 double tau_lambda;
00122 double final_lambda;
00123
00130 double add_threshold;
00131
00133 bool meta;
00134
00136 double meta_rate;
00137
00141 int initR;
00142
00151 double alpha;
00152
00155 double gamma;
00156
00159 bool blend;
00160
00161 private:
00162 void initialize(uint inDim,uint outDim);
00163 double s2pred(const doubleA& x);
00164 double predict(const MT::Array<Lmodel>& rfs, const doubleA& x);
00165 void updaterfs(MT::Array<Lmodel>& rfs, const doubleA& x, double y);
00166 void normalizeInput(doubleA& normIn,const doubleA& unNormIn);
00167
00168 bool isInitialized;
00169 uint in_dim;
00170 uint out_dim;
00171
00172 MT::Array<MT::Array<Lmodel> > models;
00173 };
00174
00175
00176
00177 #ifndef MT_doxy
00178
00179
00180 class Lmodel {
00181 public:
00182 Lwpr *lwpr;
00183
00184 doubleA center;
00185 doubleA mean_x;
00186 double mean_y;
00187 doubleA W;
00188 double MyMSE_R;
00189 doubleA lambda;
00190 doubleA M;
00191 doubleA n_data;
00192
00193
00194 MT::Array<doubleA> uprojections;
00195 MT::Array<doubleA> pprojections;
00196 MT::Array<doubleA> sXresYres;
00197 doubleA betas;
00198 doubleA MSE;
00199
00200 doubleA azz;
00201 doubleA azres;
00202 MT::Array<doubleA> axz;
00203
00204
00205 double sum_e2;
00206 doubleA sum_ecv2;
00207 doubleA aH;
00208 doubleA aG;
00209 double aE;
00210
00211
00212 double gamma;
00213
00214 doubleA alpha;
00215
00216 doubleA b;
00217 doubleA h;
00218
00219 double apk;
00220 doubleA dist;
00221
00222 public:
00223 int R;
00224 bool trustworthy;
00225 bool degenerated;
00226
00227 public:
00228 Lmodel(){};
00229 Lmodel(Lwpr* _lwpr,const doubleA& x, double y);
00230 Lmodel(Lwpr* _lwpr,const Lmodel& model, const doubleA& x, double y);
00231 double activation(const doubleA& x);
00232 void updateMean(const doubleA& x, double y);
00233 void updateStat(double w);
00234 void updateError(const doubleA& x, double y);
00235 void updateError_matlab(double, double);
00236 void update(const doubleA& x, double y);
00237 double predict(const doubleA& x ) const;
00238 void update_dist(const doubleA& x, double y, double e_cv, double e);
00239 double gete(const doubleA& x, double y);
00240 double gete_cv(const doubleA& x, double y);
00241 bool check_add_projection(const doubleA& x, double y);
00242 void add_projection(const doubleA& x, double y);
00243 void printProj();
00244 friend std::ostream& operator<<(std::ostream& outs, const Lmodel& lm);
00245
00246 void save(std::ostream&);
00247 void load(Lwpr* _lwpr, std::istream&);
00248
00249 double s2pred(const doubleA& x, double w);
00250 void updateApk(const doubleA& x,double w);
00251 const doubleA& getCenter() const{ return center; }
00252
00253 private :
00254 void computeZs(doubleA& zs, const doubleA& x);
00255 void loadParameters(uint N);
00256 void initialise(int N);
00257
00258 doubleA convert_vec(const doubleA& x);
00259 void getdJdM(doubleA& dJdM,const doubleA& D,const doubleA& M,
00260 const doubleA& x, const doubleA& center,
00261 const doubleA& z, const doubleA& azz, double w,
00262 double W, double e_cv, double e, double gamma,
00263 const doubleA& derivative_ok, double& sum_dJ1dw,
00264 doubleA& dwdM);
00265
00266 double getsum_dJ1dw(double e_cv, double e, double w, double W,
00267 const doubleA& z_vec, const doubleA& azz_vec,
00268 const doubleA& derivative_ok);
00269 void update_dist_stat(double e_cv, double w, const doubleA& z_vec,
00270 const doubleA& azz_vec, double transient_multiplier,
00271 const doubleA& derivative_ok);
00272 void getdJ2dM(doubleA& dJ2dM,const doubleA& D, const doubleA& M, double gamma);
00273 void getdwdM(doubleA& dwdM,const doubleA& M, const doubleA& x,
00274 const doubleA& center, double w);
00275 void getdDdMkl(doubleA& dDdMkl,const doubleA& M, int k, int l);
00276
00277
00278 void meta_update(double e_cv, double W, double w, double e, const doubleA& aH,
00279 double aE, const doubleA& z_vec,
00280 const doubleA& azz_vec, const doubleA& derivative_ok,
00281 const doubleA& x, const doubleA& center, double dJ1dw,
00282 const doubleA& dwdM, double transient_multiplier, const doubleA& dJdM);
00283 void getdJ2dJ2dMdM(doubleA& dJ2dJ2dMdM,const doubleA& D, const doubleA& M,
00284 double gamma);
00285 void getdwdwdMdM(doubleA& dwdwdMdM,const doubleA& M, const doubleA& x,
00286 const doubleA& center, double w);
00287
00288 void cholesky(doubleA& L,const doubleA& A);
00289 intA find(const doubleA& A, double value, bool gt);
00290 double absmax(const doubleA& A);
00291 };
00292 #endif
00293
00294 #ifdef MT_IMPLEMENTATION
00295 #include "lwpr.cpp"
00296 #endif
00297
00298 #endif