/*
 * This file is part of a parser for an extension of the PRISM language.
 *
 * This is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * The parser is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with the program this parser part of.
 * If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2007-2010 Bjoern Wachter (Bjoern.Wachter@comlab.ox.ac.uk)
 * Copyright 2009-2010 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <cvc3/vcl.h>
#include <cvc3/theory_arith.h>
#include <cvc3/theory_bitvector.h>
#include <cvc3/theory_arith.h>

#include "SymbolTable.h"
#include "ExprManager.h"
#include "Util.h"

/** BEGIN FIX **/
namespace std {
  namespace tr1 {
    template<> struct hash< pair<CVC3::Expr,CVC3::Expr> > {
      size_t operator()( const pair<CVC3::Expr,CVC3::Expr>& x) const {
	size_t hash = 0;
	hash = x.first.hash() + (hash << 6) + (hash << 16) - hash;
	hash = x.second.hash() +  (hash << 6) + (hash << 16) - hash;
	return hash;
      }
    };
    template<> struct hash<CVC3::Expr> {
      size_t operator()(CVC3::Expr x) const {
	return x.hash();
      }
    };
  }
}
/** END FIX **/

namespace prismparser {
  using namespace std;

  /* initialize CVC3 engine */
  CVC3::CLFlags flags(CVC3::ValidityChecker::createFlags());
  CVC3::VCL vc(flags);

  void ExprManager::Init() {
    vc.reprocessFlags();
    vc.push();
  }

  void ExprManager::Done() {
    vc.pop();
  }

  CVC3::Expr ExprManager::Conjunction(const std::vector<CVC3::Expr>& vec) {
    CVC3::Expr result;
    switch(vec.size()) {
    case 0:
      result = vc.trueExpr();
      break;
    case 1:
      result = vec[0];
      break;
    default:
      result = vc.andExpr(vec);
      break;
    }
    return result;

  }

  CVC3::Expr ExprManager::Conjunction(const std::set<CVC3::Expr>& s) {
    std::vector<CVC3::Expr> vec(s.begin(),s.end());
    return Conjunction(vec);
  }


  CVC3::Expr ExprManager::Disjunction(const std::vector<CVC3::Expr>& vec) {
    CVC3::Expr result;
    switch(vec.size()) {
    case 0:
      result = vc.falseExpr();
      break;
    case 1:
      result = vec[0];
      break;
    default:
      result = vc.orExpr(vec);
      break;
    }
    return result;
  }



  void ExprManager::getTopLevelConjuncts(const CVC3::Expr& e, std::vector<CVC3::Expr>& result) {
    assert(!e.isNull());
    switch(e.getKind()) {
    case CVC3::AND: {
      const std::vector< CVC3::Expr > & kids = e.getKids();
      for(int i=0;i<e.arity(); getTopLevelConjuncts(kids[i++],result));
      break;
    }
    default: {
      result.push_back(e);
      break;
    }
    }
  }

  void ExprManager::getTopLevelConjuncts(const CVC3::Expr& e, std::set<CVC3::Expr>& result) {
    switch(e.getKind()) {
    case CVC3::AND: {
      const std::vector< CVC3::Expr > & kids = e.getKids();
      getTopLevelConjuncts(kids[0],result);
      getTopLevelConjuncts(kids[1],result);
      break;
    }
    default: {
      result.insert(e);
      break;
    }
    }
  }



  void ExprManager::getDisjointSupportDecomposition(const CVC3::Expr& e,
						    std::vector<CVC3::Expr>& result) {
    std::vector<CVC3::Expr> conjuncts;
    getTopLevelConjuncts(e,conjuncts);
    std::vector< std::vector<CVC3::Expr> > decomp;
    SplitExprSet(conjuncts,decomp);
    result.resize(decomp.size());
    for(unsigned i=0;i<decomp.size();++i) {
      result[i] = Conjunction(decomp[i]);
    }
  }



  CVC3::ExprHashMap<bool> TrueCache;
  CVC3::ExprHashMap<bool> FalseCache;


