/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with PARAM.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <stdlib.h>
#include <sys/types.h>
#include <iostream>
#include <iomanip>
#include <boost/graph/graph_traits.hpp>
#include <boost/graph/adjacency_list.hpp>
#include <boost/program_options.hpp>
#include <limits>

#include "lang/PRISMParser.h"
#include "lang/Property.h"
#include "lang/Model.h"

#include "SparseMC.h"
#include "PASSStateExplorer.h"
#include "rationalFunction/RationalFunction.h"
#include "infamation/Model2C.h"

extern int line_number;       // current line number
extern std::string file_name; // name of the file currently being parsed

namespace parametric {
  
  using namespace std;
  using namespace util;
  using namespace lang;
  using namespace Rational;
  
  lang::Model model;
  string model_name;
  
  PASSStateExplorer::PASSStateExplorer(SparseMC &_MC)
    : mc(_MC) {
  }
  
  /**
   * Extract until formula.
   * Takes an until formula "P.. [left U[0, time_bound] right]" and extracts
   * left, right and time_bound.
   */
  void PASSStateExplorer::parseFormula() {
    Properties props = model.getProperties();
    Property *prop = props[0].get();
    if (quant == prop->kind) {
      Quant *propQuant = (Quant *) prop;    
      Bound bound = propQuant->getBound();
      const Property *innerProperty = propQuant->getProp();
      Until *untilProp = (Until *) innerProperty;
      const lang::Time &time = untilProp->getTime();
      if (time.kind == lang::Time::UNBOUNDED) {
        mc.setAnalysisType(SparseMC::unboundedUntilAnalysis);
        mc.time = -1.0;
      } else if ((time.kind == lang::Time::LE)
		 || ((time.kind == lang::Time::INTERVAL)
		     && (0.0 == time.t1))) {	
	mc.setAnalysisType(SparseMC::boundedUntilAnalysis);	
	mc.time = time.t2;
      } else {
	throw runtime_error("Property type not implemented");        
      }
      const Property *leftProp = untilProp->getProp1();
      const Property *rightProp = untilProp->getProp2();
      PropExpr *leftExpr = (PropExpr *) leftProp;
      PropExpr *rightExpr = (PropExpr *) rightProp;
      leftF = leftExpr->getExpr();  
      rightF = rightExpr->getExpr();
    } else if (reachability_reward == prop->kind) {
      ReachabilityReward *rew = (ReachabilityReward *) prop;
      mc.time = -1.0;
      rightF = ((PropExpr *) rew->getProp())->getExpr();
      leftF = vc.trueExpr();
      mc.setAnalysisType(SparseMC::reachabilityRewardAnalysis);
    } else {
      throw runtime_error("Property type not implemented");
    }
  }
    
  /**
   * Load a PASS model and appendant property file.
   * @param model_filename model file to be loaded
   * @param property_filename property file to be loaded
   */
  void PASSStateExplorer::loadModel
  (const string &model_filename, const string &property_filename) {
    PRISM::PRISMParser p;
    if (NULL == freopen(model_filename.c_str(),"r",stdin)) {
      throw runtime_error("Failed to open file \"" + model_filename + "\"");
    }
    line_number = 1;
    model_name = RemoveExtension(model_filename);
    file_name = model_filename;
    p.run(model_filename, model);
    fclose(stdin);
    model.Flatten();
    p.run(property_filename, model);
    fclose(stdin);
  }

  void PASSStateExplorer::prepareNondetVars
  (infamation::Model2C &model2c) {
    unsigned numNondetVars(0);
    const unsigned numStates(model2c.getNumStates());
    
    for (unsigned state(0); state < numStates; state++) {
      if (!model2c.inStateSet(0, state)
	  && !model2c.inStateSet(1, state)) {
	model2c.getStateSuccessors(state);
	const unsigned numActiveCommands(model2c.getNumActiveCommands());
	if (numActiveCommands > 1) {
	  numNondetVars += numActiveCommands - 1;
	}
      }
    }
    
    for (unsigned ndVarNr(0); ndVarNr < numNondetVars; ndVarNr++) {
      ostringstream ndVarNrStr;
      ndVarNrStr << ndVarNr;
      string paramName("nd" + ndVarNrStr.str());
      mc.nondetParams.push_back(paramName);
    }
    RationalFunction::addNewSymbolsWhileRunning(mc.nondetParams);
  }
  
