Added symbolic differentiation

This commit is contained in:
maartenkeijzer 2005-10-09 07:03:35 +00:00
commit ffdce66c98
12 changed files with 186 additions and 61 deletions

View file

@ -1,5 +1,5 @@
COMPILEFLAGS=-Wno-deprecated -g -Wall -Wshadow #-DINTERVAL_DEBUG COMPILEFLAGS=-Wno-deprecated -g -Wall -Wshadow #-DINTERVAL_DEBUG
OPTFLAGS= -O2 -DNDEBUG OPTFLAGS= -O3 -DNDEBUG
PROFILE_FLAGS=-pg PROFILE_FLAGS=-pg
@ -18,7 +18,7 @@ CXXSOURCES=FunDef.cpp Sym.cpp SymImpl.cpp SymOps.cpp sym_compile.cpp TreeBuilder
Dataset.cpp ErrorMeasure.cpp Scaling.cpp TargetInfo.cpp BoundsCheck.cpp util.cpp NodeSelector.cpp\ Dataset.cpp ErrorMeasure.cpp Scaling.cpp TargetInfo.cpp BoundsCheck.cpp util.cpp NodeSelector.cpp\
eoSymCrossover.cpp sym_operations.cpp eoSymMutate.cpp eoSymCrossover.cpp sym_operations.cpp eoSymMutate.cpp
TESTPROGRAMS=test/test_compile test/testeo test/test_simplify TESTPROGRAMS=test/test_compile test/testeo test/test_simplify test/test_diff
OBJS= $(CXXSOURCES:.cpp=.o) c_compile.o OBJS= $(CXXSOURCES:.cpp=.o) c_compile.o
@ -50,6 +50,9 @@ test/testeo: test/testeo.o ${SYMLIB}
test/test_simplify: test/test_simplify.o $(SYMLIB) test/test_simplify: test/test_simplify.o $(SYMLIB)
$(CXX) -o test/test_simplify test/test_simplify.o $(SYMLIB) ${LIBS} $(CXX) -o test/test_simplify test/test_simplify.o $(SYMLIB) ${LIBS}
test/test_diff: test/test_diff.o $(SYMLIB)
$(CXX) -o test/test_diff test/test_diff.o $(SYMLIB) ${LIBS}
# eo # eo
../../src/libeo.a: ../../src/libeo.a:
make -C ../../src make -C ../../src

View file