  lbool ExprManager::queryCaches(const CVC3::ExprHashMap<bool>& cache1,
				 const CVC3::ExprHashMap<bool>& cache2,
				 const CVC3::Expr& e) {
    lbool result = l_undef;
    CVC3::ExprHashMap<bool>::const_iterator i = cache1.find(e);
    if(i!=cache1.end()) {
      result = i->second ? l_true : l_false;
    } else {
      CVC3::ExprHashMap<bool>::const_iterator j = cache2.find(e);
      if(j != cache2.end() && j->second == true) {
	result = l_false;
      }
    }
    return result;
  }

  lbool ExprManager::Solve(const CVC3::Expr& e) {
    lbool result = l_undef;
    vc.push();
    vc.assertFormula(e);
    int vc_result = vc.query(vc.falseExpr());
    if (CVC3::VALID == vc_result) {
      result = l_false;
    } else if (CVC3::INVALID == vc_result) {
      result = l_true;       
    }
    vc.pop();
    return result;
  }



  bool ExprManager::IsTrue(const CVC3::Expr& e) {
    // phase 1: syntactic
    if(e.isBoolConst()) return e.isTrue();

    // phase 2: service from cache
    switch(queryCaches(TrueCache,FalseCache,e)) {
    case l_true: return true;
    case l_false: return false;
    default: break;
    }

    // phase 3: decompose and try to service from cache or SMT
    bool result = false;
    switch(Solve(!e)) {
    case l_false:
      result = true;
      break;
    case l_true:
      return false;
      break;
    default:
      assert(false);
      break;
    }
    TrueCache.insert(e,result);
    return result;
  }



  bool ExprManager::IsFalse(const CVC3::Expr& e) {

    // phase 1: syntactic
    if(e.isBoolConst()) return e.isFalse();

    // phase 2: service from cache
    switch(queryCaches(FalseCache,TrueCache,e)) {
    case l_true: return true;
    case l_false: return false;
    default: break;
    }

    // phase 3: decompose and try to service from cache or SMT
    bool result = false;
    switch(Solve(e)) {
    case l_false:
      result = true;
      break;
    case l_true:
      result = false;
      break;
    case l_undef:
      break;
    }
    FalseCache.insert(e,result);
    return result;
  }

  typedef map<pair<CVC3::Expr,CVC3::Expr>,bool > DisjCacheType;
  DisjCacheType DisjCache;

  bool ExprManager::DisjointWith(const CVC3::Expr& e1, const CVC3::Expr& e2) {
    bool result = false;

    pair<CVC3::Expr,CVC3::Expr> query;
    if(e1 == e2) {
      return false;
    } else if(e1>e2) {
      query.first = e2;
      query.second = e1;
    } else {
      query.first  = e1;
      query.second = e2;
    }

    DisjCacheType::iterator i = DisjCache.find(query);
    if(i!=DisjCache.end()) {
      result = i->second;
    } else {
      switch (DisjointWithSyntactic(query.first,query.second)) {
      case l_true: result = true; break;
      case l_false: result = false; break;
      case l_undef: {
	std::vector<CVC3::Expr> decomp;
	getDisjointSupportDecomposition(vc.andExpr(query.first,query.second),decomp);
	result = false;
	for(std::vector<CVC3::Expr>::const_iterator i=decomp.begin();i!=decomp.end();++i) {
	  if(IsFalse(*i))
	    {
	      result = true;
	      break;
	    }
	}
      }
	break;
      }
      DisjCache[query] = result;
    }
    return result;
  }

  bool IsStrictInequality(const CVC3::Expr& e,
			  CVC3::Expr& bound,
			  CVC3::Expr& x,
			  int& kind) {

    bool result = false;
    kind = e.getKind();
    switch(kind) {
    case CVC3::GT: {
      const std::vector<CVC3::Expr>& kids = e.getKids();
      CVC3::Expr kid1 = kids[0];
      CVC3::Expr kid2 = kids[1];
      bound = kid1;
      x = kid2;

      result = true;
    }
      break;
    case CVC3::LT: {
      const std::vector<CVC3::Expr>& kids = e.getKids();
      CVC3::Expr kid1 = kids[0];
      CVC3::Expr kid2 = kids[1];
      bound = kid1;
      x = kid2;
      result = true;
    }
      break;
    }
    return result;
  }