  void PASSStateExplorer::constructMC(infamation::Model2C &model2c) {
    const unsigned sinkNr(numeric_limits<unsigned>::max()-1);
    const unsigned targetNr(numeric_limits<unsigned>::max());
    prepareNondetVars(model2c);
    HashMap<unsigned,vertex_descriptor> numberToVertex;
    
    vertex_descriptor sinkVertex(add_vertex(*mc.graph));
    numberToVertex.insert(make_pair(sinkNr, sinkVertex));
    vertex_descriptor targetVertex(add_vertex(*mc.graph));
    mc.targetStates.insert(targetVertex);
    numberToVertex.insert(make_pair(targetNr, targetVertex));
    
    vector<unsigned> init(model2c.getInitialStates());
    for (unsigned initNr(0); initNr < init.size(); initNr++) {
      unsigned initState(init[initNr]);
      vertex_descriptor initVertex(add_vertex(*mc.graph));
      mc.initStates.insert(initVertex);
      numberToVertex.insert(make_pair(initState, initVertex));
    }
    
    vector<unsigned> pres(init);
    vector<unsigned> next;
    
    RationalFunction *succRewardsList(NULL);
    unsigned *succStatesList(model2c.getSuccStatesList());
    RationalFunction *succRatesList(model2c.getSuccRatesListRational());
    if (mc.isRewardAnalysis()) {
      succRewardsList = model2c.getSuccRewardsListRational();
    }
    unsigned *nonDetBounds(NULL);
    if (MDP == mc.model_type) {
      nonDetBounds = model2c.getNonDetBounds();
    }
    
    unsigned nextNonDetVariable(0);
    do {
      for (unsigned stateNr(0); stateNr < pres.size(); stateNr++) {
        unsigned state(pres[stateNr]);
        model2c.getStateSuccessors(state);
	vertex_descriptor stateVertex(numberToVertex[state]);
	RationalFunction stateReward(mc.isRewardAnalysis() ? model2c.getReward(state) : 0);
        const unsigned numActiveCommands(model2c.getNumActiveCommands());
        unsigned nonDetNr(1);
        
        RationalFunction multFactor(1);
        for (unsigned succNr(0); succNr < model2c.getNumSuccStates(); succNr++) {
          // TODO make this work for more than 2 nondet choices
          if (MDP == mc.model_type) {
	    if (numActiveCommands > 1) {
	      vector<unsigned> expv(RationalFunction::getNumSymbols(), 0);
	      unsigned varNr(RationalFunction::getNumSymbols()
			     - mc.nondetParams.size() + nextNonDetVariable);
	      expv[varNr] = 1;
	      RationalFunction nonDetVar(1, expv);
	      if (1 == nonDetNr) {
		multFactor = nonDetVar;
	      } else {
		multFactor = RationalFunction(1) - nonDetVar;
		nextNonDetVariable++;
		}
	    }
          }
          
          unsigned succState(succStatesList[succNr]);
          RationalFunction rate(succRatesList[succNr]);
          RationalFunction reward(mc.isRewardAnalysis() ? succRewardsList[succNr] + stateReward: 0);
	  const bool isSinkState(model2c.inStateSet(0, succState));
	  const bool isTargetState(model2c.inStateSet(1, succState));
	  if (isSinkState) {
	    succState = sinkNr;
	  } else if (isTargetState) {
	    succState = targetNr;
	  } else if (numberToVertex.find(succState) == numberToVertex.end()) {
	    /* successor state not yet seen */
	    next.push_back(succState);
	    vertex_descriptor succVertex(add_vertex(*mc.graph));
	    numberToVertex.insert(make_pair(succState, succVertex));
	  }
	  
	  TransProp prop;
	  prop.setValue(rate * multFactor);
	  if (mc.isRewardAnalysis()) {
	    prop.setReward(reward);
	  }
	  vertex_descriptor succVertex(numberToVertex[succState]);
	  pair<edge_descriptor, bool> previousEdge;
	  previousEdge = edge(stateVertex, succVertex, *mc.graph);
	  if (!previousEdge.second) {
	    add_edge(stateVertex, succVertex, prop, *mc.graph);
	  } else {
	    TransProp &previousProp((*mc.graph)[previousEdge.first]);
	    if (mc.isRewardAnalysis()) {
	      previousProp.setReward((previousProp.getValue() * previousProp.getReward()
				      + prop.getValue() * prop.getReward()));
	      previousProp.setReward(previousProp.getReward()
				     / (prop.getValue() + previousProp.getValue()));
	    }
	    previousProp.setValue(previousProp.getValue() + prop.getValue());
	  }
	  if (MDP == mc.model_type) {
	    if (succNr+1 == nonDetBounds[nonDetNr]) {
	      nonDetNr++;
	    }
	  }
	}
      }
      
      pres.swap(next);
      next.clear();
    } while (0 != pres.size());
  }
  
