Added symbolic differentiation

This commit is contained in:
maartenkeijzer 2005-10-09 07:03:35 +00:00
commit ffdce66c98
12 changed files with 186 additions and 61 deletions

View file

@ -1,5 +1,5 @@
COMPILEFLAGS=-Wno-deprecated -g -Wall -Wshadow #-DINTERVAL_DEBUG
OPTFLAGS= -O2 -DNDEBUG
OPTFLAGS= -O3 -DNDEBUG
PROFILE_FLAGS=-pg
@ -18,7 +18,7 @@ CXXSOURCES=FunDef.cpp Sym.cpp SymImpl.cpp SymOps.cpp sym_compile.cpp TreeBuilder
Dataset.cpp ErrorMeasure.cpp Scaling.cpp TargetInfo.cpp BoundsCheck.cpp util.cpp NodeSelector.cpp\
eoSymCrossover.cpp sym_operations.cpp eoSymMutate.cpp
TESTPROGRAMS=test/test_compile test/testeo test/test_simplify
TESTPROGRAMS=test/test_compile test/testeo test/test_simplify test/test_diff
OBJS= $(CXXSOURCES:.cpp=.o) c_compile.o
@ -50,6 +50,9 @@ test/testeo: test/testeo.o ${SYMLIB}
test/test_simplify: test/test_simplify.o $(SYMLIB)
$(CXX) -o test/test_simplify test/test_simplify.o $(SYMLIB) ${LIBS}
test/test_diff: test/test_diff.o $(SYMLIB)
$(CXX) -o test/test_diff test/test_diff.o $(SYMLIB) ${LIBS}
# eo
../../src/libeo.a:
make -C ../../src

View file