  /*

    decompose an inequality:
    e = lower <= e <= upper

    special cases: lower <= e <= e
    e <= e <= upper

  */
  bool ExprManager::IsInterval(const CVC3::Expr& e,
			       CVC3::Expr& lower,
			       CVC3::Expr& x,
			       CVC3::Expr& upper) {
    bool result = false;

    switch(e.getKind()) {
    case CVC3::AND: {
      const std::vector<CVC3::Expr>& kids = e.getKids();
      CVC3::Expr kid1 = kids[0];
      CVC3::Expr kid2 = kids[1];

      if(kid1.getKind() == CVC3::LE && kid2.getKind() == CVC3::GE) {
	const std::vector<CVC3::Expr>& kids1 = kid1.getKids();
	const std::vector<CVC3::Expr>& kids2 = kid2.getKids();

	if(kids1[1] == kids2[1]) {
	  lower = kids1[0];
	  x     = kids1[1];
	  upper = kids2[0];
	  result = true;
	}
      } else if(kid1.getKind() == CVC3::GE && kid2.getKind() == CVC3::LE) {
	const std::vector<CVC3::Expr>& kids1 = kid1.getKids();
	const std::vector<CVC3::Expr>& kids2 = kid2.getKids();

	if(kids1[1] == kids2[1]) {
	  lower = kids2[0];
	  x     = kids1[1];
	  upper = kids1[0];
	  result = true;
	}
      }
    }
      break;
    case CVC3::GE: {
      const std::vector<CVC3::Expr>& kids = e.getKids();
      CVC3::Expr kid1 = kids[0];
      CVC3::Expr kid2 = kids[1];
      lower = kid2;
      x     = kid2;
      upper = kid1;
      result = true;
    }
      break;
    case CVC3::LE: {
      const std::vector<CVC3::Expr>& kids = e.getKids();
      CVC3::Expr kid1 = kids[0];
      CVC3::Expr kid2 = kids[1];
      lower = kid1;
      x     = kid2;
      upper = kid2;
      result = true;
    }
      break;
    }
    return result;
  }