  void PASSStateExplorer::exploreAllStates(infamation::Model2C &model2c) {
    vector<unsigned> pres(model2c.getInitialStates());
    vector<unsigned> next;
    
    unsigned *succStatesList(model2c.getSuccStatesList());
    do {
      for (unsigned stateNr(0); stateNr < pres.size(); stateNr++) {
        const unsigned state(pres[stateNr]);
	const unsigned lastNumStates(model2c.getNumStates());
        model2c.getStateSuccessors(state);
        for (unsigned succNr(0); succNr < model2c.getNumSuccStates(); succNr++) {
          const unsigned succState(succStatesList[succNr]);
	  const bool isSinkState(model2c.inStateSet(0, succState));
	  const bool isTargetState(model2c.inStateSet(1, succState));
	  if (!isSinkState && !isTargetState
	      && (succState >= lastNumStates)) {
	    next.push_back(succState);
	  }
	}
      }
      
      pres.swap(next);
      next.clear();
    } while (0 != pres.size());
  }

  void PASSStateExplorer::explore() {
    mc.statistics.exploreTime.Start();
    string model_filename(mc.vm["model-file"].as<string>());
    string formula_filename;
    if (0 != mc.vm.count("formula-file")) {
      formula_filename = mc.vm["formula-file"].as<string>();
    }
    loadModel(model_filename, formula_filename);
    
    mc.model_type = model.getModelType();
    
    if ("" != formula_filename) {
      parseFormula();
    } else {
      leftF = vc.trueExpr();
      rightF = vc.falseExpr();
    }
    
    mc.prepareSymbols();
    CVC3::Expr sinkF(vc.simplify(!(leftF || rightF)));
    infamation::Model2C model2c(model);
    model2c.addStateSet(sinkF);
    model2c.addStateSet(rightF);
    model2c.setValueType(infamation::Model2C::paramVal);
    model2c.setCompress(true);
    model2c.setRewards(mc.isRewardAnalysis());
    model2c.build();
    exploreAllStates(model2c);

    constructMC(model2c);
    
    mc.statistics.numStatesModel = num_vertices(*mc.graph);
    mc.statistics.numTransitionsModel = num_edges(*mc.graph);
    mc.statistics.exploreTime.Stop();





#if 0
    pair<vertex_iterator, vertex_iterator> vp;
    for (vp = vertices(*mc.graph); vp.first != vp.second; ++vp.first) {
      cout << "state" << endl;
      vertex_descriptor vMDP(*vp.first);
      pair<out_edge_iterator, out_edge_iterator> it;
      for (it = out_edges(vMDP, *mc.graph); it.first != it.second; ++it.first) {
	cout << "choice" << endl;
	edge_descriptor eChoice(*it.first);
	vertex_descriptor vChoice(target(eChoice, *mc.graph));
	pair<out_edge_iterator, out_edge_iterator> it2;
	for (it2 = out_edges(vChoice, *mc.graph); it2.first != it2.second; ++it2.first) {
	  edge_descriptor eProb(*it2.first);
	  vertex_descriptor vSucc(target(eProb, *mc.graph));
	  const TransProp &prop((*mc.graph)[eProb].getValue());
	  const RationalFunction &val(prop.getValue());
	  cout << "yy> " << val << "  " << endl;
	}
      }
    }
#endif
  }
}