@ -29,13 +29,14 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
BoundsCheck& check;
ErrorMeasure& measure;
unsigned size_cap;
public:
eoSymPopEval(BoundsCheck& _check, ErrorMeasure& _measure) :
check(_check), measure(_measure) {}
eoSymPopEval(BoundsCheck& _check, ErrorMeasure& _measure, unsigned _size_cap) :
check(_check), measure(_measure), size_cap(_size_cap) {}
/** apparently this thing works on two populations, why?
/** apparently this thing works on two populations,
*
* In any case, currently only implemented the population wide
* evaluation version, as that one is much faster. This because the
@ -51,7 +52,7 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
for (unsigned i = 0; i < p1.size(); ++i) {
if (p1[i].invalid()) {
if (check.in_bounds(p1[i])) {
if (p1[i].size() < size_cap && check.in_bounds(p1[i])) {
unevaluated.push_back(i);
tmppop.push_back( static_cast<Sym>(p1[i]) );
} else {
@ -63,7 +64,7 @@ class eoSymPopEval : public eoPopEvalFunc<EoType> {
for (unsigned i = 0; i < p2.size(); ++i) {
if (p2[i].invalid()) {
if (check.in_bounds(p2[i])) {
if (p2[i].size() < size_cap && check.in_bounds(p2[i])) {
unevaluated.push_back(p1.size() + i);
tmppop.push_back( static_cast<Sym>(p2[i]) );

View file

@ -54,7 +54,6 @@ bool mutate(Sym& sym, double p, const LanguageTable& table) {
bool mutate_constants(Sym& sym, double stdev) {
vector<double> values = get_constants(sym);
if (values.empty()) {

View file

@ -29,7 +29,6 @@ class eoSymSubtreeMutate : public eoMonOp<EoType> {
TreeBuilder& subtree_builder;
NodeSelector& node_selector;
public :
eoSymSubtreeMutate(TreeBuilder& _subtree_builder, NodeSelector& _node_selector)

View file

@ -75,7 +75,9 @@ void write_entry(Sym sym, ostream& os, HashMap& map, unsigned out) {
multi_function compile(const std::vector<Sym>& syms) {
ostringstream os;
// static stream to avoid fragmentation of these LARGE strings
static ostringstream os;
os.str("");
os << make_prototypes();
@ -88,11 +90,8 @@ multi_function compile(const std::vector<Sym>& syms) {
}
os << ";}";
string func_str = os.str();
//cout << "compiling " << func_str << endl;
return (multi_function) symc_make(func_str.c_str(), "func");
return (multi_function) symc_make(os.str().c_str(), "func");
}
single_function compile(Sym sym) {
@ -140,7 +139,9 @@ string print_code(Sym sym, HashMap& map) {
void compile(const std::vector<Sym>& syms, std::vector<single_function>& functions) {
symc_init();
ostringstream os;
static ostringstream os;
os.str("");
os << make_prototypes();
HashMap map;
for (unsigned i = 0; i < syms.size(); ++i) {
@ -153,10 +154,11 @@ void compile(const std::vector<Sym>& syms, std::vector<single_function>& functio
//cout << "compiling " << os.str() << endl;
}
os << ends;
#ifdef INTERVAL_DEBUG
//cout << "Compiling " << os.str() << endl;
#endif
symc_compile(os.str().c_str());
symc_link();

View file

@ -482,6 +482,12 @@ string prototypes = "double pow(double, double);";
string get_prototypes() { return prototypes; }
unsigned add_prototype(string str) { prototypes += string("double ") + str + "(double);"; return prototypes.size(); }
token_t add_function(FunDef* function, token_t where) {
if (language.size() <= where) language.resize(where+1);
language[where] = function;
return 0;
}
#define FUNCDEF(funcname) struct funcname##_struct { \
double operator()(double val) const { return funcname(val); }\
@ -489,20 +495,18 @@ unsigned add_prototype(string str) { prototypes += string("double ") + str + "(d
Interval operator()(Interval val) const { return funcname(val); }\
string name() const { return string(#funcname); }\
};\
const token_t funcname##_token = add_function( new Unary<funcname##_struct>);\
static const token_t funcname##_token_static = add_function( new Unary<funcname##_struct>, funcname##_token);\
unsigned funcname##_size = add_prototype(#funcname);
FunDef* make_var(int idx) { return new Var(idx); }
FunDef* make_const(double value) { return new Const(value); }
const token_t sum_token = add_function( new Sum );
const token_t prod_token = add_function( new Prod);
const token_t inv_token = add_function( new Unary<Inv>);
const token_t min_token = add_function( new Unary<Min>);
const token_t pow_token = add_function( new Power);
const token_t ifltz_token = add_function( new IsNeg);
static token_t ssum_token = add_function( new Sum , sum_token);
static token_t sprod_token = add_function( new Prod, prod_token);
static token_t sinv_token = add_function( new Unary<Inv>, inv_token);
static token_t smin_token = add_function( new Unary<Min>, min_token);
static token_t spow_token = add_function( new Power, pow_token);
static token_t sifltz_token = add_function( new IsNeg, ifltz_token);
FUNCDEF(sin);
FUNCDEF(cos);

View file

@ -88,25 +88,36 @@ extern Sym SymVar(unsigned idx);
/** simplifies a sym (sym_operations.cpp) */
extern Sym simplify(Sym sym);
/** differentiates a sym to a token (sym_operations.cpp) */
extern Sym differentiate(Sym sym, token_t var_token);
/** differentiates a sym to a token (sym_operations.cpp)
* The token can be a variable or a constant
*/
extern Sym differentiate(Sym sym, token_t dx);
struct differentiation_error{}; // thrown in case of ifltz
/* Add function to the language table (and take a guess at the arity) */
class LanguageTable;
extern void add_function_to_table(LanguageTable& table, token_t token);
// token names
extern const token_t sum_token;
extern const token_t prod_token;
extern const token_t inv_token;
extern const token_t min_token;
extern const token_t pow_token;
extern const token_t ifltz_token;
enum {
sum_token,
prod_token,
inv_token,
min_token,
pow_token,
ifltz_token,
sin_token, cos_token, tan_token,
asin_token, acos_token, atan_token,
sinh_token, cosh_token, tanh_token,
acosh_token, asinh_token, atanh_token,
exp_token, log_token,
sqr_token, sqrt_token
};
#define HEADERFUNC(name) extern const token_t name##_token;\
inline Sym name(Sym arg) { return Sym(name##_token, arg); }
#define HEADERFUNC(name) inline Sym name(Sym arg) { return Sym(name##_token, arg); }
/* This defines the tokens: sin_token, cos_token, etc. */
HEADERFUNC(inv);
HEADERFUNC(sin);
HEADERFUNC(cos);
HEADERFUNC(tan);

View file

@ -74,3 +74,100 @@ Sym simplify(Sym sym) {
}
Sym derivative(token_t token, Sym x) {
Sym one = Sym(prod_token);
switch (token) {
case inv_token : return Sym(inv_token, sqr(x));
case sin_token : return -cos(x);
case cos_token : return sin(x);
case tan_token : return one + sqr(tan(x));
case asin_token : return inv( sqrt(one - sqr(x)));
case acos_token: return -inv( sqrt(one - sqr(x)));
case atan_token : return inv( sqrt(one + sqr(x)));
case cosh_token : return -sinh(x);
case sinh_token : return cosh(x);
case tanh_token : return one - sqr( tanh(x) );
case asinh_token : return inv( sqrt( one + sqr(x) ));
case acosh_token : return inv( sqrt(x-one) * sqrt(x + one) );
case atanh_token : return inv(one - sqr(x));
case exp_token : return exp(x);
case log_token : return inv(x);
case sqr_token : return SymConst(2.0) * x;
case sqrt_token : return SymConst(0.5) * inv( sqrt(x));
}
throw differentiation_error();
return x;
}
extern Sym differentiate(Sym sym, token_t dx) {
token_t token = sym.token();
Sym zero = Sym(sum_token);
Sym one = Sym(prod_token);
if (token == dx) {
return one;
}
SymVec args = sym.args();
if (args.size() == 0) { // df/dx with f != x
return zero;
}
switch (token) {
case sum_token:
{
for (unsigned i = 0; i < args.size(); ++i) {
args[i] = differentiate(args[i], dx);
}
if (args.size() == 1) return args[0];
return Sym(sum_token, args);
}
case min_token :
{
return -differentiate(args[0],dx);
}
case prod_token:
{
if (args.size() == 1) return differentiate(args[0], dx);
if (args.size() == 2) {
return args[0] * differentiate(args[1], dx) + args[1] * differentiate(args[0], dx);
}
// else
Sym c = args.back();
args.pop_back();
Sym f = Sym(prod_token, args);
Sym df = differentiate( f, dx);
return c * df + f * differentiate(c,dx);
}
case pow_token :
{
return pow(args[0], args[1]) * args[1] * inv(args[0]);
}
case ifltz_token :
{ // cannot be differentiated
throw differentiation_error(); // TODO define proper exception
}
default: // unary function: apply chain rule
{
Sym arg = args[0];
return derivative(token, arg) * differentiate(arg, dx);
}
}
}

View file

@ -165,7 +165,8 @@ class ErrorMeasureImpl {
return result;
}
std::vector<double> err(pop.size());
for (unsigned i = 0; i < train_cases(); ++i) {
@ -319,6 +320,6 @@ double ErrorMeasure::worst_performance() const {
return pimpl->train_info.tvar();
}
return 1e+20; // TODO: make this general
return 1e+20;
}

View file

@ -194,6 +194,13 @@ int main(int argc, char* argv[]) {
't',
"Population").value();
unsigned maximumSize = parser.createParam(
0xffffffffu,
"maximum-size",
"Maximum size after crossover",
's',
"Population").value();
unsigned meas_param = parser.createParam(
2u,
"measure",
@ -204,7 +211,8 @@ int main(int argc, char* argv[]) {
",
'm',
"Regression").value();
ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled;
if (meas_param == 0) meas = ErrorMeasure::absolute;
if (meas_param == 1) meas = ErrorMeasure::mean_squared;
@ -253,7 +261,7 @@ int main(int argc, char* argv[]) {
// todo, make this parameter, etc
double std = 0.01;
eoSymConstantMutate<EoType> constmutate(std);
genetic_operator.add( constmutate, 0.1);
genetic_operator.add(constmutate, 0.1);
eoSymNodeMutate<EoType> nodemutate(table);
genetic_operator.add(nodemutate, node_mut_prob);
@ -269,7 +277,7 @@ int main(int argc, char* argv[]) {
IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
ErrorMeasure measure(dataset, train_percentage, meas);
eoSymPopEval<EoType> evaluator(check, measure);
eoSymPopEval<EoType> evaluator(check, measure, maximumSize);
eoDetTournamentSelect<EoType> selectOne(tournamentsize);
eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator);

View file

@ -17,14 +17,14 @@
#include <utils/eoRNG.h>
#include "FunDef.h"
#include "sym_compile.h"
#include <FunDef.h>
#include <sym_compile.h>
#include "Dataset.h"
#include "ErrorMeasure.h"
#include "LanguageTable.h"
#include "BoundsCheck.h"
#include "TreeBuilder.h"
#include <Dataset.h>
#include <ErrorMeasure.h>
#include <LanguageTable.h>
#include <BoundsCheck.h>
#include <TreeBuilder.h>
#include <iostream>
@ -34,7 +34,7 @@ void test_xover();
int main() {
Dataset dataset;
dataset.load_data("problem4.dat");
dataset.load_data("test_data.txt");
cout << "Records/Fields " << dataset.n_records() << ' ' << dataset.n_fields() << endl;

View file

@ -16,17 +16,17 @@
*/
#include "LanguageTable.h"
#include "TreeBuilder.h"
#include "FunDef.h"
#include "Dataset.h"
#include <LanguageTable.h>
#include <TreeBuilder.h>
#include <FunDef.h>
#include <Dataset.h>
#include "eoSymInit.h"
#include "eoSym.h"
#include "eoPop.h"
#include "eoSymMutate.h"
#include "eoSymCrossover.h"
#include "eoSymEval.h"
#include <eoSymInit.h>
#include <eoSym.h>
#include <eoPop.h>
#include <eoSymMutate.h>
#include <eoSymCrossover.h>
#include <eoSymEval.h>
typedef EoSym<double> EoType;
@ -106,7 +106,7 @@ int main() {
IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
ErrorMeasure measure(dataset, 0.90, ErrorMeasure::mean_squared_scaled);
eoSymPopEval<EoType> evaluator(check, measure);
eoSymPopEval<EoType> evaluator(check, measure, 20000);
eoPop<EoType> dummy;
evaluator(pop, dummy);