  lbool ExprManager::DisjointWithSyntactic(const CVC3::Expr& e1, const CVC3::Expr& e2) {
    if(e1 == e2) return l_false;

    if(!LvalueCompatible(e1,e2))
      return l_false;

    lbool result = l_undef;

    // arithmetic: case : a <= e AND b>= e     VS      c <= e AND d >= e



    CVC3::Expr a,x,b,c,y,d;

    int kind1 = -1;
    int kind2 = -1;

    CVC3::Expr test;

    // a <= x <= b
    if (IsInterval(e1,a,x,b)) {
      // c <= y <= d
      if(IsInterval(e2,c,y,d) && x == y) {
	/* test if intervals empty */

	/* first interval unconstrained on the left ... and symmetric cases */
	if(a==x || d==x)      // cases: (-inf,b], [c,d] and [a,b], [c,inf)
	  test = vc.leExpr(c,b);
	else if(b==x || c==x) // cases : [a,inf), [c,d] and [a,b], (-inf,d]
	  test = vc.leExpr(a,d);
	else
	  test = vc.orExpr(vc.andExpr(vc.leExpr(c,a),vc.leExpr(a,d)),
			   vc.andExpr(vc.leExpr(a,c),vc.leExpr(c,b)) );

	/* 			else */
      }
      // y < d or y > d
      else if(IsStrictInequality(e2,d,y,kind2) && x == y) {
	switch(kind2) {
	case CVC3::GT:
	  if(a==x) { // (-inf,d), (-inf,b]
	    test = vc.trueExpr(); // definitely intersects
	  } else if(b==x) { // (-inf,d), [a,inf)
	    test = vc.ltExpr(a,d);
	  } else { // (-inf,d), [a,b]
	    test = vc.ltExpr(a,d);
	  }
	  break;
	case CVC3::LT:
	  if(a==x) { // (d,inf), (-inf,b]
	    test = vc.ltExpr(d,b);
	  } else if(b==x) { // (d,inf), [a,inf)
	    test = vc.trueExpr();
	  } else { // (d,inf), [a,b]
	    test = vc.ltExpr(d,b);
	  }
	  break;
	default:
	  break;
	}

      }
    } else if(IsStrictInequality(e1,a,x,kind1)) {
      if(IsInterval(e2,c,y,d) && x == y) {
	switch(kind1) {
	case CVC3::GT:
	  if(d==x) {     // case: (-inf,a), [c,inf)
	    test = vc.ltExpr(c,a);
	  }
	  else if(c==x) {// case : (-inf,a), (-inf,d]
	    test = vc.trueExpr();
	  }
	  else { // case: (-inf,a), [c,d]
	    test = vc.ltExpr(c,a);

	  }
	  break;

	case CVC3::LT:
	  if(d==x) {     // case: (a,inf), [c,inf)
	    test = vc.trueExpr();
	  }
	  else if(c==x) {// case : (a,inf), (-inf,d]
	    test = vc.ltExpr(a,d);
	  }
	  else { // case: (a,inf), [c,d]
	    test = vc.ltExpr(a,d);
	  }
	  break;
	default:
	  break;
	}

      } else if (IsStrictInequality(e2,d,y,kind2) && x == y) {
	switch(kind1) {
	case CVC3::GT:
	  switch(kind2) {
	  case CVC3::GT: // (-inf,a), (-inf,d)
	    test = vc.trueExpr();
	    break;
	  case CVC3::LT: // (-inf,a), (d,inf)
	    test = vc.ltExpr(a,d);
	    break;
	  default:
	    break;
	  }
	  break;
	case CVC3::LT:
	  switch(kind2) {
	  case CVC3::GT: // (a,inf), (-inf,d)
	    test = vc.ltExpr(a,d);
	    break;
	  case CVC3::LT: // (a,inf), (d,inf)
	    test = vc.trueExpr();
	    break;
	  default:
	    break;
	  }
	  break;
	}
      }
    }

    if(!test.isNull()) {
      test = vc.simplify(test);
      if(test == vc.falseExpr()) {
	result = l_true;
      } else if(test == vc.trueExpr()) {
	result = l_false;
      }
    }


    return result;
  }



  bool ExprManager::EquivalentTo(const CVC3::Expr& e1, const CVC3::Expr& e2) {
    if(e1 == e2) return true;
    return IsTrue(vc.iffExpr(e1,e2));
  }

  typedef HashMap<pair<CVC3::Expr,CVC3::Expr>, bool > LvalueCache;

  LvalueCache LvalueCompatCache;



  bool ExprManager::LvalueCompatible(const CVC3::Expr& e1, const CVC3::Expr& e2) {

    bool result = false;
    bool swap = (e2 < e1);
    if(e1 == e2) result = true;
    else {
      CVC3::Expr lhs = swap ? e1 : e2;
      CVC3::Expr rhs = swap ? e2 : e1;

      pair<CVC3::Expr,CVC3::Expr> ordered_pair(lhs,rhs);
      LvalueCache::const_iterator it =
	LvalueCompatCache.find(ordered_pair);
      if(it!=LvalueCompatCache.end()) {
	result = it->second;
      } else {


	if(e1.getKind()==CVC3::UCONST) {
	  result = e1.subExprOf(e2);
	} else if(e2.getKind()==CVC3::UCONST) {
	  result = e2.subExprOf(e1);
	}
	else {
	  set<CVC3::Expr> set1;
	  ComputeLvalues(e1, set1);
	  set<CVC3::Expr> set2;
	  ComputeLvalues(e2, set2);
	  swap = set1.size() > set2.size();

	  set<CVC3::Expr>& fst_set = swap ? set2 : set1;
	  set<CVC3::Expr>& sec_set = swap ? set1 : set2;

	  set<CVC3::Expr>::const_iterator sec_end = sec_set.end();
	  for(set<CVC3::Expr>::const_iterator i = fst_set.begin();i!=fst_set.end();++i) {
	    if(sec_set.find(*i)!=sec_end) {
	      result = true;
	      break;
	    }
	  }
	}
	LvalueCompatCache.insert(pair<pair<CVC3::Expr,CVC3::Expr>, bool> (ordered_pair,result));
      }
    }
    return result;
  }


