/*
 *  Module:   expression evaluate
 *
 *  Library:  mathbase.a (project Riemann math routines)
 */

#include "master.h"

#define dcl_ef(n) static SCALAR n (a) SCALAR *a;

dcl_ef(e_neg) { return -a[0]; }
dcl_ef(e_add) { return a[0] + a[1]; }
dcl_ef(e_sub) { return a[0] - a[1]; }
dcl_ef(e_mul) { return a[0] * a[1]; }
dcl_ef(e_div) { return a[0] / a[1]; }
dcl_ef(e_pow) { return pow(a[0],a[1]); }
dcl_ef(e_ln) { return log(a[0]); }
dcl_ef(e_inv) { return 1/a[0]; }
dcl_ef(e_sqr) { return a[0]*a[0]; }
dcl_ef(e_exp) { return exp(a[0]); }
dcl_ef(e_sqrt) { return sqrt(a[0]); }
dcl_ef(e_dbl) { return a[0]*2; }
dcl_ef(e_half) { return a[0]/2; }
dcl_ef(e_sin) { return sin(a[0]); }
dcl_ef(e_cos) { return cos(a[0]); }

static double (*evalfuncs[])() = {
    e_neg, e_add, e_sub, e_mul, e_div, e_pow, e_ln, e_inv, e_sqr, e_exp,
    e_sqrt, e_dbl, e_half, e_sin, e_cos
    };

/*
 *   Function:	EvaluateExpr
 *
 *   Purpose:   This function evaluates the expression pexp_e at point
 *		pvct_p and returns the result of the evaluation.
 *
 *   Author:	Scott Goehring
 *
 */

SCALAR EvaluateExpr(pexp_e, pvct_p)
    PEXPR pexp_e;
    PVECTOR pvct_p;
{
    SCALAR *v;
    int i;
    switch(pexp_e->t) {
      case N_C:				/* constant */
	return pexp_e->x.c;
      case N_V:				/* variable */
	return pvct_p->vars[pexp_e->x.v];
      case N_F:				/* function */
	v = (SCALAR*)alloca((unsigned)(pexp_e->x.f.n*sizeof(SCALAR)));
	for (i=0; i<pexp_e->x.f.n; i++)
	    v[i] = EvaluateExpr(pexp_e->x.f.a[i],pvct_p);
	return (*evalfuncs[pexp_e->x.f.f])(v);
    }
    return 0.0;				/* in case of confusion */
}



/*
 *   Function:	SimplifyExpr
 *
 *   Purpose:   This function simplifies the expresion (*e) and returns
 *              the result of said simplification.  N.B.  The expression
 *              is simplified "in place".
 *
 *   Author:	Scott Goehring
 *
 */

static int inverse[] = { F_NEG, -1, F_NEG, -1, F_INV, -1, F_EXP,
			 F_INV, F_SQRT, F_LN, F_SQR, F_HALF, F_DBL,
			 -1, -1 };

