From 3ca7c0f6f473f79a15662fc0536c28bf590e0769 Mon Sep 17 00:00:00 2001 From: maartenkeijzer Date: Fri, 7 Oct 2005 13:31:20 +0000 Subject: [PATCH] Faster scaled evaluation, etc. --- eo/contrib/mathsym/eo_interface/eoSym.h | 4 +- eo/contrib/mathsym/eval/sym_compile.cpp | 7 -- eo/contrib/mathsym/fun/FunDef.cpp | 7 -- eo/contrib/mathsym/gen/LanguageTable.cpp | 2 - eo/contrib/mathsym/gen/NodeSelector.h | 2 + .../mathsym/regression/ErrorMeasure.cpp | 71 +++++++++++++++++-- eo/contrib/mathsym/regression/Scaling.h | 4 +- eo/contrib/mathsym/sym/Sym.h | 2 +- eo/contrib/mathsym/sym/SymImpl.h | 4 +- 9 files changed, 76 insertions(+), 27 deletions(-) diff --git a/eo/contrib/mathsym/eo_interface/eoSym.h b/eo/contrib/mathsym/eo_interface/eoSym.h index 11443af4..e30e06e1 100644 --- a/eo/contrib/mathsym/eo_interface/eoSym.h +++ b/eo/contrib/mathsym/eo_interface/eoSym.h @@ -29,7 +29,7 @@ class EoSym : public EO, public Sym { public: void set(const Sym& sym) { - invalidate(); + EO::invalidate(); static_cast(this)->operator=(sym); } @@ -54,7 +54,7 @@ void EoSym::readFrom(istream& is) { template inline std::ostream& operator<<(std::ostream& os, const EoSym& f) { f.printOn(os); return os; } template -inline istream& operator>>(std::istream& is, EoSym& f) { f.readFrom(is); return os; } +inline istream& operator>>(std::istream& is, EoSym& f) { f.readFrom(is); return is; } #endif diff --git a/eo/contrib/mathsym/eval/sym_compile.cpp b/eo/contrib/mathsym/eval/sym_compile.cpp index 4d3ce916..6351d716 100644 --- a/eo/contrib/mathsym/eval/sym_compile.cpp +++ b/eo/contrib/mathsym/eval/sym_compile.cpp @@ -16,13 +16,6 @@ */ -#if __GNUC__ == 3 -#include -#else -#include -using std::hash_map; -#endif - #include "Sym.h" #include "FunDef.h" #include "sym_compile.h" diff --git a/eo/contrib/mathsym/fun/FunDef.cpp b/eo/contrib/mathsym/fun/FunDef.cpp index 967f5986..9c1bfa7f 100644 --- a/eo/contrib/mathsym/fun/FunDef.cpp +++ b/eo/contrib/mathsym/fun/FunDef.cpp @@ -17,13 +17,6 @@ #include -#if __GNUC__ == 3 -#include -#else -#include -using std::hash_map; -#endif - #include "Sym.h" #include "FunDef.h" #include diff --git a/eo/contrib/mathsym/gen/LanguageTable.cpp b/eo/contrib/mathsym/gen/LanguageTable.cpp index c5906afb..7bda548f 100644 --- a/eo/contrib/mathsym/gen/LanguageTable.cpp +++ b/eo/contrib/mathsym/gen/LanguageTable.cpp @@ -23,8 +23,6 @@ using namespace std; -eoRng rng(time(0)); - extern Sym default_const(); class LanguageImpl { diff --git a/eo/contrib/mathsym/gen/NodeSelector.h b/eo/contrib/mathsym/gen/NodeSelector.h index 0a2863d9..5a93575b 100644 --- a/eo/contrib/mathsym/gen/NodeSelector.h +++ b/eo/contrib/mathsym/gen/NodeSelector.h @@ -23,6 +23,8 @@ class Sym; /** Base class for selecting nodes */ class NodeSelector { public: + virtual ~NodeSelector() {} + virtual unsigned select_node(Sym sym) const = 0; }; diff --git a/eo/contrib/mathsym/regression/ErrorMeasure.cpp b/eo/contrib/mathsym/regression/ErrorMeasure.cpp index b5937ce8..71d659a8 100644 --- a/eo/contrib/mathsym/regression/ErrorMeasure.cpp +++ b/eo/contrib/mathsym/regression/ErrorMeasure.cpp @@ -24,6 +24,7 @@ #include "Sym.h" #include "sym_compile.h" #include "TargetInfo.h" +#include "stats.h" using namespace std; @@ -101,10 +102,72 @@ class ErrorMeasureImpl { multi_function all = compile(pop); std::vector y(pop.size()); - std::vector err(pop.size()); + + Scaling noScaling = Scaling(new NoScaling); const std::valarray& t = train_info.targets(); + if (measure == ErrorMeasure::mean_squared_scaled) { + std::vector var(pop.size()); + std::vector cov(pop.size()); + + Var vart; + + for (unsigned i = 0; i < t.size(); ++i) { + vart.update(t[i]); + + all(&data.get_inputs(i)[0], &y[0]); // evalutate + + for (unsigned j = 0; j < pop.size(); ++j) { + var[j].update(y[j]); + cov[j].update(y[j], t[i]); + } + } + + std::vector result(pop.size()); + + for (unsigned i = 0; i < pop.size(); ++i) { + + // calculate scaling + double b = cov[i].get_cov() / var[i].get_var(); + + if (!finite(b)) { + result[i].scaling = noScaling; + result[i].error = vart.get_var(); // largest error + continue; + } + + double a = vart.get_mean() - b * var[i].get_mean(); + + result[i].scaling = Scaling( new LinearScaling(a,b)); + + // calculate error + double c = cov[i].get_cov(); + c *= c; + + double err = vart.get_var() - c / var[i].get_var(); + result[i].error = err; + if (!finite(err)) { + cout << "b " << b << endl; + cout << "var t " << vart.get_var() << endl; + cout << "var i " << var[i].get_var() << endl; + cout << "cov " << cov[i].get_cov() << endl; + + for (unsigned j = 0; j < t.size(); ++j) { + all(&data.get_inputs(i)[0], &y[0]); // evalutate + + cout << y[i] << endl; + } + + exit(1); + } + } + + return result; + } + + std::vector err(pop.size()); + for (unsigned i = 0; i < train_cases(); ++i) { // evaluate all(&data.get_inputs(i)[0], &y[0]); @@ -124,10 +187,9 @@ class ErrorMeasureImpl { std::vector result(pop.size()); double n = train_cases(); - Scaling no = Scaling(new NoScaling); for (unsigned i = 0; i < pop.size(); ++i) { result[i].error = err[i] / n; - result[i].scaling = no; + result[i].scaling = noScaling; } return result; @@ -194,7 +256,7 @@ class ErrorMeasureImpl { dresult = multi_function_eval(decloned); break; case ErrorMeasure::mean_squared_scaled: - dresult = single_function_eval(decloned); + dresult = multi_function_eval(decloned); break; } @@ -241,7 +303,6 @@ ErrorMeasure::result ErrorMeasure::calc_error(Sym sym) { res.error = not_a_number; return res; } - } return pimpl->eval(y); diff --git a/eo/contrib/mathsym/regression/Scaling.h b/eo/contrib/mathsym/regression/Scaling.h index 53d53d8b..efa98a01 100644 --- a/eo/contrib/mathsym/regression/Scaling.h +++ b/eo/contrib/mathsym/regression/Scaling.h @@ -28,7 +28,9 @@ class TargetInfo; class ScalingBase { public: - + + virtual ~ScalingBase() {} + std::valarray apply(const std::valarray& x) { std::valarray xtmp = x; transform(xtmp); diff --git a/eo/contrib/mathsym/sym/Sym.h b/eo/contrib/mathsym/sym/Sym.h index 0bf46bf0..3bbf2049 100644 --- a/eo/contrib/mathsym/sym/Sym.h +++ b/eo/contrib/mathsym/sym/Sym.h @@ -20,7 +20,7 @@ #include -#if __GNUC__ == 3 +#if __GNUC__ >= 3 #include #else #include diff --git a/eo/contrib/mathsym/sym/SymImpl.h b/eo/contrib/mathsym/sym/SymImpl.h index bc3538fb..105796c7 100644 --- a/eo/contrib/mathsym/sym/SymImpl.h +++ b/eo/contrib/mathsym/sym/SymImpl.h @@ -84,7 +84,7 @@ struct SymValue unsigned getSize() const { return size; } unsigned getDepth() const { return depth; } - private : + // for reference counting unsigned refcount; @@ -92,8 +92,8 @@ struct SymValue // some simple stats unsigned size; unsigned depth; - UniqueNodeStats* uniqueNodeStats; + };