/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PARAM.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009-2011 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include "cvc3/vcl.h"
#include "CVC3Converter.h"
#include "Base.h"
#include "RationalFunction.h"

using namespace std;

namespace prismparser {
  extern CVC3::VCL vc;
}

namespace rational {
  CVC3Converter::CVC3Converter() {
  }

  RationalFunction CVC3Converter::convert(const CVC3::Expr &cvc3Expr) {
    if (cvc3Expr.isRational()) {
      CVC3::Rational rate_rat = cvc3Expr.getRational();
      RationalFunction num(rate_rat.getNumerator().getInt());
      RationalFunction den(rate_rat.getDenominator().getInt());
      return num / den;
    } else if (cvc3Expr.isVar()) {
      string name(cvc3Expr.toString());
      ParamNumbersMap::iterator iter(paramNumbers.find(name));
      if (paramNumbers.end() != iter) {
        vector<unsigned> monomial;
        monomial.resize(Base::getNumSymbols());
        monomial[iter->second] = 1;
        RationalFunction res(1, monomial);
        return res;
      } else {
	throw runtime_error("Variable \"" + name + "\" was not found in parameter list.");
      }
    } else if (cvc3Expr.getKind() == 3004) { // PLUS
      RationalFunction result(0);
      unsigned arity(cvc3Expr.arity());
      for (unsigned partNr(0); partNr < arity; partNr++) {
	result += convert(cvc3Expr[partNr]);
      }
      return result;
    } else if (cvc3Expr.getKind() == 3005) { // MINUS
      RationalFunction result(convert(cvc3Expr[0]));
      unsigned arity(cvc3Expr.arity());
      for (unsigned partNr(1); partNr < arity; partNr++) {
	result -= convert(cvc3Expr[partNr]);
      }
      return result;
    } else if (cvc3Expr.getKind() == 3006) { // TIMES
      RationalFunction result(1);
      unsigned arity(cvc3Expr.arity());
      for (unsigned partNr(0); partNr < arity; partNr++) {
	result *= convert(cvc3Expr[partNr]);
      }
      return result;
    } else if (cvc3Expr.getKind() == 3007) { // DIV
      RationalFunction result(convert(cvc3Expr[0]));
      unsigned arity(cvc3Expr.arity());
      for (unsigned partNr(1); partNr < arity; partNr++) {
	result /= convert(cvc3Expr[partNr]);
      }
      return result;
    } else if (cvc3Expr.getKind() == 3008) { // POWER
      CVC3::Expr pow(prismparser::vc.simplify(cvc3Expr[0]));
      if (!pow.isRational()
	  || (1 != pow.getRational().getDenominator().getInt())) {
	throw runtime_error("Expression \"" + pow.toString() + "\" in \""
			    + cvc3Expr.toString() + "\" is not an integer.");
      }
      int powInt(pow.getRational().getNumerator().getInt());
      unsigned upow(abs(powInt));

      RationalFunction res(1);
      if (0 == upow) {
	res = 1;
      } else {
	RationalFunction r(convert(cvc3Expr[1]));
	while (1 != upow) {
	  if (1 & upow) {
	    res *= r;
	  }
	  upow >>= 1;
	  r *= r;
	}
	res *= r;
      }
      
      if (0 > powInt) {
	res = 1 / res;
      }

      return res;
    } else if (cvc3Expr.getKind() == 121) { // ITE
      CVC3::Expr truthVal(prismparser::vc.simplify(cvc3Expr[0]));
      if (truthVal.isTrue()) {
	return convert(cvc3Expr[1]);
      } else if (truthVal.isFalse()) {
	return convert(cvc3Expr[2]);
      } else {
	cout << "truth value of \"" << cvc3Expr[0] << "\" could not be evaluated." << endl;
	exit(1);
      }
    } else {
      cout << "TO BE DONE: " << cvc3Expr << " IS " << cvc3Expr.getKind() << endl;
      exit(1);
    }
  }

  RationalFunction CVC3Converter::operator()(const CVC3::Expr &cvc3Expr) {
    vector<string> &symbols(Base::symbols);
    paramNumbers.clear();
    for (unsigned symbolNr(0); symbolNr < symbols.size(); symbolNr++) {
      paramNumbers.insert(make_pair(symbols[symbolNr], symbolNr));
    }

    return convert(cvc3Expr);
  }

  void CVC3Converter::addSymbols(const CVC3::Expr &expr) {
    if (expr.isVar()) {
      string name(expr.toString());
      ParamNumbersMap::iterator iter(paramNumbers.find(name));
      if (paramNumbers.end() == iter) {
        Base::symbols.push_back(name);
        paramNumbers.insert(make_pair(name, Base::getNumSymbols() - 1));
      }
    } else {
      for (CVC3::Expr::iterator iter(expr.begin()); iter != expr.end(); iter++) {
        addSymbols(*iter);
      }
    }
  }
}
