/*
 * This file is part of a parser for an extension of the PRISM language.
 *
 * This is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * The parser is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with the program this parser part of.
 * If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2007-2010 Bjoern Wachter (Bjoern.Wachter@comlab.ox.ac.uk)
 * Copyright 2009-2010 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <limits>
#include <string>
#include <stdio.h>

#include "AST.h"
#include "Util.h"

#include <cvc3/vcl.h>
#include <cvc3/theory_arith.h>
#include <cvc3/theory_bitvector.h>
#include <cvc3/theory_arith.h>

#include "SymbolTable.h"
#include "ExprManager.h"
#include "Node.h"

#include "Property.h"
#include "Model.h"

#include "PRISMParser.h"
extern void PRISMparse();
extern FILE *PRISMin;

std::string file_name;

namespace prismparser {
  extern int line_number;
}

namespace prismparser {
  using namespace std;

  prismparser_ast::Model PRISMParser::astModel;

  prismparser_ast::Substitution constants;



  typedef std::tr1::unordered_map<string, CVC3::Expr> VarTable;

  void translateModel(prismparser_ast::Model& am, Model& m, VarTable& vt);

  PRISMParser::PRISMParser() {
    table = new VarTable();
  }

  PRISMParser::~PRISMParser() {
    ((VarTable*)table)->clear();
    delete (VarTable*) table;
  }

  void PRISMParser::run(const string& file, prismparser::Model &model) {
    line_number = 1;
    if (!(PRISMin = fopen(file.c_str(), "r"))) {
      throw RuntimeError("File "+file+" not found\n");
    }

    PRISMparse();
    
    translateModel(PRISMParser::astModel, model,*(VarTable*)table);
    PRISMParser::astModel.clear(); // avoid double insertions into ::model
    fclose(PRISMin);
  }

  CVC3::Expr translateBaseExpr(boost::shared_ptr<prismparser_ast::Expr> ae, const VarTable& vt) {
    CVC3::Expr result;
    assert(ae.get());
    prismparser_ast::Expr e(*ae.get());
    
    vector<CVC3::Expr> children(e.children.size ());
    for (unsigned i = 0; i < e.children.size (); i++) {
     children[i] = translateBaseExpr(e.children[i],vt);
    }

    switch(e.kind) {
    case prismparser_ast::Null:
      break;
    case prismparser_ast::Var:
      {
	VarTable::const_iterator i(vt.find(e.getIdentifier()));
	if(i!=vt.end())
	  result = i->second;
	else {
	  cout << "Table size " << vt.size() << "\n";
	  for(VarTable::const_iterator i=vt.begin(); i!=vt.end();++i) {
	    cout << "Table contents: " << i->first << "->"
		      << i->second.toString() << "\n";
	  }
	  throw ParseError("translateBaseExpr: unknown variable " + e.toString());
	}
      }
      break;
    case prismparser_ast::Bool:
      result = e.getBool() ? vc.trueExpr() : vc.falseExpr();
      break;
    case prismparser_ast::Int:
      result = vc.ratExpr(e.getInt(),1);
      break;
    case prismparser_ast::Double: {
      string repr(prismparser::floatToString(e.getDouble()));
      result = vc.ratExpr(repr,10);
    }
      break;

    case prismparser_ast::Not:
      result = vc.notExpr(children[0]);
      break;
    case prismparser_ast::And:
      result = vc.andExpr(children);
      break;
    case prismparser_ast::Or:
      result = vc.orExpr(children);
      break;
    case prismparser_ast::Eq:
      if(children[0].getType().isBool()) {
	result = vc.iffExpr(children[0],children[1]);
      } else {
	result = vc.eqExpr(children[0],children[1]);
      }
      break;
    case prismparser_ast::Neq:
      if(children[0].getType().isBool()) {
	result = vc.notExpr(vc.iffExpr(children[0],children[1]));
      } else {
	result = vc.notExpr(vc.eqExpr(children[0],children[1]));
      }
      break;
    case prismparser_ast::Lt:
      result = vc.ltExpr(children[0],children[1]);
      break;
    case prismparser_ast::Gt:
      result = vc.gtExpr(children[0],children[1]);
      break;
    case prismparser_ast::Le:
      result = vc.leExpr(children[0],children[1]);
      break;
    case prismparser_ast::Ge:
      result = vc.geExpr(children[0],children[1]);
      break;

    case prismparser_ast::Plus:
      result = vc.plusExpr(children[0],children[1]);
      break;
    case prismparser_ast::Minus:
      result = vc.minusExpr(children[0],children[1]);
      break;
    case prismparser_ast::Uminus:
      result = vc.uminusExpr(children[0]);
      break;
    case prismparser_ast::Mult:
      result = vc.multExpr(children[0],children[1]);
      break;
    case prismparser_ast::Div:
      result = vc.divideExpr(children[0],children[1]);
      break;
    case prismparser_ast::Mod:

      break;

    case prismparser_ast::Ite:
      result = vc.iteExpr(children[0],children[1],children[2]);
      break;
    case prismparser_ast::Min:
      result = vc.iteExpr(vc.leExpr(children[0],children[1]),children[0],children[1]);
      break;
    case prismparser_ast::Max:
      result = vc.iteExpr(vc.geExpr(children[0],children[1]),children[0],children[1]);
      break;

    case prismparser_ast::Apply:

      break;
    default:
      break;
    }
    if(result.isNull()) {
      throw ParseError("translateBaseExpr: could not translate " + e.toString());
    }
    return result;
  }

  /**
   * Transforms AST bound expression to Bound of concrete syntax tree.
   *
   * @param bound_expr abstract expression to be transformed
   * @param minimize true iff result should be for minimizing probabilities
   */
  Bound boundFromAST(const prismparser_ast::Expr& bound_expr, bool minimize) {
    Bound::Kind k;
    double bound(0.0);
    switch(bound_expr.kind) {
    case prismparser_ast::Gt:
      k = Bound::GR;  // >  bound ... greater
      assert(bound_expr.children[1]->isDouble());
      bound = bound_expr.children[1]->getDouble();
      break;
    case prismparser_ast::Ge:
      k = Bound::GEQ; // >= bound ... greater or equal
      assert(bound_expr.children[1]->isDouble());
      bound = bound_expr.children[1]->getDouble();
      break;
    case prismparser_ast::Lt:
      k = Bound::LE;  // <  bound ... strictly less
      assert(bound_expr.children[1]->isDouble());
      bound = bound_expr.children[1]->getDouble();
      break;
    case prismparser_ast::Le:
      k = Bound::LEQ; // <= bound ... less or equal
      assert(bound_expr.children[1]->isDouble());
      bound = bound_expr.children[1]->getDouble();
      break;
    case prismparser_ast::Eq:
      k = Bound::EQ; // =  bound ... equal
      assert(bound_expr.children[1]->isDouble());
      bound = bound_expr.children[1]->getDouble();
      break;
    default:
      k = Bound::DK;   // = ?      ... value to be computed
      break;
    }

    Bound b(k, bound, minimize);

    return b;
  }


  Property* translateProperty(boost::shared_ptr<prismparser_ast::Expr> ae, const VarTable& vt) {
    Property* result(0);
    assert(ae.get());
    prismparser_ast::Expr e(*ae.get());

    switch(e.kind) {
    case prismparser_ast::Next:
      if(e.children.size()==1)
	result = new Next(translateProperty(e.children[0],vt));
      else if(e.children.size()==3) {
	double a(-1), b (-1);
	Time::Kind k;
	switch(e.children[1]->kind) {
	case prismparser_ast::Null:
	  break;
	case prismparser_ast::Int:
	  a = e.children[1]->getInt();
	  break;
	case prismparser_ast::Double:
	  a = e.children[1]->getDouble();
	  break;
	default:
	  break;
	}

	switch(e.children[2]->kind) {
	case prismparser_ast::Null:
	  break;
	case prismparser_ast::Int:
	  b = e.children[2]->getInt();
	  break;
	case prismparser_ast::Double:
	  b = e.children[2]->getDouble();
	  break;
	default:
	  break;
	}

	if(a!=-1.0 && b!=-1.0) {
	  k = Time::INTERVAL;
	} else if(b != -1.0) {
	  a = 0;
	  k = Time::LE;
	} else if(a !=-1.0) {
	  k = Time::GE;
	  b = numeric_limits<double>::max();
	} else {
	  throw ParseError("Bad time bound");
	}
	Time t(k,a,b);
	result = new Next(t,translateProperty(e.children[0],vt));

      }

      break;
    case prismparser_ast::Until:
      {
	Property* p1(translateProperty(e.children[0],vt));
	Property* p2(translateProperty(e.children[1],vt));

	if(e.children.size()==2)
	  result = new Until(p1,p2);
	else if(e.children.size()==4){
	  double a(-1.0), b (-1.0);
	  Time::Kind k;
	  switch(e.children[2]->kind) {
	  case prismparser_ast::Null:
	    break;
	  case prismparser_ast::Int:
	    a = e.children[2]->getInt();
	    break;
	  case prismparser_ast::Double:
	    a = e.children[2]->getDouble();
	    break;
	  default:
	    break;
	  }

	  switch(e.children[3]->kind) {
	  case prismparser_ast::Null:
	    break;
	  case prismparser_ast::Int:
	    b = e.children[3]->getInt();
	    break;
	  case prismparser_ast::Double:
	    b = e.children[3]->getDouble();
	    break;
	  default:
	    break;
	  }

	  if(a!=-1.0 && b!=-1.0) {
	    k = Time::INTERVAL;
	  } else if(b != -1.0) {
	    a = 0;
	    k = Time::LE;
	  } else if(a !=-1.0) {
	    b = numeric_limits<double>::max();
	    k = Time::GE;
	  } else {
	    throw ParseError("Bad time bound");
	  }
	  Time t(k,a,b);
	  result = new Until(t,p1,p2);
	}
      }
      break;
    case prismparser_ast::P:
    case prismparser_ast::Pmin:
    case prismparser_ast::Pmax:
    case prismparser_ast::Steady:
    case prismparser_ast::SteadyMax:
    case prismparser_ast::SteadyMin: {
      const prismparser_ast::Expr& bound_expr(*e.children[0].get());
      Bound b(boundFromAST(bound_expr, (e.kind == prismparser_ast::Pmin) || (e.kind == prismparser_ast::SteadyMin)));
      if ((prismparser_ast::Steady == e.kind)
	  || (prismparser_ast::SteadyMax == e.kind)
	  || (prismparser_ast::SteadyMin == e.kind)) {
	result = new SteadyState(b, translateProperty(e.children[1],vt));;
      } else {
	result = new Quant(b, translateProperty(e.children[1],vt));;
      }
      break;
    }
      break;
    case prismparser_ast::ReachabilityReward:
      {
	Property* p(translateProperty(e.children[0],vt));
	result = new ReachabilityReward(p);
      }
      break;
    case prismparser_ast::CumulativeReward:
      {
	assert(e.children[0]->isDouble());
	double d(e.children[0]->getDouble());
	result = new CumulativeReward(d);
      }
      break;
    case prismparser_ast::InstantaneousReward:
      {
	assert(e.children[0]->isDouble());
	double d(e.children[0]->getDouble());
	result = new InstantaneousReward(d);
      }
      break;
    case prismparser_ast::SteadyStateReward:
    case prismparser_ast::SteadyStateRewardMax:
    case prismparser_ast::SteadyStateRewardMin:
      {
	const prismparser_ast::Expr& bound_expr(*e.children[0].get());
	Bound b(boundFromAST(bound_expr, (e.kind == prismparser_ast::SteadyStateRewardMin)));
	result = new SteadyStateReward(b);
	break;
      }
    case prismparser_ast::Not:
      {
      Property *inner = translateProperty(e.children[0],vt);
      if (expr == inner->kind) {
	PropExpr *innerProp = (PropExpr *) inner;
	result = new PropExpr(vc.notExpr(innerProp->getExpr()));
	delete inner;
      } else {
	result = new PropNeg(inner);
      }
      break;
      }
    case prismparser_ast::And:
      {
      Property *innerA = translateProperty(e.children[0],vt);
      Property *innerB = translateProperty(e.children[1],vt);
      if ((expr == innerA->kind) && (expr == innerB->kind)) {
	PropExpr *innerAProp = (PropExpr *) innerA;
	PropExpr *innerBProp = (PropExpr *) innerB;
	result = new PropExpr(vc.andExpr(innerAProp->getExpr(), innerBProp->getExpr()));
	delete innerA;
	delete innerB;
      } else {
	result = new PropBinary(PropBinary::AND,innerA, innerB);
      }
      break;
      }
    case prismparser_ast::Or:
      {
      Property *innerA = translateProperty(e.children[0],vt);
      Property *innerB = translateProperty(e.children[1],vt);
      if ((expr == innerA->kind) && (expr == innerB->kind)) {
	PropExpr *innerAProp = (PropExpr *) innerA;
	PropExpr *innerBProp = (PropExpr *) innerB;
	result = new PropExpr(vc.orExpr(innerAProp->getExpr(), innerBProp->getExpr()));
      } else {
	result = new PropBinary(PropBinary::OR,innerA, innerB);
      }
      break;
      }
    default:
      {
	CVC3::Expr nested_expr(translateBaseExpr(ae,vt));
	result = new PropExpr(nested_expr);
	break;
      }
    }
    return result;
  }

  CVC3::Expr translateExpr(boost::shared_ptr<prismparser_ast::Expr> ae, const VarTable& vt) {
    Property *prop = translateProperty(ae, vt);
    assert(prop->kind == expr);
    CVC3::Expr result(((PropExpr *)prop)->getExpr());
    delete prop;

    return result;
  }

  Alternative* translateAlternative(boost::shared_ptr<prismparser_ast::Alternative> aa, const VarTable& vt)
  {
    const prismparser_ast::Alternative& alternative(*aa.get());
    const prismparser_ast::Update& update (alternative.update);
    Alternative* result(new Alternative());

    for(prismparser_ast::Assignment::const_iterator i=update.assignment.begin();i!=update.assignment.end();++i)
      {
	CVC3::Expr lhs (translateExpr(i->first,vt));
	CVC3::Expr rhs (translateExpr(i->second,vt));
	result->Assign(lhs,rhs);
      }
    CVC3::Expr weight(translateExpr(alternative.weight,vt));
    result->setWeight(weight);
    return result;
  }

  Command* translateCommand(boost::shared_ptr<prismparser_ast::Command> ac, const VarTable& vt)
  {
    string label;
    boost::shared_ptr < prismparser_ast::Expr > guard;
    prismparser_ast::Alternatives alternatives;

    const prismparser_ast::Command& command (*ac.get());
    Command* result(new Command());



    for (prismparser_ast::Alternatives::const_iterator i (command.alternatives.begin ());
    	 i != command.alternatives.end (); ++i)
      {
	try {
	  result->addAlternative (translateAlternative(*i,vt));
	} catch(ParseError& p) {
	  throw ParseError("Alternative of command "+(*i)->toString() + "\n"
				 + " Reason: " +  p.toString() + "\n");
	}
      }

    try {
      result->setGuard(translateExpr(command.guard,vt));
    } catch(ParseError& p) {
      throw ParseError("Guard "+command.guard->toString() + "\n"
			     + " Reason: " +  p.toString() + "\n");
    }

    result->setAction(command.label);
    return result;
  }

  Module* translateModule(boost::shared_ptr<prismparser_ast::Module> am, const VarTable& vt)
  {
    const prismparser_ast::Module& module(*am.get());
    Module* result(new Module(module.name));
    for (prismparser_ast::Commands::const_iterator i (module.commands.begin ());
    	 i != module.commands.end (); ++i) {
      result->addCommand(translateCommand(*i,vt)) ;
    }
    
    return result;
  }

  void translateVariables(const prismparser_ast::Variables& vars, Model& model, VarTable& vt) {
    for (prismparser_ast::Variables::const_iterator i (vars.begin ());
	 i != vars.end (); ++i) {
	const prismparser_ast::Variable& var(*i->second.get());

	CVC3::Expr var_expr;

	switch(var.type->kind) {
	case prismparser_ast::Type::Boolean:
	  {
	    var_expr = (vc.varExpr(i->first,vc.boolType()));
	    model.addVariable(i->first,var_expr);
	    model.setDefaultInitialValue (var_expr, var.init.get() ? translateExpr(var.init,vt) : vc.falseExpr());
	  }
	  break;
	case prismparser_ast::Type::Integer:
	  {
	    var_expr = (vc.varExpr(i->first,vc.intType()));
	    model.addVariable(i->first,var_expr);
	    model.setDefaultInitialValue (var_expr, var.init.get() ? translateExpr(var.init,vt) : vc.ratExpr(0,1));
	  }
	  break;
	case prismparser_ast::Type::Double: {
	  var_expr = (vc.varExpr(i->first,vc.realType()));
	  model.addVariable(i->first,var_expr);
	  model.setDefaultInitialValue (var_expr, var.init.get() ? translateExpr(var.init,vt) : vc.ratExpr(0,1) );
	}
	  break;
	case prismparser_ast::Type::Bitvector: {
	  var_expr = (vc.varExpr(i->first,vc.bitvecType(var.type->bitvector_data.width)));
	  model.addVariable(i->first,var_expr);
	}
	  break;
	case prismparser_ast::Type::Range:
	  {
	    CVC3::Expr upper, lower;
	    try {

	      lower = translateExpr(var.type->range_data.lower,vt);
	      upper = translateExpr(var.type->range_data.upper,vt);
	    } catch (ParseError& p) {
	      throw ParseError("Range of variable "+ var.toString() + "\n"
				     + " Reason: " +  p.toString() + "\n");
	    }
	    var_expr = (vc.varExpr(i->first,vc.intType()));
	    model.addVariable(i->first,var_expr,lower,upper);

	    model.setDefaultInitialValue (var_expr, var.init.get() ? translateExpr(var.init,vt) : lower);
	  }
	  break;
	}

	vt.insert(pair<string,CVC3::Expr>(i->first, var_expr));
	if(var.is_parameter)
	  model.parameter_variables.insert(var_expr);

      }

  }

  void translateModel(prismparser_ast::Model& am, Model& model, VarTable& vt) {
    switch(am.model_type) {
    case prismparser_ast::DTMC:
      model.setModelType(DTMC);
      break;
    case prismparser_ast::MDP:
      model.setModelType(MDP);
      break;
    case prismparser_ast::CTMC:
      model.setModelType(CTMC);
      break;
    case prismparser_ast::CTMDP:
      model.setModelType(CTMDP);
      break;
    case prismparser_ast::Unspecified:
      model.setModelType(MDP);
      break;
    }

    /* 1) Variable table
     *
     * build the variable table by traversing the model
     * collecting variables from each module */

    /* global variables */
    translateVariables(am.globals,model,vt);

    /* local module variables */
    for (prismparser_ast::Modules::const_iterator i (am.modules.begin ());
	 i != am.modules.end (); ++i)
      {
	translateVariables(i->second->locals,model,vt);
      }

    /* 2) translate modules and add them to the model */
    for (prismparser_ast::Modules::const_iterator i (am.modules.begin ()); i != am.modules.end (); ++i)
      {
	model.addModule(translateModule(i->second,vt));
      }

    /* 3) translate the rest */

    // boost::shared_ptr < Expr > initial
    try {
      if(am.initial.get()) {
	CVC3::Expr e(translateExpr(am.initial,vt));
	model.setInitial(e);
      }

    } catch(ParseError& p) {
      throw ParseError("Initial condition "+ am.initial->toString() + "\n"
			     + " Reason: " +  p.toString() + "\n");
    }

    // Exprs invariants
    for (prismparser_ast::Exprs::const_iterator i=am.invariants.begin();i!=am.invariants.end();++i) {
      CVC3::Expr e(translateExpr(*i,vt));
      model.addInvariant(e);
    }

    // Actions actions;
    for (prismparser_ast::Actions::const_iterator i=am.actions.begin();i!=am.actions.end();++i) {
      model.addAction(*i);
    }

    // Exprs predicates;
    for (prismparser_ast::Exprs::const_iterator i=am.predicates.begin();i!=am.predicates.end();++i) {
      CVC3::Expr e(translateExpr(*i,vt));
      model.addPredicate(e);
    }

    // Exprs invariants
    for (prismparser_ast::Exprs::const_iterator i=am.properties.begin();i!=am.properties.end();++i) {
      Property* p(translateProperty(*i,vt));
      model.addProperty(p);
    }

    // StateRewards state_rewards;
    for (prismparser_ast::StateRewards::const_iterator i=am.state_rewards.begin();i!=am.state_rewards.end();++i) {
      CVC3::Expr guard (translateExpr(i->first,vt));
      CVC3::Expr reward(translateExpr(i->second,vt));
      model.addStateReward(guard, reward);
    }

    // TransitionRewards transition_rewards;
    for (prismparser_ast::TransitionRewards::const_iterator i=am.transition_rewards.begin();i!=am.transition_rewards.end();++i) {
      Action a(i->first);
      CVC3::Expr guard(translateExpr(i->second.first,vt));
      CVC3::Expr reward(translateExpr(i->second.second,vt));

      if ("" == a) {
	model.addTransReward(guard, reward);
      } else {
	model.addTransReward(a,guard, reward);
      }
    }
  }
}