@ -29,13 +29,14 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
BoundsCheck& check; BoundsCheck& check;
ErrorMeasure& measure; ErrorMeasure& measure;
unsigned size_cap;
public: public:
eoSymPopEval(BoundsCheck& _check, ErrorMeasure& _measure) : eoSymPopEval(BoundsCheck& _check, ErrorMeasure& _measure, unsigned _size_cap) :
check(_check), measure(_measure) {} check(_check), measure(_measure), size_cap(_size_cap) {}
/** apparently this thing works on two populations, why? /** apparently this thing works on two populations,
* *
* In any case, currently only implemented the population wide * In any case, currently only implemented the population wide
* evaluation version, as that one is much faster. This because the * evaluation version, as that one is much faster. This because the
@ -51,7 +52,7 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
for (unsigned i = 0; i < p1.size(); ++i) { for (unsigned i = 0; i < p1.size(); ++i) {
if (p1[i].invalid()) { if (p1[i].invalid()) {
if (check.in_bounds(p1[i])) { if (p1[i].size() < size_cap && check.in_bounds(p1[i])) {
unevaluated.push_back(i); unevaluated.push_back(i);
tmppop.push_back( static_cast<Sym>(p1[i]) ); tmppop.push_back( static_cast<Sym>(p1[i]) );
} else { } else {
@ -63,7 +64,7 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
for (unsigned i = 0; i < p2.size(); ++i) { for (unsigned i = 0; i < p2.size(); ++i) {
if (p2[i].invalid()) { if (p2[i].invalid()) {
if (check.in_bounds(p2[i])) { if (p2[i].size() < size_cap && check.in_bounds(p2[i])) {
unevaluated.push_back(p1.size() + i); unevaluated.push_back(p1.size() + i);
tmppop.push_back( static_cast<Sym>(p2[i]) ); tmppop.push_back( static_cast<Sym>(p2[i]) );

View file

@ -54,7 +54,6 @@ bool mutate(Sym& sym, double p, const LanguageTable& table) {
bool mutate_constants(Sym& sym, double stdev) { bool mutate_constants(Sym& sym, double stdev) {
vector<double> values = get_constants(sym); vector<double> values = get_constants(sym);
if (values.empty()) { if (values.empty()) {

View file

@ -29,7 +29,6 @@ class eoSymSubtreeMutate : public eoMonOp<EoType> {
TreeBuilder& subtree_builder; TreeBuilder& subtree_builder;
NodeSelector& node_selector; NodeSelector& node_selector;
public : public :
eoSymSubtreeMutate(TreeBuilder& _subtree_builder, NodeSelector& _node_selector) eoSymSubtreeMutate(TreeBuilder& _subtree_builder, NodeSelector& _node_selector)

View file

@ -75,7 +75,9 @@ void write_entry(Sym sym, ostream& os, HashMap& map, unsigned out) {
multi_function compile(const std::vector<Sym>& syms) { multi_function compile(const std::vector<Sym>& syms) {
ostringstream os; // static stream to avoid fragmentation of these LARGE strings
static ostringstream os;
os.str("");
os << make_prototypes(); os << make_prototypes();
@ -88,11 +90,8 @@ multi_function compile(const std::vector<Sym>& syms) {
} }
os << ";}"; os << ";}";
string func_str = os.str();
//cout << "compiling " << func_str << endl; return (multi_function) symc_make(os.str().c_str(), "func");
return (multi_function) symc_make(func_str.c_str(), "func");
} }
single_function compile(Sym sym) { single_function compile(Sym sym) {
@ -140,7 +139,9 @@ string print_code(Sym sym, HashMap& map) {
void compile(const std::vector<Sym>& syms, std::vector<single_function>& functions) { void compile(const std::vector<Sym>& syms, std::vector<single_function>& functions) {
symc_init(); symc_init();
ostringstream os; static ostringstream os;
os.str("");
os << make_prototypes(); os << make_prototypes();
HashMap map; HashMap map;
for (unsigned i = 0; i < syms.size(); ++i) { for (unsigned i = 0; i < syms.size(); ++i) {
@ -153,6 +154,7 @@ void compile(const std::vector<Sym>& syms, std::vector<single_function>& functio
//cout << "compiling " << os.str() << endl; //cout << "compiling " << os.str() << endl;
} }
os << ends;
#ifdef INTERVAL_DEBUG #ifdef INTERVAL_DEBUG
//cout << "Compiling " << os.str() << endl; //cout << "Compiling " << os.str() << endl;
#endif #endif

View file

@ -482,6 +482,12 @@ string prototypes = "double pow(double, double);";
string get_prototypes() { return prototypes; } string get_prototypes() { return prototypes; }
unsigned add_prototype(string str) { prototypes += string("double ") + str + "(double);"; return prototypes.size(); } unsigned add_prototype(string str) { prototypes += string("double ") + str + "(double);"; return prototypes.size(); }
token_t add_function(FunDef* function, token_t where) {
if (language.size() <= where) language.resize(where+1);
language[where] = function;
return 0;
}
#define FUNCDEF(funcname) struct funcname##_struct { \ #define FUNCDEF(funcname) struct funcname##_struct { \
double operator()(double val) const { return funcname(val); }\ double operator()(double val) const { return funcname(val); }\
@ -489,20 +495,18 @@ unsigned add_prototype(string str) { prototypes += string("double ") + str + "(d
Interval operator()(Interval val) const { return funcname(val); }\ Interval operator()(Interval val) const { return funcname(val); }\
string name() const { return string(#funcname); }\ string name() const { return string(#funcname); }\
};\ };\
const token_t funcname##_token = add_function( new Unary<funcname##_struct>);\ static const token_t funcname##_token_static = add_function( new Unary<funcname##_struct>, funcname##_token);\
unsigned funcname##_size = add_prototype(#funcname); unsigned funcname##_size = add_prototype(#funcname);
FunDef* make_var(int idx) { return new Var(idx); } FunDef* make_var(int idx) { return new Var(idx); }
FunDef* make_const(double value) { return new Const(value); } FunDef* make_const(double value) { return new Const(value); }
const token_t sum_token = add_function( new Sum ); static token_t ssum_token = add_function( new Sum , sum_token);
const token_t prod_token = add_function( new Prod); static token_t sprod_token = add_function( new Prod, prod_token);
const token_t inv_token = add_function( new Unary<Inv>); static token_t sinv_token = add_function( new Unary<Inv>, inv_token);
const token_t min_token = add_function( new Unary<Min>); static token_t smin_token = add_function( new Unary<Min>, min_token);
const token_t pow_token = add_function( new Power); static token_t spow_token = add_function( new Power, pow_token);
const token_t ifltz_token = add_function( new IsNeg); static token_t sifltz_token = add_function( new IsNeg, ifltz_token);
FUNCDEF(sin); FUNCDEF(sin);
FUNCDEF(cos); FUNCDEF(cos);

View file

@ -88,25 +88,36 @@ extern Sym SymVar(unsigned idx);
/** simplifies a sym (sym_operations.cpp) */ /** simplifies a sym (sym_operations.cpp) */
extern Sym simplify(Sym sym); extern Sym simplify(Sym sym);
/** differentiates a sym to a token (sym_operations.cpp) */ /** differentiates a sym to a token (sym_operations.cpp)
extern Sym differentiate(Sym sym, token_t var_token); * The token can be a variable or a constant
*/
extern Sym differentiate(Sym sym, token_t dx);
struct differentiation_error{}; // thrown in case of ifltz
/* Add function to the language table (and take a guess at the arity) */ /* Add function to the language table (and take a guess at the arity) */
class LanguageTable; class LanguageTable;
extern void add_function_to_table(LanguageTable& table, token_t token); extern void add_function_to_table(LanguageTable& table, token_t token);
// token names enum {
extern const token_t sum_token; sum_token,
extern const token_t prod_token; prod_token,
extern const token_t inv_token; inv_token,
extern const token_t min_token; min_token,
extern const token_t pow_token; pow_token,
extern const token_t ifltz_token; ifltz_token,
sin_token, cos_token, tan_token,
asin_token, acos_token, atan_token,
sinh_token, cosh_token, tanh_token,
acosh_token, asinh_token, atanh_token,
exp_token, log_token,
sqr_token, sqrt_token
};
#define HEADERFUNC(name) extern const token_t name##_token;\
inline Sym name(Sym arg) { return Sym(name##_token, arg); } #define HEADERFUNC(name) inline Sym name(Sym arg) { return Sym(name##_token, arg); }
/* This defines the tokens: sin_token, cos_token, etc. */ /* This defines the tokens: sin_token, cos_token, etc. */
HEADERFUNC(inv);
HEADERFUNC(sin); HEADERFUNC(sin);
HEADERFUNC(cos); HEADERFUNC(cos);
HEADERFUNC(tan); HEADERFUNC(tan);

View file

@ -74,3 +74,100 @@ Sym simplify(Sym sym) {
} }
Sym derivative(token_t token, Sym x) {
Sym one = Sym(prod_token);
switch (token) {
case inv_token : return Sym(inv_token, sqr(x));
case sin_token : return -cos(x);
case cos_token : return sin(x);
case tan_token : return one + sqr(tan(x));
case asin_token : return inv( sqrt(one - sqr(x)));
case acos_token: return -inv( sqrt(one - sqr(x)));
case atan_token : return inv( sqrt(one + sqr(x)));
case cosh_token : return -sinh(x);
case sinh_token : return cosh(x);
case tanh_token : return one - sqr( tanh(x) );
case asinh_token : return inv( sqrt( one + sqr(x) ));
case acosh_token : return inv( sqrt(x-one) * sqrt(x + one) );
case atanh_token : return inv(one - sqr(x));
case exp_token : return exp(x);
case log_token : return inv(x);
case sqr_token : return SymConst(2.0) * x;
case sqrt_token : return SymConst(0.5) * inv( sqrt(x));
}
throw differentiation_error();
return x;
}
extern Sym differentiate(Sym sym, token_t dx) {
token_t token = sym.token();
Sym zero = Sym(sum_token);
Sym one = Sym(prod_token);
if (token == dx) {
return one;
}
SymVec args = sym.args();
if (args.size() == 0) { // df/dx with f != x
return zero;
}
switch (token) {
case sum_token:
{
for (unsigned i = 0; i < args.size(); ++i) {
args[i] = differentiate(args[i], dx);
}
if (args.size() == 1) return args[0];
return Sym(sum_token, args);
}
case min_token :
{
return -differentiate(args[0],dx);
}
case prod_token:
{
if (args.size() == 1) return differentiate(args[0], dx);
if (args.size() == 2) {
return args[0] * differentiate(args[1], dx) + args[1] * differentiate(args[0], dx);
}
// else
Sym c = args.back();
args.pop_back();
Sym f = Sym(prod_token, args);
Sym df = differentiate( f, dx);
return c * df + f * differentiate(c,dx);
}
case pow_token :
{
return pow(args[0], args[1]) * args[1] * inv(args[0]);
}
case ifltz_token :
{ // cannot be differentiated
throw differentiation_error(); // TODO define proper exception
}
default: // unary function: apply chain rule
{
Sym arg = args[0];
return derivative(token, arg) * differentiate(arg, dx);
}
}
}

View file

@ -166,6 +166,7 @@ class ErrorMeasureImpl {
return result; return result;
} }
std::vector<double> err(pop.size()); std::vector<double> err(pop.size());
for (unsigned i = 0; i < train_cases(); ++i) { for (unsigned i = 0; i < train_cases(); ++i) {
@ -319,6 +320,6 @@ double ErrorMeasure::worst_performance() const {
return pimpl->train_info.tvar(); return pimpl->train_info.tvar();
} }
return 1e+20; // TODO: make this general return 1e+20;
} }

