Faster scaled evaluation, etc.

This commit is contained in:
maartenkeijzer 2005-10-07 13:31:20 +00:00
commit 3ca7c0f6f4
9 changed files with 76 additions and 27 deletions

View file

@ -29,7 +29,7 @@ class EoSym : public EO<Fitness>, public Sym {
public:
void set(const Sym& sym) {
invalidate();
EO<Fitness>::invalidate();
static_cast<Sym*>(this)->operator=(sym);
}
@ -54,7 +54,7 @@ void EoSym<Fitness>::readFrom(istream& is) {
template <class Fitness>
inline std::ostream& operator<<(std::ostream& os, const EoSym<Fitness>& f) { f.printOn(os); return os; }
template <class Fitness>
inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return os; }
inline istream& operator>>(std::istream& is, EoSym<Fitness>& f) { f.readFrom(is); return is; }
#endif

View file

@ -16,13 +16,6 @@
*/
#if __GNUC__ == 3
#include <backward/hash_map.h>
#else
#include <hash_map.h>
using std::hash_map;
#endif
#include "Sym.h"
#include "FunDef.h"
#include "sym_compile.h"

View file

@ -17,13 +17,6 @@
#include <sstream>
#if __GNUC__ == 3
#include <backward/hash_map.h>
#else
#include <hash_map.h>
using std::hash_map;
#endif
#include "Sym.h"
#include "FunDef.h"
#include <LanguageTable.h>

View file

@ -23,8 +23,6 @@
using namespace std;
eoRng rng(time(0));
extern Sym default_const();
class LanguageImpl {

View file

@ -23,6 +23,8 @@ class Sym;
/** Base class for selecting nodes */
class NodeSelector {
public:
virtual ~NodeSelector() {}
virtual unsigned select_node(Sym sym) const = 0;
};

View file

@ -24,6 +24,7 @@
#include "Sym.h"
#include "sym_compile.h"
#include "TargetInfo.h"
#include "stats.h"
using namespace std;
@ -101,10 +102,72 @@ class ErrorMeasureImpl {
multi_function all = compile(pop);
std::vector<double> y(pop.size());
std::vector<double> err(pop.size());
Scaling noScaling = Scaling(new NoScaling);
const std::valarray<double>& t = train_info.targets();
if (measure == ErrorMeasure::mean_squared_scaled) {
std::vector<Var> var(pop.size());
std::vector<Cov> cov(pop.size());
Var vart;
for (unsigned i = 0; i < t.size(); ++i) {
vart.update(t[i]);
all(&data.get_inputs(i)[0], &y[0]); // evalutate
for (unsigned j = 0; j < pop.size(); ++j) {
var[j].update(y[j]);
cov[j].update(y[j], t[i]);
}
}
std::vector<ErrorMeasure::result> result(pop.size());
for (unsigned i = 0; i < pop.size(); ++i) {
// calculate scaling
double b = cov[i].get_cov() / var[i].get_var();
if (!finite(b)) {
result[i].scaling = noScaling;
result[i].error = vart.get_var(); // largest error
continue;
}
double a = vart.get_mean() - b * var[i].get_mean();
result[i].scaling = Scaling( new LinearScaling(a,b));
// calculate error
double c = cov[i].get_cov();
c *= c;
double err = vart.get_var() - c / var[i].get_var();
result[i].error = err;
if (!finite(err)) {
cout << "b " << b << endl;
cout << "var t " << vart.get_var() << endl;
cout << "var i " << var[i].get_var() << endl;
cout << "cov " << cov[i].get_cov() << endl;
for (unsigned j = 0; j < t.size(); ++j) {
all(&data.get_inputs(i)[0], &y[0]); // evalutate
cout << y[i] << endl;
}
exit(1);
}
}
return result;
}
std::vector<double> err(pop.size());
for (unsigned i = 0; i < train_cases(); ++i) {
// evaluate
all(&data.get_inputs(i)[0], &y[0]);
@ -124,10 +187,9 @@ class ErrorMeasureImpl {
std::vector<ErrorMeasure::result> result(pop.size());
double n = train_cases();
Scaling no = Scaling(new NoScaling);
for (unsigned i = 0; i < pop.size(); ++i) {
result[i].error = err[i] / n;
result[i].scaling = no;
result[i].scaling = noScaling;
}
return result;
@ -194,7 +256,7 @@ class ErrorMeasureImpl {
dresult = multi_function_eval(decloned);
break;
case ErrorMeasure::mean_squared_scaled:
dresult = single_function_eval(decloned);
dresult = multi_function_eval(decloned);
break;
}
@ -241,7 +303,6 @@ ErrorMeasure::result ErrorMeasure::calc_error(Sym sym) {
res.error = not_a_number;
return res;
}
}
return pimpl->eval(y);

View file

@ -28,7 +28,9 @@ class TargetInfo;
class ScalingBase {
public:
virtual ~ScalingBase() {}
std::valarray<double> apply(const std::valarray<double>& x) {
std::valarray<double> xtmp = x;
transform(xtmp);

View file

@ -20,7 +20,7 @@
#include <cassert>
#if __GNUC__ == 3
#if __GNUC__ >= 3
#include <backward/hash_map.h>
#else
#include <hash_map.h>

View file

@ -84,7 +84,7 @@ struct SymValue
unsigned getSize() const { return size; }
unsigned getDepth() const { return depth; }
private :
// for reference counting
unsigned refcount;
@ -92,8 +92,8 @@ struct SymValue
// some simple stats
unsigned size;
unsigned depth;
UniqueNodeStats* uniqueNodeStats;
};