Faster scaled evaluation, etc.
This commit is contained in:
parent
4042798417
commit
3ca7c0f6f4
9 changed files with 76 additions and 27 deletions
|
|
@ -29,7 +29,7 @@ class EoSym : public EO<Fitness>, public Sym {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
void set(const Sym& sym) {
|
void set(const Sym& sym) {
|
||||||
invalidate();
|
EO<Fitness>::invalidate();
|
||||||
static_cast<Sym*>(this)->operator=(sym);
|
static_cast<Sym*>(this)->operator=(sym);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -54,7 +54,7 @@ void EoSym<Fitness>::readFrom(istream& is) {
|
||||||
template <class Fitness>
|
template <class Fitness>
|
||||||
inline std::ostream& operator<<(std::ostream& os, const EoSym<Fitness>& f) { f.printOn(os); return os; }
|
inline std::ostream& operator<<(std::ostream& os, const EoSym<Fitness>& f) { f.printOn(os); return os; }
|
||||||
template <class Fitness>
|
template <class Fitness>
|
||||||
inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return os; }
|
inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return is; }
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
#if __GNUC__ == 3
|
|
||||||
#include <backward/hash_map.h>
|
|
||||||
#else
|
|
||||||
#include <hash_map.h>
|
|
||||||
using std::hash_map;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "Sym.h"
|
#include "Sym.h"
|
||||||
#include "FunDef.h"
|
#include "FunDef.h"
|
||||||
#include "sym_compile.h"
|
#include "sym_compile.h"
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,6 @@
|
||||||
|
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#if __GNUC__ == 3
|
|
||||||
#include <backward/hash_map.h>
|
|
||||||
#else
|
|
||||||
#include <hash_map.h>
|
|
||||||
using std::hash_map;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "Sym.h"
|
#include "Sym.h"
|
||||||
#include "FunDef.h"
|
#include "FunDef.h"
|
||||||
#include <LanguageTable.h>
|
#include <LanguageTable.h>
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,6 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
eoRng rng(time(0));
|
|
||||||
|
|
||||||
extern Sym default_const();
|
extern Sym default_const();
|
||||||
|
|
||||||
class LanguageImpl {
|
class LanguageImpl {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ class Sym;
|
||||||
/** Base class for selecting nodes */
|
/** Base class for selecting nodes */
|
||||||
class NodeSelector {
|
class NodeSelector {
|
||||||
public:
|
public:
|
||||||
|
virtual ~NodeSelector() {}
|
||||||
|
|
||||||
virtual unsigned select_node(Sym sym) const = 0;
|
virtual unsigned select_node(Sym sym) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include "Sym.h"
|
#include "Sym.h"
|
||||||
#include "sym_compile.h"
|
#include "sym_compile.h"
|
||||||
#include "TargetInfo.h"
|
#include "TargetInfo.h"
|
||||||
|
#include "stats.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
@ -101,10 +102,72 @@ class ErrorMeasureImpl {
|
||||||
|
|
||||||
multi_function all = compile(pop);
|
multi_function all = compile(pop);
|
||||||
std::vector<double> y(pop.size());
|
std::vector<double> y(pop.size());
|
||||||
std::vector<double> err(pop.size());
|
|
||||||
|
Scaling noScaling = Scaling(new NoScaling);
|
||||||
|
|
||||||
const std::valarray<double>& t = train_info.targets();
|
const std::valarray<double>& t = train_info.targets();
|
||||||
|
|
||||||
|
if (measure == ErrorMeasure::mean_squared_scaled) {
|
||||||
|
std::vector<Var> var(pop.size());
|
||||||
|
std::vector<Cov> cov(pop.size());
|
||||||
|
|
||||||
|
Var vart;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < t.size(); ++i) {
|
||||||
|
vart.update(t[i]);
|
||||||
|
|
||||||
|
all(&data.get_inputs(i)[0], &y[0]); // evalutate
|
||||||
|
|
||||||
|
for (unsigned j = 0; j < pop.size(); ++j) {
|
||||||
|
var[j].update(y[j]);
|
||||||
|
cov[j].update(y[j], t[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ErrorMeasure::result> result(pop.size());
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < pop.size(); ++i) {
|
||||||
|
|
||||||
|
// calculate scaling
|
||||||
|
double b = cov[i].get_cov() / var[i].get_var();
|
||||||
|
|
||||||
|
if (!finite(b)) {
|
||||||
|
result[i].scaling = noScaling;
|
||||||
|
result[i].error = vart.get_var(); // largest error
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
double a = vart.get_mean() - b * var[i].get_mean();
|
||||||
|
|
||||||
|
result[i].scaling = Scaling( new LinearScaling(a,b));
|
||||||
|
|
||||||
|
// calculate error
|
||||||
|
double c = cov[i].get_cov();
|
||||||
|
c *= c;
|
||||||
|
|
||||||
|
double err = vart.get_var() - c / var[i].get_var();
|
||||||
|
result[i].error = err;
|
||||||
|
if (!finite(err)) {
|
||||||
|
cout << "b " << b << endl;
|
||||||
|
cout << "var t " << vart.get_var() << endl;
|
||||||
|
cout << "var i " << var[i].get_var() << endl;
|
||||||
|
cout << "cov " << cov[i].get_cov() << endl;
|
||||||
|
|
||||||
|
for (unsigned j = 0; j < t.size(); ++j) {
|
||||||
|
all(&data.get_inputs(i)[0], &y[0]); // evalutate
|
||||||
|
|
||||||
|
cout << y[i] << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> err(pop.size());
|
||||||
|
|
||||||
for (unsigned i = 0; i < train_cases(); ++i) {
|
for (unsigned i = 0; i < train_cases(); ++i) {
|
||||||
// evaluate
|
// evaluate
|
||||||
all(&data.get_inputs(i)[0], &y[0]);
|
all(&data.get_inputs(i)[0], &y[0]);
|
||||||
|
|
@ -124,10 +187,9 @@ class ErrorMeasureImpl {
|
||||||
std::vector<ErrorMeasure::result> result(pop.size());
|
std::vector<ErrorMeasure::result> result(pop.size());
|
||||||
|
|
||||||
double n = train_cases();
|
double n = train_cases();
|
||||||
Scaling no = Scaling(new NoScaling);
|
|
||||||
for (unsigned i = 0; i < pop.size(); ++i) {
|
for (unsigned i = 0; i < pop.size(); ++i) {
|
||||||
result[i].error = err[i] / n;
|
result[i].error = err[i] / n;
|
||||||
result[i].scaling = no;
|
result[i].scaling = noScaling;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -194,7 +256,7 @@ class ErrorMeasureImpl {
|
||||||
dresult = multi_function_eval(decloned);
|
dresult = multi_function_eval(decloned);
|
||||||
break;
|
break;
|
||||||
case ErrorMeasure::mean_squared_scaled:
|
case ErrorMeasure::mean_squared_scaled:
|
||||||
dresult = single_function_eval(decloned);
|
dresult = multi_function_eval(decloned);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -241,7 +303,6 @@ ErrorMeasure::result ErrorMeasure::calc_error(Sym sym) {
|
||||||
res.error = not_a_number;
|
res.error = not_a_number;
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return pimpl->eval(y);
|
return pimpl->eval(y);
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ class TargetInfo;
|
||||||
class ScalingBase {
|
class ScalingBase {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
virtual ~ScalingBase() {}
|
||||||
|
|
||||||
std::valarray<double> apply(const std::valarray<double>& x) {
|
std::valarray<double> apply(const std::valarray<double>& x) {
|
||||||
std::valarray<double> xtmp = x;
|
std::valarray<double> xtmp = x;
|
||||||
transform(xtmp);
|
transform(xtmp);
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#if __GNUC__ == 3
|
#if __GNUC__ >= 3
|
||||||
#include <backward/hash_map.h>
|
#include <backward/hash_map.h>
|
||||||
#else
|
#else
|
||||||
#include <hash_map.h>
|
#include <hash_map.h>
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ struct SymValue
|
||||||
unsigned getSize() const { return size; }
|
unsigned getSize() const { return size; }
|
||||||
unsigned getDepth() const { return depth; }
|
unsigned getDepth() const { return depth; }
|
||||||
|
|
||||||
private :
|
|
||||||
|
|
||||||
// for reference counting
|
// for reference counting
|
||||||
unsigned refcount;
|
unsigned refcount;
|
||||||
|
|
@ -92,8 +92,8 @@ struct SymValue
|
||||||
// some simple stats
|
// some simple stats
|
||||||
unsigned size;
|
unsigned size;
|
||||||
unsigned depth;
|
unsigned depth;
|
||||||
|
|
||||||
UniqueNodeStats* uniqueNodeStats;
|
UniqueNodeStats* uniqueNodeStats;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue