/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with PARAM.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include "Eliminator.h"

/**
 * TODOs:
 * - use priority queue to reduce in an order such that overall number of
 *   transitions stays small
 */

#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <boost/graph/strong_components.hpp>
#include "SparseMC.h"

using namespace std;
using namespace boost;
using namespace Rational;

namespace parametric {
  
  template < typename List > class bfs_list_visitor :
    public default_bfs_visitor {
  public:
    bfs_list_visitor
    (List &__list, StateSet &initStates__, StateSet &targetStates__)
      : m(__list), initStates(initStates__), targetStates(targetStates__) {
    }
    template < typename Vertex, typename Graph >
    void discover_vertex(Vertex u, const Graph & graph) const {
      if (targetStates.find(u) == targetStates.end()) {
        if (initStates.find(u) == initStates.end()) {
          m.push_back(u);
        }
      }
    }
    
    template < typename Vertex, typename Graph >
    void finish_vertex(Vertex u, const Graph & graph) const {
      
    }
    
    List &m;
    StateSet &initStates;
    StateSet &targetStates;
  };
  
  RationalFunction Eliminator::leavingSum(vertex_descriptor u) {
    if (0 == out_degree(u, *graph)) {
      return RationalFunction(0);
    } else if (1 == out_degree(u, *graph)) {
      if (edge(u, u, *graph).second) {
        return RationalFunction(0);
      } else {
        return RationalFunction(1);
      }
    } else {
      if (edge(u, u, *graph).second) {
        return RationalFunction(1) - (*graph)[edge(u, u, *graph).first].getValue();
      } else {
        return RationalFunction(1);
      }
    }
  }
  
  void Eliminator::eliminateState(vertex_descriptor v) {
    /* look if state to remove has a loop l. If it has, set "loop" to
     * val(l^*). Else, just set it to "1" (neutral element of
     * multiplication)
     */
    RationalFunction leaving(leavingSum(v));
    RationalFunction loopReward(0);
    
    pair<edge_descriptor,bool> loopEdge(edge(v, v, *graph));
    if (loopEdge.second) {
      if (rewardAnalysis && (out_degree(v, *graph) > 1)) {
        loopReward = ((*graph)[loopEdge.first].getValue() / leaving)
          * (*graph)[loopEdge.first].getReward();
      }
      remove_edge(loopEdge.first, *graph);
    }
    
    /* calculate val(l^* * out) for each outgoing label
     * "out". Afterwards, remove outgoing transitions from state to be
     * removed.
     */
    vector<RationalFunction> out_val;
    vector<vertex_descriptor> out_target;
    vector<RationalFunction> out_reward;
    pair<out_edge_iterator, out_edge_iterator> it;
    for (it = out_edges(v, *graph); it.first != it.second; ++it.first) {
      edge_descriptor e(*it.first);
      vertex_descriptor u(target(e, *graph));
      out_val.push_back((*graph)[e].getValue() / leaving);   
      if (rewardAnalysis) {
        out_reward.push_back(loopReward + (*graph)[e].getReward());
      }
      out_target.push_back(u);
    }
    clear_out_edges(v, *graph);
    
    /* all states with transitions to state to be removed get new
     * transitions.
     */
    vector<vertex_descriptor> back_v;
    vector<edge_descriptor> back_e;
    pair<in_edge_iterator,in_edge_iterator> in_it;
    for (in_it = in_edges(v, *graph); in_it.first != in_it.second;
         ++in_it.first) {
      edge_descriptor e(*in_it.first);
      vertex_descriptor u(source(e, *graph));
      back_v.push_back(u);
      back_e.push_back(e);
    }
    
    /* iterate over incoming states of state to remove */
    RationalFunction rew(0);
    for (unsigned incomingNr(0); incomingNr < back_v.size(); incomingNr++) {
      vertex_descriptor back(back_v[incomingNr]);
      edge_descriptor e(back_e[incomingNr]);
      
      const RationalFunction &back_val((*graph)[e].getValue());
      unsigned out_idx = 0;
      /* iterate over outgoing transitions */
      for (out_idx = 0; out_idx < out_target.size(); out_idx++) {
        vertex_descriptor to(out_target[out_idx]);
        RationalFunction r(back_val * out_val[out_idx]);
        if (rewardAnalysis) {
          const RationalFunction &back_reward((*graph)[e].getReward());
          rew = (back_reward + out_reward[out_idx]) * r;
        }
        RationalFunction old_reward(0);
        
        /* see whether there is already an existing edge */
        pair<edge_descriptor,bool> backToTo(edge(back, to, *graph));
        if (backToTo.second) {
          RationalFunction tr_val((*graph)[backToTo.first].getValue());
          r = r + tr_val;
          if (rewardAnalysis) {
            old_reward = (*graph)[backToTo.first].getReward() * tr_val;
            rew += old_reward;
          }
          remove_edge(backToTo.first, *graph);
        }
        
        TransProp tp(r);
        if (rewardAnalysis) {
          rew = rew / r;
          tp.setReward(rew);
        }
        add_edge(back, to, tp, *graph);
      } /* end iterating outgoing transitions */
      remove_edge(e, *graph);
    } /* end iterating incoming transitions */
    
    remove_vertex(v, *graph);
  }
  
  void Eliminator::eliminateStates(StateList &elemStates) {
    for (StateList::iterator it = elemStates.begin();
         it != elemStates.end(); it++) {
      vertex_descriptor u(*it);
      eliminateState(u);
    }
  }
  
