Faster scaled evaluation, etc.

This commit is contained in:
maartenkeijzer 2005-10-07 13:31:20 +00:00
commit 3ca7c0f6f4
9 changed files with 76 additions and 27 deletions

View file

@ -29,7 +29,7 @@ class EoSym : public EO<Fitness>, public Sym {
public: public:
void set(const Sym& sym) { void set(const Sym& sym) {
invalidate(); EO<Fitness>::invalidate();
static_cast<Sym*>(this)->operator=(sym); static_cast<Sym*>(this)->operator=(sym);
} }
@ -54,7 +54,7 @@ void EoSym<Fitness>::readFrom(istream& is) {
template <class Fitness> template <class Fitness>
inline std::ostream& operator<<(std::ostream& os, const EoSym<Fitness>& f) { f.printOn(os); return os; } inline std::ostream& operator<<(std::ostream& os, const EoSym<Fitness>& f) { f.printOn(os); return os; }
template <class Fitness> template <class Fitness>
inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return os; } inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return is; }
#endif #endif

View file

@ -16,13 +16,6 @@
*/ */
#if __GNUC__ == 3
#include <backward/hash_map.h>
#else
#include <hash_map.h>
using std::hash_map;
#endif
#include "Sym.h" #include "Sym.h"
#include "FunDef.h" #include "FunDef.h"
#include "sym_compile.h" #include "sym_compile.h"

View file

@ -17,13 +17,6 @@
#include <sstream> #include <sstream>
#if __GNUC__ == 3
#include <backward/hash_map.h>
#else
#include <hash_map.h>
using std::hash_map;
#endif
#include "Sym.h" #include "Sym.h"
#include "FunDef.h" #include "FunDef.h"
#include <LanguageTable.h> #include <LanguageTable.h>

View file

@ -23,8 +23,6 @@
using namespace std; using namespace std;
eoRng rng(time(0));
extern Sym default_const(); extern Sym default_const();
class LanguageImpl { class LanguageImpl {

View file

@ -23,6 +23,8 @@ class Sym;
/** Base class for selecting nodes */ /** Base class for selecting nodes */
class NodeSelector { class NodeSelector {
public: public:
virtual ~NodeSelector() {}
virtual unsigned select_node(Sym sym) const = 0; virtual unsigned select_node(Sym sym) const = 0;
}; };

View file

@ -24,6 +24,7 @@
#include "Sym.h" #include "Sym.h"
#include "sym_compile.h" #include "sym_compile.h"
#include "TargetInfo.h" #include "TargetInfo.h"
#include "stats.h"
using namespace std; using namespace std;
@ -101,10 +102,72 @@ class ErrorMeasureImpl {
multi_function all = compile(pop); multi_function all = compile(pop);
std::vector<double> y(pop.size()); std::vector<double> y(pop.size());
std::vector<double> err(pop.size());
Scaling noScaling = Scaling(new NoScaling);
const std::valarray<double>& t = train_info.targets(); const std::valarray<double>& t = train_info.targets();
if (measure == ErrorMeasure::mean_squared_scaled) {
std::vector<Var> var(pop.size());
std::vector<Cov> cov(pop.size());
Var vart;
for (unsigned i = 0; i < t.size(); ++i) {
vart.update(t[i]);
all(&data.get_inputs(i)[0], &y[0]); // evalutate
for (unsigned j = 0; j < pop.size(); ++j) {
var[j].update(y[j]);
cov[j].update(y[j], t[i]);
}
}
std::vector<ErrorMeasure::result> result(pop.size());
for (unsigned i = 0; i < pop.size(); ++i) {
// calculate scaling
double b = cov[i].get_cov() / var[i].get_var();
if (!finite(b)) {
result[i].scaling = noScaling;
result[i].error = vart.get_var(); // largest error
continue;
}
double a = vart.get_mean() - b * var[i].get_mean();
result[i].scaling = Scaling( new LinearScaling(a,b));
// calculate error
double c = cov[i].get_cov();
c *= c;
double err = vart.get_var() - c / var[i].get_var();
result[i].error = err;
if (!finite(err)) {
cout << "b " << b << endl;
cout << "var t " << vart.get_var() << endl;
cout << "var i " << var[i].get_var() << endl;
cout << "cov " << cov[i].get_cov() << endl;
for (unsigned j = 0; j < t.size(); ++j) {
all(&data.get_inputs(i)[0], &y[0]); // evalutate
cout << y[i] << endl;
}
exit(1);
}
}
return result;
}
std::vector<double> err(pop.size());
for (unsigned i = 0; i < train_cases(); ++i) { for (unsigned i = 0; i < train_cases(); ++i) {
// evaluate // evaluate
all(&data.get_inputs(i)[0], &y[0]); all(&data.get_inputs(i)[0], &y[0]);
@ -124,10 +187,9 @@ class ErrorMeasureImpl {
std::vector<ErrorMeasure::result> result(pop.size()); std::vector<ErrorMeasure::result> result(pop.size());
double n = train_cases(); double n = train_cases();
Scaling no = Scaling(new NoScaling);
for (unsigned i = 0; i < pop.size(); ++i) { for (unsigned i = 0; i < pop.size(); ++i) {
result[i].error = err[i] / n; result[i].error = err[i] / n;
result[i].scaling = no; result[i].scaling = noScaling;
} }
return result; return result;
@ -194,7 +256,7 @@ class ErrorMeasureImpl {
dresult = multi_function_eval(decloned); dresult = multi_function_eval(decloned);
break; break;
case ErrorMeasure::mean_squared_scaled: case ErrorMeasure::mean_squared_scaled:
dresult = single_function_eval(decloned); dresult = multi_function_eval(decloned);
break; break;
} }
@ -241,7 +303,6 @@ ErrorMeasure::result ErrorMeasure::calc_error(Sym sym) {
res.error = not_a_number; res.error = not_a_number;
return res; return res;
} }
} }
return pimpl->eval(y); return pimpl->eval(y);

View file

@ -28,7 +28,9 @@ class TargetInfo;
class ScalingBase { class ScalingBase {
public: public:
virtual ~ScalingBase() {}
std::valarray<double> apply(const std::valarray<double>& x) { std::valarray<double> apply(const std::valarray<double>& x) {
std::valarray<double> xtmp = x; std::valarray<double> xtmp = x;
transform(xtmp); transform(xtmp);

View file

@ -20,7 +20,7 @@
#include <cassert> #include <cassert>
#if __GNUC__ == 3 #if __GNUC__ >= 3
#include <backward/hash_map.h> #include <backward/hash_map.h>
#else #else
#include <hash_map.h> #include <hash_map.h>

View file

@ -84,7 +84,7 @@ struct SymValue
unsigned getSize() const { return size; } unsigned getSize() const { return size; }
unsigned getDepth() const { return depth; } unsigned getDepth() const { return depth; }
private :
// for reference counting // for reference counting
unsigned refcount; unsigned refcount;
@ -92,8 +92,8 @@ struct SymValue
// some simple stats // some simple stats
unsigned size; unsigned size;
unsigned depth; unsigned depth;
UniqueNodeStats* uniqueNodeStats; UniqueNodeStats* uniqueNodeStats;
}; };