/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with PARAM. If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009-2010 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <map>
#include <list>
#include <boost/dynamic_bitset.hpp>
#include "Controller.h"
#include "GPMC.h"
#include "Eliminator.h"

using namespace std;
using namespace Rational;

namespace parametric {
  RationalFunction Eliminator::leavingSum(PMM::state state) {
    RationalFunction result(0);
    for (unsigned succ(0); succ < pmc->getNumSuccStates(state); succ++) {
      PMM::state succState(pmc->getSuccState(state, succ));
      if (state != succState) {
	result += pmc->getSuccProb(state, succ);
      }
    }

    return result;
  }
  
  void Eliminator::eliminateState(PMM::state elimState) {
    bool isInit(initStates->find(elimState) != initStates->end());
    RationalFunction leaving(leavingSum(elimState));
    RationalFunction loopReward(0);
    
    unsigned loopNr(pmc->getSuccNrBySuccState(elimState, elimState));
    if (pmc->getInvalidState() != loopNr) {
      if (0 != leaving) {
	if (rewardAnalysis) {
	  loopReward = ((1 - leaving) / leaving)
	    * pmc->getSuccReward(elimState, loopNr);
	}
	pmc->removeSuccTrans(elimState, loopNr);
      }
    }

    if (0 == leaving) {
      return;
    }

    /* calculate outgoing probabilities */
    vector<RationalFunction> outVals;
    vector<PMM::state> outStates;
    vector<RationalFunction> outRewards;
    for (unsigned succ(0); succ < pmc->getNumSuccStates(elimState); succ++) {
      PMM::state succState(pmc->getSuccState(elimState, succ));
      RationalFunction outProb(pmc->getSuccProb(elimState, succ) / leaving);      
      outVals.push_back(outProb);
      if (isInit) {
	pmc->setSuccProb(elimState, succ, outProb);
      }
      if (rewardAnalysis) {
	RationalFunction outReward(loopReward + pmc->getSuccReward(elimState, succ));
        outRewards.push_back(outReward);
	if (isInit) {
	  pmc->setSuccReward(elimState, succ, outReward);
	}
      }
      outStates.push_back(succState);
    }

    vector<RationalFunction> inVals;
    vector<PMM::state> inStates;
    vector<RationalFunction> inRewards;
    vector<unsigned> remove;
    for (unsigned pred(0); pred < pmc->getNumPredStates(elimState); pred++) {
      PMM::state predState(pmc->getPredState(elimState, pred));
      unsigned predSuccNr(pmc->getSuccNrBySuccState(predState, elimState));
      RationalFunction inVal(pmc->getSuccProb(predState, predSuccNr));
      inVals.push_back(inVal);
      if (rewardAnalysis) {
	RationalFunction inReward(pmc->getSuccReward(predState, predSuccNr));
	inRewards.push_back(inReward);
      }
      inStates.push_back(predState);
      remove.push_back(predSuccNr);
    }
    for (unsigned nr(0); nr < inStates.size(); nr++) {
      PMM::state predState(inStates[nr]);
      pmc->removeSuccTrans(predState, remove[nr]);
    }

    for (unsigned inNr(0); inNr < inStates.size(); inNr++) {
      RationalFunction inProb(inVals[inNr]);
      PMM::state inState(inStates[inNr]);
      for (unsigned outNr(0); outNr < outStates.size(); outNr++) {
	RationalFunction outProb(outVals[outNr]);
	PMM::state outState(outStates[outNr]);
	RationalFunction prob(inProb * outProb);
	unsigned inToOut(pmc->getSuccNrBySuccState(inState, outState));
	if (pmc->getInvalidState() != inToOut) {
	  RationalFunction oldProb(pmc->getSuccProb(inState, inToOut));
	  pmc->setSuccProb(inState, inToOut, prob + oldProb);
	  if (rewardAnalysis) {
	    RationalFunction newRew(inRewards[inNr] + outRewards[outNr]);
	    RationalFunction oldRew(pmc->getSuccReward(inState, inToOut));
	    RationalFunction rew((newRew * prob + oldRew * oldProb) / (prob + oldProb));
	    pmc->setSuccReward(inState, inToOut, rew);
	  }
	} else {
	  if (rewardAnalysis) {
	    RationalFunction rew(inRewards[inNr] + outRewards[outNr]);
	    pmc->addSucc(inState, outState, prob, rew);
	  } else {
	    pmc->addSucc(inState, outState, prob);
	  }
	}
      }
    }

    if (!isInit) {
      pmc->makeAbsorbing(elimState);
    }
  }
  