  /**
   * Collects all non-init and non-target states.
   *
   * @param set will contain all non-init and non-target states
   */
  void Eliminator::collectNonInitNonTargetStates(StateSet &stateSet) {
    pair<vertex_iterator, vertex_iterator> vp;
    for (vp = vertices(*graph); vp.first != vp.second; ++vp.first) {
      vertex_descriptor u(*vp.first);
      if ((initStates->find(u) == initStates->end())
          && (targetStates->find(u) == targetStates->end())) {
        stateSet.insert(u);
      }
    }
  }
  
  /**
   * Collects reachable non-init, non-target states and places in @a stateList.
   * Will begin at initial states, continuing in breadth-first-search placing
   * states in BFS order in @a stateList.
   *
   * @param stateList list of states to be created
   */
  void Eliminator::collectNonInitNonTargetStatesOrdered(StateList &stateList) {
    map<vertex_descriptor, default_color_type> colorMap;
    associative_property_map<std::map<vertex_descriptor, default_color_type> >
      colorMapBoost(colorMap);
    bfs_list_visitor<StateList> list_visitor(stateList, *initStates, *targetStates);
    pair<vertex_iterator, vertex_iterator> vp;
    for (vp = vertices(*graph); vp.first != vp.second; ++vp.first) {
      vertex_descriptor v(*vp.first);
      colorMap[v] = color_traits<Graph>::white();
    }
    for (vp = vertices(*graph); vp.first != vp.second; ++vp.first) {
      vertex_descriptor v(*vp.first);
      if (initStates->find(v) != initStates->end()) {
        if (colorMap[v] == color_traits<Graph>::white()) {
          breadth_first_visit(*graph, v,
                              visitor(list_visitor).color_map(colorMapBoost));
        }
      }
    }
    
    if (backward == eliminationOrder) {
      stateList.reverse();
    } else if (random == eliminationOrder) {
      vector<vertex_descriptor> stateVector;
      for (StateList::iterator it = stateList.begin(); it != stateList.end();
           it++) {
        stateVector.push_back(*it);
      }
      random_shuffle(stateVector.begin(), stateVector.end());
      stateList.clear();
      for (unsigned i = 0; i < stateVector.size(); i++) {
        stateList.push_back(stateVector[i]);
      }
    }
  }
  
  /**
   * Prepares initial states for state-elimination based algorithms.
   * Only initial states with incoming transitions are considered.
   */
  void Eliminator::prepareInitialStatesForUnbounded() {
    vector<vertex_descriptor> mustPrepare;
    pair<vertex_iterator, vertex_iterator> vp;
    for (vp = vertices(*graph); vp.first != vp.second; ++vp.first) {
      vertex_descriptor u(*vp.first);
      if ((initStates->find(u) != initStates->end())
          && (0 != in_degree(u, *graph))) {
        mustPrepare.push_back(u);
      }
    }
    
    for (vector<vertex_descriptor>::iterator it = mustPrepare.begin();
         it != mustPrepare.end(); it++) {
      RationalFunction zero(0);
      RationalFunction one(1);
      vertex_descriptor u(*it);
      initStates->erase(u);
      vertex_descriptor v(add_vertex(*graph));
      pair<edge_descriptor,bool> ep(add_edge(v, u, (*graph)));
      (*graph)[ep.first].setValue(one);
      (*graph)[ep.first].setReward(zero);
      initStates->insert(v);
    }
  }
  
  void Eliminator::setGraph(Graph *graph__) {
    graph = graph__;
  }
  
  void Eliminator::setInitStates(StateSet &initStates__) {
    initStates = &initStates__;
  }
  
  
  void Eliminator::setTargetStates(StateSet &targetStates__) {
    targetStates = &targetStates__;
  }
  
  
  void Eliminator::setStateRewards(RewardMap &stateRewards__) {
    stateRewards = &stateRewards__;
  }
  
  void Eliminator::setRewardAnalysis(bool rewardAnalysis__) {
    rewardAnalysis = rewardAnalysis__;
  }
  
  void Eliminator::setEliminationOrder(EliminationOrder eliminationOrder__) {
    eliminationOrder = eliminationOrder__;
  }
  
  void Eliminator::eliminate(std::vector<RationalFunction> &result) {
    prepareInitialStatesForUnbounded();
    
    /* remove all states except initial and target ones */
    StateList stateList;
    collectNonInitNonTargetStatesOrdered(stateList);
    eliminateStates(stateList);
    
    /* now collect results */
    pair<vertex_iterator, vertex_iterator> vp;
    for (vp = vertices(*graph); vp.first != vp.second; ++vp.first) {
      vertex_descriptor v(*vp.first);
      if (initStates->find(v) != initStates->end()) {
        RationalFunction vValue(0);
        if (targetStates->find(v) != targetStates->end()) {
          if (!rewardAnalysis) {
            vValue = RationalFunction(1);
          }
        } else {
          pair<out_edge_iterator, out_edge_iterator> it;
          for (it = out_edges(v, *graph); it.first != it.second; ++it.first) {
            edge_descriptor e(*it.first);
            if (rewardAnalysis) {
              vValue += (*graph)[e].getValue() * (*graph)[e].getReward();
            } else {
              vValue += (*graph)[e].getValue();
            }
          }
        }
        result.push_back(vValue);
      }
    }
  }
}