View file

@ -194,6 +194,13 @@ int main(int argc, char* argv[]) {
't', 't',
"Population").value(); "Population").value();
unsigned maximumSize = parser.createParam(
0xffffffffu,
"maximum-size",
"Maximum size after crossover",
's',
"Population").value();
unsigned meas_param = parser.createParam( unsigned meas_param = parser.createParam(
2u, 2u,
"measure", "measure",
@ -205,6 +212,7 @@ int main(int argc, char* argv[]) {
'm', 'm',
"Regression").value(); "Regression").value();
ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled; ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled;
if (meas_param == 0) meas = ErrorMeasure::absolute; if (meas_param == 0) meas = ErrorMeasure::absolute;
if (meas_param == 1) meas = ErrorMeasure::mean_squared; if (meas_param == 1) meas = ErrorMeasure::mean_squared;
@ -253,7 +261,7 @@ int main(int argc, char* argv[]) {
// todo, make this parameter, etc // todo, make this parameter, etc
double std = 0.01; double std = 0.01;
eoSymConstantMutate<EoType> constmutate(std); eoSymConstantMutate<EoType> constmutate(std);
genetic_operator.add( constmutate, 0.1); genetic_operator.add(constmutate, 0.1);
eoSymNodeMutate<EoType> nodemutate(table); eoSymNodeMutate<EoType> nodemutate(table);
genetic_operator.add(nodemutate, node_mut_prob); genetic_operator.add(nodemutate, node_mut_prob);
@ -269,7 +277,7 @@ int main(int argc, char* argv[]) {
IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima()); IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
ErrorMeasure measure(dataset, train_percentage, meas); ErrorMeasure measure(dataset, train_percentage, meas);
eoSymPopEval<EoType> evaluator(check, measure); eoSymPopEval<EoType> evaluator(check, measure, maximumSize);
eoDetTournamentSelect<EoType> selectOne(tournamentsize); eoDetTournamentSelect<EoType> selectOne(tournamentsize);
eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator); eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator);

