/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PARAM. If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009-2010 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <stdlib.h>
#include <sys/types.h>
#include <iostream>
#include <limits>
#include "prismparser/PRISMParser.h"
#include "prismparser/Property.h"
#include "prismparser/Model.h"
#include "Controller.h"
#include "Statistics.h"
#include "PASSStateExplorer.h"
#include "rationalFunction/RationalFunction.h"
#include "model2x/Model2X.h"

namespace parametric {  
  using namespace std;
  using namespace prismparser;
  using namespace model2x;
  using namespace Rational;
  
  PASSStateExplorer::PASSStateExplorer(Controller &_MC)
    : mc(_MC) {
    minimize = false;
    model = new Model();
  }
  
  PASSStateExplorer::~PASSStateExplorer() {
    delete model;
  }
  /**
   * Extract until formula.
   * Takes an until formula "P.. [left U[0, time_bound] right]" and extracts
   * left, right and time_bound.
   */
  void PASSStateExplorer::parseFormula() {
    Properties props = model->getProperties();
    Property *prop = props[0].get();
    if (quant == prop->kind) {
      Quant *propQuant = (Quant *) prop;    
      Bound bound = propQuant->getBound();
      const Property *innerProperty = propQuant->getProp();
      Until *untilProp = (Until *) innerProperty;
      const prismparser::Time &time = untilProp->getTime();
      if (time.kind == prismparser::Time::UNBOUNDED) {
        mc.setAnalysisType(Controller::unboundedUntilAnalysis);
        mc.time = -1.0;
      } else if ((time.kind == prismparser::Time::LE)
		 || ((time.kind == prismparser::Time::INTERVAL)
		     && (0.0 == time.t1))) {	
	mc.setAnalysisType(Controller::boundedUntilAnalysis);	
	mc.time = time.t2;
      } else {
	throw runtime_error("Property type not implemented");        
      }
      const Property *leftProp = untilProp->getProp1();
      const Property *rightProp = untilProp->getProp2();
      PropExpr *leftExpr = (PropExpr *) leftProp;
      PropExpr *rightExpr = (PropExpr *) rightProp;
      leftF = leftExpr->getExpr();  
      rightF = rightExpr->getExpr();
    } else if (reachability_reward == prop->kind) {
      ReachabilityReward *rew = (ReachabilityReward *) prop;
      mc.time = -1.0;
      rightF = ((PropExpr *) rew->getProp())->getExpr();
      leftF = vc.trueExpr();
      mc.setAnalysisType(Controller::reachabilityRewardAnalysis);
    } else {
      throw runtime_error("Property type not implemented");
    }
  }
    
  /**
   * Load a PASS model and appendant property file.
   * @param model_filename model file to be loaded
   * @param property_filename property file to be loaded
   */
  void PASSStateExplorer::loadModel
  (const string &model_filename, const string &property_filename) {
    prismparser::PRISMParser p;
    if (NULL == freopen(model_filename.c_str(),"r",stdin)) {
      throw runtime_error("Failed to open file \"" + model_filename + "\"");
    }
    p.run(model_filename, *model);
    fclose(stdin);
    model->Flatten();
    p.run(property_filename, *model);
    fclose(stdin);
  }