  void Eliminator::eliminateStates(StateList &elemStates) {
    for (StateList::iterator it = elemStates.begin();
         it != elemStates.end(); it++) {
      PMM::state state(*it);
      eliminateState(state);
    }
  }
  
  /**
   * Collects reachable non-init, non-target states and places in @a stateList.
   * Will begin at initial states, continuing in breadth-first-search placing
   * states in BFS order in @a stateList.
   *
   * @param stateList list of states to be created
   */
  void Eliminator::collectStatesOrdered(StateList &stateList) {
    boost::dynamic_bitset<> seen(pmc->getNumStates());
    list<PMM::state> work;
    for (StateSet::iterator it(initStates->begin()); it != initStates->end();
	 it++) {
      work.push_back(*it);
      seen[*it] = true;
      stateList.push_back(*it);
    }
    while (!work.empty()) {
      PMM::state state(work.front());
      work.pop_front();
      for (unsigned succ(0); succ < pmc->getNumSuccStates(state); succ++) {
	PMM::state succState(pmc->getSuccState(state, succ));
	if ((0 == targetStates->count(succState))
	    && (!seen[succState])) {
	  work.push_back(succState);
	  stateList.push_back(succState);
	}
	seen[succState] = true;
      }
    }

    if (backward == eliminationOrder) {
      stateList.reverse();
    } else if (random == eliminationOrder) {
      vector<PMM::state> stateVector;
      for (StateList::iterator it = stateList.begin(); it != stateList.end();
           it++) {
        stateVector.push_back(*it);
      }
      random_shuffle(stateVector.begin(), stateVector.end());
      stateList.clear();
      for (unsigned i = 0; i < stateVector.size(); i++) {
        stateList.push_back(stateVector[i]);
      }
    }
  }
  
  void Eliminator::setPMC(GPMC &pmc__) {
    pmc = &pmc__;
  }
  
  void Eliminator::setInitStates(StateSet &initStates__) {
    initStates = &initStates__;
  }
  
  
  void Eliminator::setTargetStates(StateSet &targetStates__) {
    targetStates = &targetStates__;
  }
  
  
  void Eliminator::setRewardAnalysis(bool rewardAnalysis__) {
    rewardAnalysis = rewardAnalysis__;
  }
  
  void Eliminator::setEliminationOrder(EliminationOrder eliminationOrder__) {
    eliminationOrder = eliminationOrder__;
  }
  
  void Eliminator::eliminate(Results &result) {
    /* remove all states except initial and target ones */
    StateList stateList;
    collectStatesOrdered(stateList);
    eliminateStates(stateList);
    /* now collect results */
    for (PMM::state state(0); state < pmc->getNumStates(); state++) {
      if (initStates->find(state) != initStates->end()) {
        RationalFunction vValue(0);
        if (targetStates->find(state) != targetStates->end()) {
          if (!rewardAnalysis) {
            vValue = RationalFunction(1);
          }
        } else {
	  for (unsigned succ(0); succ < pmc->getNumSuccStates(state); succ++) {
	    PMM::state succState(pmc->getSuccState(state, succ));
	    if (targetStates->find(succState) != targetStates->end()) {
	      RationalFunction prob(pmc->getSuccProb(state, succ));
	      if (rewardAnalysis) {
		vValue += prob * pmc->getSuccReward(state, succ);
	      } else {
		vValue += prob;
	      }
	    }
          }
        }
        result.insert(make_pair(state, vValue));
      }
    }
  }
}
