Created
January 31, 2014 10:44
-
-
Save H2CO3/8729897 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// derive.c | |
// just because | |
// (just because I wanted to demonstrate what the Sparkling API is good for) | |
// | |
// created by H2CO3 on 31/01/2014 | |
// use for good, not for evil | |
// | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <assert.h> | |
#include <spn/api.h> | |
#include <spn/parser.h> | |
#include <spn/ast.h> | |
// assumptions: | |
// - the argument of the function is called 'x' | |
// - only basic arithmetic (+, -, *, /) is used | |
// - the only functions called are sin, cos, tan, asin, acos, atan, exp, log, sinh, cosh | |
// - literal numbers and any identifier that is not 'x' is assumed to be a constant | |
static void real_dump(SpnAST *ast, int parens) | |
{ | |
static const char ops[] = "+-*/"; | |
switch (ast->node) { | |
case SPN_NODE_ADD: | |
case SPN_NODE_SUB: | |
case SPN_NODE_MUL: | |
case SPN_NODE_DIV: | |
if (parens) | |
printf("("); | |
real_dump(ast->left, 1); | |
printf(" %c ", ops[ast->node - SPN_NODE_ADD]); | |
real_dump(ast->right, 1); | |
if (parens) | |
printf(")"); | |
break; | |
case SPN_NODE_UNPLUS: | |
real_dump(ast->left, 1); | |
break; | |
case SPN_NODE_UNMINUS: | |
printf("-"); | |
real_dump(ast->left, 1); | |
break; | |
case SPN_NODE_LITERAL: | |
assert(ast->value.t == SPN_TYPE_NUMBER); | |
if (ast->value.f & SPN_TFLG_FLOAT) | |
printf("%g", ast->value.v.fltv); | |
else | |
printf("%ld", ast->value.v.intv); | |
break; | |
case SPN_NODE_FUNCCALL: | |
assert(ast->left->node == SPN_NODE_IDENT); | |
assert(ast->right->left == NULL); // only ONE argument, s'il vous plait | |
printf("%s(", ast->left->name->cstr); | |
real_dump(ast->right->right, 0); | |
printf(")"); | |
break; | |
case SPN_NODE_IDENT: | |
printf("%s", ast->name->cstr); | |
break; | |
default: | |
printf("\n\nError: unrecognized node/operation: %d\n", ast->node); | |
exit(-1); | |
break; | |
} | |
} | |
static void dump(SpnAST *ast) | |
{ | |
real_dump(ast, 0); | |
printf("\n\n"); | |
} | |
static SpnAST *copy_ast(SpnAST *orig) | |
{ | |
SpnAST *ast = spn_ast_new(orig->node, orig->lineno); | |
spn_value_retain(&orig->value); | |
ast->value = orig->value; | |
if (orig->name) { | |
spn_object_retain(orig->name); | |
ast->name = orig->name; | |
} | |
if (orig->left) | |
ast->left = copy_ast(orig->left); | |
if (orig->right) | |
ast->right = copy_ast(orig->right); | |
return ast; | |
} | |
static int is_constant(SpnAST *ast) | |
{ | |
if (ast->node == SPN_NODE_LITERAL) | |
return 1; | |
if (ast->node == SPN_NODE_IDENT) | |
if (strcmp(ast->name->cstr, "x") != 0) | |
return 1; | |
return 0; | |
} | |
static SpnAST *make_literal_zero(unsigned long lineno) | |
{ | |
SpnAST *ast = spn_ast_new(SPN_NODE_LITERAL, lineno); | |
ast->value = (SpnValue){ .t = SPN_TYPE_NUMBER, .f = 0, .v.intv = 0 }; | |
return ast; | |
} | |
static SpnAST *make_literal_one(unsigned long lineno) | |
{ | |
SpnAST *ast = spn_ast_new(SPN_NODE_LITERAL, lineno); | |
ast->value = (SpnValue){ .t = SPN_TYPE_NUMBER, .f = 0, .v.intv = 1 }; | |
return ast; | |
} | |
static SpnAST *derivative_func(SpnAST *ast) | |
{ | |
assert(ast->node == SPN_NODE_IDENT); | |
static const struct { | |
const char *f; | |
const char *fprime; | |
} dict[] = { | |
{ "sin", "cos" }, | |
{ "cos", "-sin" }, | |
{ "tan", "1 / cos^2" }, | |
{ "exp", "exp" }, | |
{ "ln", "1 / " }, | |
{ "sinh", "cosh" }, | |
{ "cosh", "sinh" } | |
}; | |
for (size_t i = 0; i < sizeof dict / sizeof dict[0]; i++) { | |
if (!strcmp(ast->name->cstr, dict[i].f)) { | |
SpnAST *ret = spn_ast_new(SPN_NODE_IDENT, ast->lineno); | |
ret->name = spn_string_new_nocopy(dict[i].fprime, 0); | |
return ret; | |
} | |
} | |
printf("\n\nUnrecognized function: %s\n", ast->name->cstr); | |
exit(-1); | |
return NULL; | |
} | |
static SpnAST *derivative(SpnAST *ast) | |
{ | |
switch (ast->node) { | |
case SPN_NODE_ADD: | |
case SPN_NODE_SUB: { | |
// constant optimization | |
if (is_constant(ast->left)) | |
return derivative(ast->right); | |
if (is_constant(ast->right)) | |
return derivative(ast->left); | |
SpnAST *ret = spn_ast_new(ast->node, ast->lineno); | |
ret->left = derivative(ast->left); | |
ret->right = derivative(ast->right); | |
return ret; | |
} | |
case SPN_NODE_MUL: { | |
if (is_constant(ast->left)) { | |
if (is_constant(ast->right)) | |
return make_literal_zero(ast->lineno); | |
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
ret->left = copy_ast(ast->left); | |
ret->right = derivative(ast->right); | |
return ret; | |
} | |
if (is_constant(ast->right)) { | |
if (is_constant(ast->left)) | |
return make_literal_zero(ast->lineno); | |
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
ret->left = derivative(ast->left); | |
ret->right = copy_ast(ast->right); | |
return ret; | |
} | |
SpnAST *fder_g = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
SpnAST *f_gder = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
fder_g->left = derivative(ast->left); | |
fder_g->right = copy_ast(ast->right); | |
f_gder->left = copy_ast(ast->left); | |
f_gder->right = derivative(ast->right); | |
SpnAST *ret = spn_ast_new(SPN_NODE_ADD, ast->lineno); | |
ret->left = fder_g; | |
ret->right = f_gder; | |
return ret; | |
} | |
case SPN_NODE_DIV: | |
if (is_constant(ast->right)) { | |
if (is_constant(ast->left)) { | |
return make_literal_zero(ast->lineno); | |
} | |
SpnAST *ret = spn_ast_new(SPN_NODE_DIV, ast->lineno); | |
ret->left = derivative(ast->left); | |
ret->right = copy_ast(ast->right); | |
return ret; | |
} | |
SpnAST *fder_g = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
fder_g->left = derivative(ast->left); | |
fder_g->right = copy_ast(ast->right); | |
SpnAST *f_gder = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
f_gder->left = copy_ast(ast->left); | |
f_gder->right = derivative(ast->right); | |
SpnAST *diff = spn_ast_new(SPN_NODE_SUB, ast->lineno); | |
diff->left = fder_g; | |
diff->right = f_gder; | |
SpnAST *g_squared = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
g_squared->left = copy_ast(ast->right); | |
g_squared->right = copy_ast(ast->right); | |
SpnAST *ret = spn_ast_new(SPN_NODE_DIV, ast->lineno); | |
ret->left = diff; | |
ret->right = g_squared; | |
return ret; | |
case SPN_NODE_UNPLUS: | |
return derivative(ast->left); | |
case SPN_NODE_UNMINUS: { | |
SpnAST *ret = spn_ast_new(SPN_NODE_UNMINUS, ast->lineno); | |
ret->left = derivative(ast->left); | |
return ret; | |
} | |
case SPN_NODE_LITERAL: | |
return make_literal_zero(ast->lineno); | |
case SPN_NODE_IDENT: | |
if (is_constant(ast)) // not 'x' | |
return make_literal_zero(ast->lineno); | |
else // it is 'x' --> dx/dx = 1 | |
return make_literal_one(ast->lineno); | |
case SPN_NODE_FUNCCALL: { | |
assert(ast->left->node == SPN_NODE_IDENT); | |
assert(ast->right->left == NULL); // only ONE argument, s'il vous plait | |
if (is_constant(ast->right->right)) { | |
return make_literal_zero(ast->lineno); | |
} | |
// optimization: let's NOT treat f'(x) as f'(x) * x' = f'(x) * 1 | |
if (ast->right->right->node == SPN_NODE_IDENT) { | |
// not a constant but an identifier --> it can only be 'x' | |
SpnAST *ret = spn_ast_new(SPN_NODE_FUNCCALL, ast->lineno); | |
ret->left = derivative_func(ast->left); | |
ret->right = copy_ast(ast->right); | |
return ret; | |
} | |
// else it's a function composition | |
SpnAST *fder_g = spn_ast_new(SPN_NODE_FUNCCALL, ast->lineno); | |
fder_g->left = derivative_func(ast->left); | |
fder_g->right = copy_ast(ast->right); | |
SpnAST *gder = derivative(ast->right->right); | |
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno); | |
ret->left = fder_g; | |
ret->right = gder; | |
return ret; | |
} | |
default: | |
printf("\n\nError: unrecognized node/operation: %d\n", ast->node); | |
exit(-1); | |
return NULL; | |
} | |
} | |
int main(int argc, char *argv[]) | |
{ | |
char *expr = strdup(argv[1]); | |
expr = realloc(expr, strlen(argv[1]) + 1 + 1); | |
expr[strlen(argv[1])] = ';'; | |
expr[strlen(argv[1]) + 1] = 0; | |
SpnParser *parser = spn_parser_new(); | |
SpnAST *ast = spn_parser_parse(parser, expr); | |
spn_parser_free(parser); | |
free(expr); | |
printf("f(x) = "); | |
dump(ast->left); | |
SpnAST *der = derivative(ast->left); | |
spn_ast_free(ast); | |
printf("f'(x) = "); | |
dump(der); | |
spn_ast_free(der); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment