////////////////////////////////////////////////////////////////////////////////
//
// File:            spirit_evaluator.c++
// Purpose:         Arithmetic Expression Parsing and Evaluation
// 
// Original Author: David Bergman http://blog.davber.com/about
// Second Author:   Joel Young <jdy@cs.brown.edu>
//
// License:         MIT
// Copyright:       See Below
//
////////////////////////////////////////////////////////////////////////////////
//
// History:
//   2008.02.27:
//     -- Added support for right-to-left operators for exponentiation
//        and negation
//     -- Added expect clauses to the grammar to throw errors for typos
//     -- Added outputs for postfix, prefix, and tree, and XML notation
//     -- Added detection for divide by zero and fractional roots of
//        negatives
//     -- Added compilation mode for command line argument version
//     -- Changed program name from parse_arith.cpp to spirit_evaluator.c++
//   2008.02.27:
//     -- David Bergman releases under MIT license
//     -- see comment at URL below
//   2006.07.06:
//     -- David Bergman posted original version on his blog
//     -- URL: http://blog.davber.com/2006/07/06/the-spirit-of-parsing/
//
////////////////////////////////////////////////////////////////////////////////
// 
// This program implements a simple arithmetic language using Boost.Spirit.
// 
// It does so by creating an AST and then implementing an evaluating visitor 
// for such nodes.  It also outputs the expression in infix, postfix,
// and prefix notation.  Furthermore it prints an xml coded version of
// the AST as well as a textual representation of the tree. 
//
// (It's not really a visitor, since it handles the traversal itself) 
//
// Compile with:
//
// g++ -o spirit_evaluator spirit_evaluator.c++ -O2
////////////////////////////////////////////////////////////////////////////////
// 
// Copyright (c) 2006-2008 David Bergman
// Copyright (c)      2008 Joel Young
// 
// License Information (MIT License):
// 
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the “Software”), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// 
// THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////
//
// Compile flags:
//
// Uncomment to enable parse time exceptions for consecutive 
// operands such as (6 6) or ((3-5)(5x6)).  Uncommenting will
// result in a 50x slowdown.  Also, note that this expression
// error will result in an incomplete parse and is checked
// for post-parse.
//#define CALC_DETECT_CONSECUTIVE_OPERANDS
//
////////
//
// Uncomment to enable commandline mode for use in scripts:
//#define CALC_SHELL_MODE
// You can then put lines like:
// c () { ~/bin/spirit_evaluator "$*" 1000; }
// t () { ~/bin/spirit_evaluator "$*" 10 1; }
// in your .bashrc to have a handy command line calculator after you
// copy your executable into your bin directory. 
// c 6x6^5 
// outputs
// 46656
// and 
// t 6x6^5
// outputs
// 46656
//   x      
//  / \__   
// 6     ^  
//      / \ 
//     6   5
// 
////////
//
// Uncomment to enable debugging output from spirit's parse system:
//#define BOOST_SPIRIT_DEBUG
//
////////////////////////////////////////////////////////////////////////////////

// Standard Includes:
#include <string>                             // string, getline
#include <iostream>                           // cout, cerr
#include <sstream>                            // istringstream, ostringstream
#include <iomanip>                            // setw, setfill
#include <map>                                // map
#include <cmath>                              // abs, pow, floor
#include <functional>                         // multiplies, plus, minus 
#include <algorithm>                          // max

// Boost Includes:
#include <boost/lexical_cast.hpp>                     // lexical_cast<>
#include <boost/function.hpp>                         // function<>
#include <boost/spirit/core.hpp>                      // rule, grammar
#include <boost/spirit/tree/ast.hpp>                  // ast_parse
#include <boost/spirit/tree/tree_to_xml.hpp>          // tree_to_xml
#include <boost/spirit/error_handling/exceptions.hpp> // assertion, 
                                                      // parser_error,
                                                      // throw_

using namespace std;
using namespace boost::spirit;
using namespace boost;