void SimplifyExpr(e)
    PEXPR *e;
{
    PEXPR r;
    int i;
    int nv;

  loop:
    r = (*e);
    switch(r->t) {
      case N_C:				/* constant */
      case N_V:				/* variable */
	break;
      case N_F:				/* function */
	for (i = 0; i < r->x.f.n; i++)
	    if (r->x.f.a[i]->t == N_F)
		SimplifyExpr(&(r->x.f.a[i]));
	nv = 0;
	for (i = 0; i < r->x.f.n; i++)
	    if (r->x.f.a[i]->t == N_C)
		nv++;

	/* all parameters constant */
	if (nv == r->x.f.n) {
	    SCALAR *cv = (SCALAR *) alloca((unsigned)(r->x.f.n*sizeof(SCALAR)));
	    for (i = 0; i < r->x.f.n; i++)
		cv[i] = r->x.f.a[i]->x.c;
	    r = ExprConstNode((*evalfuncs[r->x.f.f])(cv));
#ifdef S_DEBUG
	    fprintf(stderr, "Constants collected\n");
#endif	    
	}
	 
	/* f(g(x)) where g is f's inverse */
	else if (r->x.f.n == 1 && r->x.f.a[0]->t == N_F &&
		 r->x.f.a[0]->x.f.f == inverse[r->x.f.f]) {
	    r = CopyExpr(r->x.f.a[0]->x.f.a[0]);
#ifdef S_DEBUG
	    fprintf(stderr, "Function/inverse pair eliminated\n");
#endif	    
	}
	else if ((r->x.f.f == F_MUL &&
		  ((r->x.f.a[0]->t == N_C && r->x.f.a[0]->x.c == 0.0) ||
		   (r->x.f.a[1]->t == N_C && r->x.f.a[1]->x.c == 0.0))) ||
		 (r->x.f.f == F_DIV && 
		  r->x.f.a[0]->t == N_C && r->x.f.a[0]->x.c == 0.0) ||
		 (r->x.f.f == F_SUB &&
		  EqualExpr(r->x.f.a[0],r->x.f.a[1])) ||
		 (r->x.f.f == F_POW &&
		  r->x.f.a[0]->t == N_C && r->x.f.a[0]->x.c == 0.0)) {
	    r = ExprConstNode(0.0);
#ifdef S_DEBUG
	    fprintf(stderr, 
              "Mult by 0, div of 0, power of 0, or a sub a elim\n");
#endif	    
	}
	else if (r->x.f.n == 2 && r->x.f.a[0]->t == N_C &&
		 (r->x.f.a[0]->x.c == 0.0 && r->x.f.f == F_ADD) ||
		 (r->x.f.a[0]->x.c == 1.0 && r->x.f.f == F_MUL)) {
	    r = CopyExpr(r->x.f.a[1]);
#ifdef S_DEBUG
	    fprintf(stderr, "Add of 0 or mult by 1 eliminated\n");
#endif	    
	}

	else if (r->x.f.n == 2 && r->x.f.a[1]->t == N_C &&
		 (r->x.f.a[1]->x.c == 1.0 && r->x.f.f == F_POW ||
		  r->x.f.a[1]->x.c == 1.0 && r->x.f.f == F_MUL ||
		  r->x.f.a[1]->x.c == 0.0 && r->x.f.f == F_ADD)) {
	    r = CopyExpr(r->x.f.a[0]);
#ifdef S_DEBUG
	    fprintf(stderr,
              "Add of 0, mult by 1, or raise to 1st eliminated\n");
#endif	    
	}
	else if (r->x.f.n == 2 && r->x.f.a[0]->t == N_C &&
		 (r->x.f.f == F_SUB && r->x.f.a[0]->x.c == 0.0) ||
		 (r->x.f.f == F_MUL && r->x.f.a[0]->x.c == -1.0)) {
	    r = ExprF1Node(F_NEG, CopyExpr(r->x.f.a[1]));
#ifdef S_DEBUG
	    fprintf(stderr, "Sub from 0 or mult by -1 simplified\n");
#endif	    
	}
	else if ((r->x.f.f == F_DIV &&
		  r->x.f.a[0]->t == N_C &&
		  r->x.f.a[0]->x.c == 1.0) ||
		 (r->x.f.f == F_POW &&
		  r->x.f.a[1]->t == N_C &&
		  r->x.f.a[0]->x.c == -1.0)) {
	    r = ExprF1Node(F_INV, CopyExpr(r->x.f.a[1]));
#ifdef S_DEBUG
	    fprintf(stderr, "Div of 1 or raise to -1st simplified\n");
#endif	    
	}
	else if (r->x.f.f == F_ADD &&
		 EqualExpr(r->x.f.a[0], r->x.f.a[1])) {
	    r = ExprF1Node(F_DBL, CopyExpr(r->x.f.a[0]));
#ifdef S_DEBUG
	    fprintf(stderr, "a plus a simplified\n");
#endif	    
	}
	else if (r->x.f.f == F_MUL &&
		 EqualExpr(r->x.f.a[0], r->x.f.a[1])) {
	    r = ExprF1Node(F_SQR, CopyExpr(r->x.f.a[0]));
#ifdef S_DEBUG
	    fprintf(stderr, "a times a simplified\n");
#endif	    
	}
    }
    if (r != (*e)) {
	DisposeExpr(*e);
	*e = r;
	goto loop;
    }
}

BOOL EqualExpr (pexp_1, pexp_2)
PEXPR pexp_1, pexp_2;
{
    BOOL r;
    int i;


    if (pexp_1->t != pexp_2->t) r=0;
    else switch (pexp_1->t) {
      case N_C:
	r=pexp_1->x.c == pexp_2->x.c;
	break;
      case N_V:
	r=pexp_1->x.v == pexp_2->x.v;
	break;
      case N_F:
	if (pexp_1->x.f.f != pexp_2->x.f.f ||
	    pexp_1->x.f.n != pexp_2->x.f.n) r=0;
	else {
	    r=1;
	    for(i=0; i<pexp_1->x.f.n; i++)
		if (!(r=EqualExpr(pexp_1->x.f.a[i],
				  pexp_2->x.f.a[i]))) break;
	}
	break;
    }
    return r;
}