  void PASSStateExplorer::constructMC(Model2X &model2x) {
    unsigned numInitStates(model2x.getNumInitStates());
    const unsigned *init(model2x.getInitStates());
    for (unsigned initNr(0); initNr < numInitStates; initNr++) {
      unsigned initState(init[initNr]);
      mc.initStates.insert(initState);
    }
    
    vector<unsigned> pres(init, init + numInitStates);
    vector<unsigned> next;
    
    const RationalFunction *succRewardsList(NULL);
    const unsigned *succStatesList(model2x.getSuccStatesList());
    const RationalFunction *succRatesList(model2x.getSuccRatesList());
    if (mc.isRewardAnalysis()) {
      succRewardsList = model2x.getSuccRewardsList();
    }

    unsigned numStates(model2x.getNumStates());
    for (unsigned state(0); state < numStates; state++) {
      const bool isSinkState(model2x.inStateSet(0, state));
      const bool isTargetState(model2x.inStateSet(1, state));
      if (isSinkState || isTargetState) {
	if (mc.isRewardAnalysis()) {
	  mc.pmc->addSucc(state, 1, 0);
	} else {
	  mc.pmc->addSucc(state, 1);
	}
	if (isTargetState) {
	  mc.targetStates.insert(state);
	}
      } else {
	model2x.getStateSuccessors(state);
	RationalFunction stateReward(mc.isRewardAnalysis() ? model2x.getStateReward(state) : 0);

	map<PMM::state,RationalFunction> probMap;
	map<PMM::state,RationalFunction> rewardMap;
        for (unsigned succNr(0); succNr < model2x.getNumSuccStates(); succNr++) {
          unsigned succState(succStatesList[succNr]);
          RationalFunction prob(succRatesList[succNr]);
	  probMap[succState] += prob;
	  if (mc.isRewardAnalysis()) {
	    RationalFunction reward(succRewardsList[succNr] + stateReward);
	    rewardMap[succState] += reward * prob;
	  }
	}

	for (map<PMM::state,RationalFunction>::iterator it(probMap.begin());
	     it != probMap.end(); it++) {
          unsigned succState(it->first);
          RationalFunction prob(it->second);
	  if (mc.isRewardAnalysis()) {
	    RationalFunction reward(rewardMap[it->first]);
	    mc.pmc->addSucc(succState, prob, reward / prob);
	  } else {
	    mc.pmc->addSucc(succState, prob);
	  }	  
       	}
      }
      mc.pmc->finishState();
    }
  }
  
  void PASSStateExplorer::exploreAllStates(Model2X &model2x) {
    unsigned numInitStates(model2x.getNumInitStates());
    const unsigned *init(model2x.getInitStates());
    vector<unsigned> pres(init, init + numInitStates);
    vector<unsigned> next;
    
    unsigned numTrans(0);

    const unsigned *succStatesList(model2x.getSuccStatesList());
    do {
      for (unsigned stateNr(0); stateNr < pres.size(); stateNr++) {
        const unsigned state(pres[stateNr]);
	const unsigned lastNumStates(model2x.getNumStates());
        model2x.getStateSuccessors(state);
        for (unsigned succNr(0); succNr < model2x.getNumSuccStates(); succNr++) {
	  numTrans++;
          const unsigned succState(succStatesList[succNr]);
	  const bool isSinkState(model2x.inStateSet(0, succState));
	  const bool isTargetState(model2x.inStateSet(1, succState));
	  if (!isSinkState && !isTargetState
	      && (succState >= lastNumStates)) {
	    next.push_back(succState);
	  } else {
	    numTrans++;
	  }
	}
      }
      
      pres.swap(next);
      next.clear();
    } while (0 != pres.size());
    const unsigned numStates(model2x.getNumStates());
    mc.pmc->reserveRowsMem(numStates);
    mc.pmc->reserveColsMem(numTrans);
    if (mc.isRewardAnalysis()) {
      mc.pmc->reserveStateRewardsMem(numStates);
      mc.pmc->reserveTransRewardsMem(numTrans);
    }
    mc.statistics->numStatesModel = numStates;
    mc.statistics->numTransitionsModel = numTrans;
  }

  void PASSStateExplorer::explore() {
    mc.statistics->exploreTime.start();
    string model_filename(mc.vm["model-file"].as<string>());
    string formula_filename;
    if (0 != mc.vm.count("formula-file")) {
      formula_filename = mc.vm["formula-file"].as<string>();
    }
    loadModel(model_filename, formula_filename);
    
    mc.model_type = model->getModelType();
    
    if ("" != formula_filename) {
      parseFormula();
    } else {
      leftF = vc.trueExpr();
      rightF = vc.falseExpr();
    }
    
    mc.prepareSymbols(*model);
    CVC3::Expr sinkF(vc.simplify(!(leftF || rightF)));
    Model2X model2x;
    model2x.setModel(*model);
    model2x.addStateSet(sinkF);
    model2x.addStateSet(rightF);
    model2x.setUseRewards(mc.isRewardAnalysis());
    model2x.build();
    model2x.addInitStates();
    exploreAllStates(model2x);
    constructMC(model2x);
    
    mc.statistics->exploreTime.stop();
  }
}
