Added symbolic differentiation

This commit is contained in:
maartenkeijzer 2005-10-09 07:03:35 +00:00
commit ffdce66c98
12 changed files with 186 additions and 61 deletions

View file

@ -482,6 +482,12 @@ string prototypes = "double pow(double, double);";
string get_prototypes() { return prototypes; }
unsigned add_prototype(string str) { prototypes += string("double ") + str + "(double);"; return prototypes.size(); }
token_t add_function(FunDef* function, token_t where) {
if (language.size() <= where) language.resize(where+1);
language[where] = function;
return 0;
}
#define FUNCDEF(funcname) struct funcname##_struct { \
double operator()(double val) const { return funcname(val); }\
@ -489,20 +495,18 @@ unsigned add_prototype(string str) { prototypes += string("double ") + str + "(d
Interval operator()(Interval val) const { return funcname(val); }\
string name() const { return string(#funcname); }\
};\
const token_t funcname##_token = add_function( new Unary<funcname##_struct>);\
static const token_t funcname##_token_static = add_function( new Unary<funcname##_struct>, funcname##_token);\
unsigned funcname##_size = add_prototype(#funcname);
FunDef* make_var(int idx) { return new Var(idx); }
FunDef* make_const(double value) { return new Const(value); }
const token_t sum_token = add_function( new Sum );
const token_t prod_token = add_function( new Prod);
const token_t inv_token = add_function( new Unary<Inv>);
const token_t min_token = add_function( new Unary<Min>);
const token_t pow_token = add_function( new Power);
const token_t ifltz_token = add_function( new IsNeg);
static token_t ssum_token = add_function( new Sum , sum_token);
static token_t sprod_token = add_function( new Prod, prod_token);
static token_t sinv_token = add_function( new Unary<Inv>, inv_token);
static token_t smin_token = add_function( new Unary<Min>, min_token);
static token_t spow_token = add_function( new Power, pow_token);
static token_t sifltz_token = add_function( new IsNeg, ifltz_token);
FUNCDEF(sin);
FUNCDEF(cos);

View file

@ -88,25 +88,36 @@ extern Sym SymVar(unsigned idx);
/** simplifies a sym (sym_operations.cpp) */
extern Sym simplify(Sym sym);
/** differentiates a sym to a token (sym_operations.cpp) */
extern Sym differentiate(Sym sym, token_t var_token);
/** differentiates a sym to a token (sym_operations.cpp)
* The token can be a variable or a constant
*/
extern Sym differentiate(Sym sym, token_t dx);
struct differentiation_error{}; // thrown in case of ifltz
/* Add function to the language table (and take a guess at the arity) */
class LanguageTable;
extern void add_function_to_table(LanguageTable& table, token_t token);
// token names
extern const token_t sum_token;
extern const token_t prod_token;
extern const token_t inv_token;
extern const token_t min_token;
extern const token_t pow_token;
extern const token_t ifltz_token;
enum {
sum_token,
prod_token,
inv_token,
min_token,
pow_token,
ifltz_token,
sin_token, cos_token, tan_token,
asin_token, acos_token, atan_token,
sinh_token, cosh_token, tanh_token,
acosh_token, asinh_token, atanh_token,
exp_token, log_token,
sqr_token, sqrt_token
};
#define HEADERFUNC(name) extern const token_t name##_token;\
inline Sym name(Sym arg) { return Sym(name##_token, arg); }
#define HEADERFUNC(name) inline Sym name(Sym arg) { return Sym(name##_token, arg); }
/* This defines the tokens: sin_token, cos_token, etc. */
HEADERFUNC(inv);
HEADERFUNC(sin);
HEADERFUNC(cos);
HEADERFUNC(tan);

View file

@ -74,3 +74,100 @@ Sym simplify(Sym sym) {
}
Sym derivative(token_t token, Sym x) {
Sym one = Sym(prod_token);
switch (token) {
case inv_token : return Sym(inv_token, sqr(x));
case sin_token : return -cos(x);
case cos_token : return sin(x);
case tan_token : return one + sqr(tan(x));
case asin_token : return inv( sqrt(one - sqr(x)));
case acos_token: return -inv( sqrt(one - sqr(x)));
case atan_token : return inv( sqrt(one + sqr(x)));
case cosh_token : return -sinh(x);
case sinh_token : return cosh(x);
case tanh_token : return one - sqr( tanh(x) );
case asinh_token : return inv( sqrt( one + sqr(x) ));
case acosh_token : return inv( sqrt(x-one) * sqrt(x + one) );
case atanh_token : return inv(one - sqr(x));
case exp_token : return exp(x);
case log_token : return inv(x);
case sqr_token : return SymConst(2.0) * x;
case sqrt_token : return SymConst(0.5) * inv( sqrt(x));
}
throw differentiation_error();
return x;
}
extern Sym differentiate(Sym sym, token_t dx) {
token_t token = sym.token();
Sym zero = Sym(sum_token);
Sym one = Sym(prod_token);
if (token == dx) {
return one;
}
SymVec args = sym.args();
if (args.size() == 0) { // df/dx with f != x
return zero;
}
switch (token) {
case sum_token:
{
for (unsigned i = 0; i < args.size(); ++i) {
args[i] = differentiate(args[i], dx);
}
if (args.size() == 1) return args[0];
return Sym(sum_token, args);
}
case min_token :
{
return -differentiate(args[0],dx);
}
case prod_token:
{
if (args.size() == 1) return differentiate(args[0], dx);
if (args.size() == 2) {
return args[0] * differentiate(args[1], dx) + args[1] * differentiate(args[0], dx);
}
// else
Sym c = args.back();
args.pop_back();
Sym f = Sym(prod_token, args);
Sym df = differentiate( f, dx);
return c * df + f * differentiate(c,dx);
}
case pow_token :
{
return pow(args[0], args[1]) * args[1] * inv(args[0]);
}
case ifltz_token :
{ // cannot be differentiated
throw differentiation_error(); // TODO define proper exception
}
default: // unary function: apply chain rule
{
Sym arg = args[0];
return derivative(token, arg) * differentiate(arg, dx);
}
}
}