typedef tree_match<const char*>    treematch_t;
typedef treematch_t::tree_iterator tree_iter_t;

// Errors to check for during the parse:
enum Errors {
  close_expected
, expression_expected
#ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
, binary_operator_expected
#endif
};

// Assertions to use during the parse:
assertion<Errors> expect_close(close_expected);
assertion<Errors> expect_expression(expression_expected);

// Check for a missing binary operator
#ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
void cc(char const* first, char const* last) {
  throw_(first,binary_operator_expected);
}
#endif

struct expression : public grammar<expression> {
  // Explicit identifiers to switch properly when evaluating the AST
  static const int factorID = 1;
  static const int termID   = 2;
  static const int expID    = 3;
  static const int powerID  = 4;
  static const int unaryID  = 5;

  // Meta function from scanner type to a proper rule type
  template<typename ScannerT> struct definition {
    rule<ScannerT, parser_context<>, parser_tag<factorID > > factor_p;
    rule<ScannerT, parser_context<>, parser_tag<powerID  > > power_p;
    rule<ScannerT, parser_context<>, parser_tag<expID    > > exp_p;
    rule<ScannerT, parser_context<>, parser_tag<termID   > > term_p;
    rule<ScannerT, parser_context<>, parser_tag<unaryID  > > unary_p;
    rule<ScannerT> start_p;
    #ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
    rule<ScannerT> operator_p, parenthetical_p, missing_ops_p;
    #endif

    // Arithmetic expression grammar:
    definition(const expression& self) {

      #ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
      // short hand for the missing operator checks
      operator_p      = ch_p('^') | 
                        ch_p('x') | ch_p('/') | 
                        ch_p('-') | ch_p('+');
      parenthetical_p = ch_p('(') >> start_p >> ch_p(')');

      missing_ops_p   =  // Check for consecutive operands
          (ureal_p >> 
           (ureal_p - (operator_p >> ureal_p))[&cc])
        | (parenthetical_p >> 
           (ureal_p - (operator_p >> ureal_p))[&cc])
        | (ureal_p >> 
           (parenthetical_p - (operator_p >> parenthetical_p))[&cc])
        | (parenthetical_p >> 
           (parenthetical_p - (operator_p >> parenthetical_p))[&cc]);
      #endif
      
      factor_p =  // numbers or parentheticals:
        #ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
        missing_ops_p |
        #endif
        leaf_node_d[ureal_p] | (inner_node_d['(' >> 
                                             expect_expression(start_p) >> 
                                             expect_close(ch_p(')'))]
                               );
      power_p = // exponentials    (right-to-left)
        (factor_p >> root_node_d[ch_p('^')]) >> expect_expression(unary_p) | 
        factor_p;
      unary_p = // unary operators (right-to-left)
        (root_node_d[ch_p('-')]) >> expect_expression(unary_p) | power_p;
      term_p =  // multiplicatives (left-to-right) 
        unary_p >> *(root_node_d[ch_p('x')|'/'] >> expect_expression(unary_p));
      exp_p =   // additives       (left-to-right)
        term_p  >> *(root_node_d[ch_p('+')|'-'] >> expect_expression(term_p));
      start_p = expect_expression(exp_p);

      #ifdef BOOST_SPIRIT_DEBUG
      BOOST_SPIRIT_DEBUG_RULE(factor_p);
      BOOST_SPIRIT_DEBUG_RULE(power_p);
      BOOST_SPIRIT_DEBUG_RULE(unary_p);
      BOOST_SPIRIT_DEBUG_RULE(term_p);
      BOOST_SPIRIT_DEBUG_RULE(exp_p);
      BOOST_SPIRIT_DEBUG_RULE(start_p);
      #endif
    }
    
    // Specify the starting rule for the parse
    const rule<ScannerT> & start() const { return start_p; }
  };
};

// Map binary operators to operations
static map<char, function<double (double, double)> > op;

// Map unary operators to operations
static map<char, function<double (double)> >        uop;