View file

@ -17,14 +17,14 @@
#include <utils/eoRNG.h> #include <utils/eoRNG.h>
#include "FunDef.h" #include <FunDef.h>
#include "sym_compile.h" #include <sym_compile.h>
#include "Dataset.h" #include <Dataset.h>
#include "ErrorMeasure.h" #include <ErrorMeasure.h>
#include "LanguageTable.h" #include <LanguageTable.h>
#include "BoundsCheck.h" #include <BoundsCheck.h>
#include "TreeBuilder.h" #include <TreeBuilder.h>
#include <iostream> #include <iostream>
@ -34,7 +34,7 @@ void test_xover();
int main() { int main() {
Dataset dataset; Dataset dataset;
dataset.load_data("problem4.dat"); dataset.load_data("test_data.txt");
cout << "Records/Fields " << dataset.n_records() << ' ' << dataset.n_fields() << endl; cout << "Records/Fields " << dataset.n_records() << ' ' << dataset.n_fields() << endl;

View file

@ -16,17 +16,17 @@
*/ */
#include "LanguageTable.h" #include <LanguageTable.h>
#include "TreeBuilder.h" #include <TreeBuilder.h>
#include "FunDef.h" #include <FunDef.h>
#include "Dataset.h" #include <Dataset.h>
#include "eoSymInit.h" #include <eoSymInit.h>
#include "eoSym.h" #include <eoSym.h>
#include "eoPop.h" #include <eoPop.h>
#include "eoSymMutate.h" #include <eoSymMutate.h>
#include "eoSymCrossover.h" #include <eoSymCrossover.h>
#include "eoSymEval.h" #include <eoSymEval.h>
typedef EoSym<double> EoType; typedef EoSym<double> EoType;
@ -106,7 +106,7 @@ int main() {
IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima()); IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
ErrorMeasure measure(dataset, 0.90, ErrorMeasure::mean_squared_scaled); ErrorMeasure measure(dataset, 0.90, ErrorMeasure::mean_squared_scaled);
eoSymPopEval<EoType> evaluator(check, measure); eoSymPopEval<EoType> evaluator(check, measure, 20000);
eoPop<EoType> dummy; eoPop<EoType> dummy;
evaluator(pop, dummy); evaluator(pop, dummy);