  void ExprManager::CollectExprs(const CVC3::Expr& e, std::tr1::unordered_set<CVC3::Expr>& exprs) {
    switch (e.getKind()) {
    case CVC3::TRUE_EXPR:
    case CVC3::FALSE_EXPR:
      return;
    case CVC3::NOT:
    case CVC3::AND:
    case CVC3::OR:
    case CVC3::IMPLIES:
    case CVC3::IFF:
    case CVC3::XOR:
    case CVC3::ITE: {
      const std::vector< CVC3::Expr > & kids = e.getKids();
      BOOST_FOREACH (const CVC3::Expr& kid, kids) {
	CollectExprs(kid,exprs);
      }
    }
      break;
    default:
      exprs.insert(e);
      break;
    }
  }

  void ExprManager::CollectExprs(const CVC3::Expr& e, set<CVC3::Expr>& exprs) {
    switch (e.getKind()) {
    case CVC3::TRUE_EXPR:
    case CVC3::FALSE_EXPR:
      return;
    case CVC3::NOT:
    case CVC3::AND:
    case CVC3::OR:
    case CVC3::IMPLIES:
    case CVC3::IFF:
    case CVC3::XOR:
    case CVC3::ITE: {
      const std::vector< CVC3::Expr > & kids = e.getKids();
      BOOST_FOREACH (const CVC3::Expr& kid, kids) {
	CollectExprs(kid,exprs);
      }
    }
      break;
    default:
      exprs.insert(e);
      break;
    }
  }

  void ExprManager::ComputeLvalues(const CVC3::Expr& e, set<CVC3::Expr>& result) {

    if(e.getKind()==CVC3::UCONST)
      result.insert(e);
    else {
      const std::vector< CVC3::Expr > & kids = e.getKids();
      for(unsigned i = 0; i<kids.size(); ++i) {
	ComputeLvalues(kids[i],result);
      }
    }
  }

  void ExprManager::SplitExprSet(const vector<CVC3::Expr> &arg,vector< vector<CVC3::Expr> > &res) {


    // get the support of e
    std::set<CVC3::Expr> support;
    std::vector<std::set<CVC3::Expr> > arg_support(arg.size());

    for(unsigned i=0; i!=arg.size();++i) {
      ExprManager::ComputeLvalues(arg[i], arg_support[i]);
      support.insert(arg_support[i].begin(),arg_support[i].end());
    }

    std::vector<CVC3::Expr> table(support.size());
    int counter = 0;
    for(std::set<CVC3::Expr>::const_iterator i=support.begin();i!=support.end();++i) {
      table[counter++] = *i;
    }

    typedef Signature<unsigned> Sig;

    std::vector<Sig> partition(arg.size(),table.size());

    // add the expressions to partition
    for(unsigned i=0; i<arg.size();++i) {
      Sig& sig(partition[i]);
      sig.insert(i);
      for(unsigned k=0;k<table.size();++k) {
	sig.s[k] = arg_support[i].find(table[k])!=arg_support[i].end();
      }
    }

    std::vector<Sig> new_partition;
    Sig::partitionRefinement2(partition,new_partition);

    res.resize(new_partition.size());

    unsigned i = 0;
    BOOST_FOREACH (Sig& sig, new_partition) {
      std::set<unsigned int>& current_set (sig.cset);
      res[i].clear();
      BOOST_FOREACH (unsigned j , current_set)
	res[i].push_back(arg[j]);
      ++i;
    }
  }

  CVC3::Expr ExprManager::getExprCube(const Cube& c, const std::vector<CVC3::Expr>& vec) {
    vector<CVC3::Expr> result;

    for(unsigned i=0;i<vec.size();++i) {
      switch(c[i]) {
      case l_true: result.push_back(vec[i]);   break;
      case l_false: result.push_back(!vec[i]); break;
      case l_undef: break;
      }
    }
    return Conjunction(result);
  }
}