// Convert AST to fully parenthesized infix notation:
string infix(const tree_iter_t& i)
{
  if (i->value.id() == expression::factorID) {
    // Simple numeric literal
    return string(i->value.begin(), i->value.end());
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    tree_iter_t j = i->children.begin();
    return string("(") + string(i->value.begin(),i->value.end()) 
                       + infix(j)
                + ")";
  } else {
    // Binary expression
    tree_iter_t j = i->children.begin();
    tree_iter_t k = i->children.begin()+1;
    if (*i->value.begin() == '^') {
      return string("(") + infix(j) 
                         + string(i->value.begin(),i->value.end())
                         + infix(k)
                  + ")";
    } else {
      return string("(") + infix(j) + " "
                         + string(i->value.begin(),i->value.end()) + " "
                         + infix(k)
                  + ")";
    }
  }
}
string infix(tree_parse_info<> t) { return infix(t.trees.begin()); }

// Convert AST to parenthesized prefix notation:
string prefix(const tree_iter_t& i) {
  if (i->value.id() == expression::factorID) {
    // Simple numeric literal
    return string(i->value.begin(), i->value.end());
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    return string("(") + string(i->value.begin(),i->value.end()) 
                       + " "
                       + prefix(i->children.begin())
                + ")";
  } else {
    // Binary expression
    return string("(") 
                       + string(i->value.begin(),i->value.end()) 
                       + " "
                       + prefix(i->children.begin()) 
                       + " "
                       + prefix(i->children.begin()+1) 
                + ")";
  }
}
string prefix(tree_parse_info<> t) { return prefix(t.trees.begin()); }

// Convert AST to postfix notation:
string postfix(const tree_iter_t& i) {
  if (i->value.id() == expression::factorID) {
    // Simple numeric literal
    return string(i->value.begin(), i->value.end());
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    return   postfix(i->children.begin())
           + " "
           + string(i->value.begin(),i->value.end()) 
           ;
  } else {
    // Binary expression
    return   postfix(i->children.begin()) 
           + " "
           + postfix(i->children.begin()+1) 
           + " "
           + string(i->value.begin(),i->value.end()) 
          ;
  }
}
string postfix(tree_parse_info<> t) { return postfix(t.trees.begin()); }

// Recursively computes the depth of an AST
unsigned int depth(const tree_iter_t& i) {
  if (i->value.id() == expression::factorID) {
    // Simple numeric literal
    return 1;
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    return 1+depth(i->children.begin());
  } else {
    // Binary expression
    unsigned int dl = depth(i->children.begin());
    unsigned int dr = depth(i->children.begin()+1);
    return 1 + std::max(dl,dr);
  }
}
unsigned int depth(tree_parse_info<> t) { return depth(t.trees.begin()); }

// Helpers for tree printing:
//typedef pair<string,std::pair<unsigned int,unsigned int> > return_t;
struct return_t { 
  string s; unsigned int l; unsigned int r;
  return_t(const string& _s, unsigned int _l, unsigned int _r) :
    s(_s), l(_l), r(_r) { }
};
static return_t print(const tree_iter_t&);

// The print functions return a string containing a multiline tree
// representation of the AST.  It works by recursively building up
// blocks of text containing sub-expressions and then merging the
// subexpressions.  The alignment is based on the assumption that all
// operators are one character wide.
string print(tree_parse_info<> t) { return print(t.trees.begin()).s; }
return_t print(const tree_iter_t& i) {
  if (i->value.id() == expression::factorID) {
    // Numeric literal
    string l = string(i->value.begin(), i->value.end());
    unsigned int s = l.size();
    unsigned int half = 
      static_cast<unsigned int>(std::floor(((double)l.size())/2.0));
    l = l + "\n";
    return return_t(l,half,s-half);
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    return_t operand = print(i->children.begin());
    unsigned int op_length = operand.l + operand.r;

    ostringstream ret;
    ret << setw(operand.l) << "" 
        << string(i->value.begin(),i->value.end()) 
        << setw(operand.r-1) << "" << "\n"; 
    ret << setw(operand.l) << "" 
        << "|" 
        << setw(operand.r-1) << "" << "\n"; 
    ret << operand.s ; 

    return return_t(ret.str(),operand.l,operand.r);
  } else {
    // Binary expression
    return_t lrand = print(i->children.begin());
    return_t rrand = print(i->children.begin()+1);
    istringstream  left(lrand.s);
    istringstream right(rrand.s);
    string l, r;
    getline(left,l);
    getline(right,r);
    unsigned int lsize = l.size();
    unsigned int rsize = r.size();
    unsigned int gap = (lsize < 3 or rsize < 2) ? 3 : 1;
    unsigned int width = lsize+rsize+gap;

    ostringstream o;
    // output operator row
    o << setw(lsize+(gap == 1 ? 0 : 1)) << "" 
      << string(i->value.begin(),i->value.end()) 
      << setw(rsize+(gap == 1 ? 0 : 1)) << "" << "\n";
    // output arrow row
    o << setw(lrand.l+(gap == 1 ?  1 :  1)) << "" 
      << setw(lrand.r+(gap == 1 ? -2 : -1)) << setfill('_') << "" 
      << "/ \\" 
      << setw(rrand.l+(gap == 1 ? -1 :  0)) << setfill('_') << "" 
      << setw(rrand.r+(gap == 1 ?  0 :  0)) << setfill(' ') << "" << "\n";
    // output subexpression rows
    do {
      if (l.empty()) { o << setw(lsize+0) << "";  } else { o << "" << l; }
      o << setw(gap) << "";
      if (r.empty()) { o << setw(rsize+0) << " "; } else { o << r; }
      o << '\n';
      
      if (left)  { getline(left,l);  } if (left.eof())  { l.erase(); }
      if (right) { getline(right,r); } if (right.eof()) { r.erase(); }
    } while (right or left);
    return return_t(o.str(),lsize+(gap==1 ? 0 : 1),rsize+(gap==1 ? 1 : 2));
  }
}

// Convert AST to xml
string xmlout(tree_parse_info<> t) { 
  // tell tree_to_xml how to print each of the rules
  std::map<parser_id, std::string> rule_names;
  rule_names[expression::factorID] = "double";
  rule_names[expression::termID]   = "multiplicative";
  rule_names[expression::expID]    = "additive";
  rule_names[expression::powerID]  = "power";
  rule_names[expression::unaryID]  = "unary";

  ostringstream out;
  tree_to_xml(out, t.trees, infix(t), rule_names);
  return out.str();
}            

// Exception class for division by zero
struct divide_by_zero : public std::runtime_error {
  divide_by_zero() :
    runtime_error("division by zero") { }
  divide_by_zero(const std::string& _what) :
    runtime_error(_what) { }
};

// Exception class for root_of_negative
struct root_of_negative : public std::runtime_error {
  root_of_negative() :
    runtime_error("fractional root of negative") { }
  root_of_negative(const std::string& _what) :
    runtime_error(_what) { }
};

// The evaluation function for the AST
double evaluate(const tree_iter_t& i)
{
  if (i->value.id() == expression::factorID) {
    // Simple numeric literal
    return lexical_cast<double>(string(i->value.begin(), i->value.end()));
  } else if (i->value.id() == expression::unaryID) {
    // Unary expression
    return uop[*i->value.begin()](evaluate(i->children.begin()));
  } else {
    // Binary expression
    try {
    return op[*i->value.begin()](evaluate(i->children.begin()),
                                 evaluate(i->children.begin() + 1));
    } catch (divide_by_zero& dvz) {
      throw divide_by_zero(std::string(dvz.what()) + 
                           std::string(" in ") + infix(i));
    } catch (root_of_negative& dvz) {
      throw root_of_negative(std::string(dvz.what()) + 
                           std::string(" in ") + infix(i));
    }
  }
}
double evaluate(tree_parse_info<> t) { return evaluate(t.trees.begin()); }

// Division function object
template <typename T>
struct divides {
  T operator()(const T& l, const T& r) {
    if (r == 0.0) {
      throw divide_by_zero();
    }
    return l/r;
  }
};

// Exponentiation function object
template <typename T>
struct raises {
  T operator()(const T& l, const T& r) {
    if (l < 0.0 and r != 0.0 and 
        (std::abs(r)-std::floor(std::abs(r))) > 0.0) {
      throw root_of_negative();
    }
    return std::pow((double)l,r);
  }
};

int main(int argc, char* argv[])
{
  // Initialize the binary operators:
  op['^'] = ::raises<double>();
  op['x'] = multiplies<double>();
  op['/'] = ::divides<double>();
  op['+'] = plus<double>();
  op['-'] = minus<double>();

  // Initialize the unary operators:
  uop['-'] = negate<double>();

  string input;
  #ifndef CALC_SHELL_MODE
  do {
    cout << "Please enter an expression (empty expression to exit): ";
    std::getline(std::cin,input); cout << input << endl;
    if (input == "") break;
  #else
    if (argc > 1) {
      input = string(argv[1]);
    }
    unsigned int precision = 10;
    if (argc > 2) {
      precision = lexical_cast<unsigned int>(argv[2]);
    }
  #endif
    expression my_exp;
    tree_parse_info<const char*> tree;
    // Most errors can be caught in the parse with parser exceptions
    try {
      tree = ast_parse(input.c_str(), my_exp, space_p);
    } catch (parser_error<Errors,char const*> x) {
      switch (x.descriptor) {
        case close_expected:
          cout << "Expected close parenthesis" << endl;
          break;
        case expression_expected:
          cout << "Expected operand" << endl;
          break;
        #ifdef CALC_DETECT_CONSECUTIVE_OPERANDS
        case binary_operator_expected:
          cout << "Consecutive Operands Not Allowed" << endl;
          break;
        #endif
      }
      std::cout << "-->| " << input << endl;
      std::cout << "at | " << setw(std::distance(input.c_str(),x.where)+1) 
                           << "^" << endl;
      #ifndef CALC_SHELL_MODE
      continue;
      #else
      return 1;
      #endif
    }

    if (tree.full) { // Did the parse complete:
      try {
        #ifndef CALC_SHELL_MODE
        cout << evaluate(tree) << endl;
        #else
        cout << setprecision(precision) << evaluate(tree) << endl;
        #endif
      } catch (std::runtime_error& e) {
        std::cout << e.what() << std::endl;
        #ifdef CALC_SHELL_MODE
        return 1;
        #endif
      }
      #ifndef CALC_SHELL_MODE
      cout << "Infix:   " << infix(tree) << endl;
      cout << "Postfix: " << postfix(tree) << endl;
      cout << "Prefix:  " << prefix(tree) << endl;
      cout << "XML:\n"    << xmlout(tree) << endl;
      cout << "Tree:\n"   << print(tree) << endl;
      #else
      if (argc > 3) {
        cout << print(tree) << endl;
      }
      #endif
    } else {         // Or was the expression rejected?
      // But open parenthesis and missing operators are very hard to
      // check for during the parse, so it is easier to check 
      // for an incomplete parse afterwards:
      switch (*tree.stop) {
        case ')':
          cout << "Expected open parenthesis\n";
          break;
        case '(':
        case '0': case '1': case '2': case '3': case '4':
        case '5': case '6': case '7': case '8': case '9':
          cout << "Expected binary operator\n";
          break;
        default:
          cout << "Not a number\n";
      }
      std::cout << "-->| " << input << endl;
      std::cout << "at | " << setw(std::distance(input.c_str(),tree.stop)+1) 
                           << "^" << endl;
    }
  #ifndef CALC_SHELL_MODE
  } while (input != "");
  #endif
}
