/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with PARAM. If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2010 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <assert.h>
#include "SPMC.h"

namespace parametric {
  using namespace Rational;

  SPMC::SPMC() {
    numStates = 0;
    rows = NULL;
    cols = NULL;
    nonZeros = NULL;
    transRewards = NULL;
    stateRewards = NULL;
    backRows = NULL;
    backCols = NULL;
    colIndex = 0;
  }

  SPMC::~SPMC() {
    if (NULL != rows) {
      delete[] rows;
    }
    if (NULL != cols) {
      delete[] cols;
    }
    if (NULL != nonZeros) {
      delete[] nonZeros;
    }
    if (NULL != transRewards) {
      delete[] transRewards;
    }
    if (NULL != stateRewards) {
      delete[] stateRewards;
    }
    if (NULL != backRows) {
      delete[] backRows;
    }
    if (NULL != backCols) {
      delete[] backCols;
    }
  }

  void SPMC::reserveRowsMem(unsigned numStates_) {
    assert(NULL == rows);
    rows = new unsigned[numStates_ + 1];
    rows[0] = 0;
  }
  
  void SPMC::reserveColsMem(unsigned numCols) {
    assert(NULL == cols);
    cols = new unsigned[numCols];
    nonZeros = new RationalFunction[numCols];
  }

  void SPMC::reserveTransRewardsMem(unsigned numCols) {
    assert(NULL == transRewards);
    transRewards = new RationalFunction[numCols];
  }

  void SPMC::reserveStateRewardsMem(unsigned numStates_) {
    assert(NULL == stateRewards);
    stateRewards = new RationalFunction[numStates_];
  }

  void SPMC::setStateReward(RationalFunction rat) {
    stateRewards[numStates] = rat;
  }

  void SPMC::addSucc(state state, RationalFunction prob) {
    cols[colIndex] = state;
    nonZeros[colIndex] = prob;
    colIndex++;
  }

  void SPMC::addSucc(state succState, RationalFunction prob, RationalFunction rew) {
    cols[colIndex] = succState;
    nonZeros[colIndex] = prob;
    transRewards[colIndex] = rew;
    colIndex++;
  }

  void SPMC::finishState() {
    numStates++;
    rows[numStates] = colIndex;
  }
  
  unsigned SPMC::getNumStates() const {
    return numStates;
  }

  unsigned SPMC::getNumTrans() const {
    return rows[numStates];
  }

  unsigned SPMC::getNumSuccStates(state state) const {
    return rows[state + 1] - rows[state];
  }

  PMC::state SPMC::getSuccState(state state, unsigned number) const {
    return cols[rows[state] + number];
  }

  RationalFunction SPMC::getSuccProb(state state, unsigned number) const {
    return nonZeros[rows[state] + number];    
  }

  RationalFunction SPMC::getStateReward(state state) const {
    return stateRewards[state];
  }

  RationalFunction SPMC::getSuccReward(state state, unsigned number) const {
    return transRewards[rows[state] + number];
  }

  unsigned SPMC::getNumPredStates(state state) const {
    assert(NULL != backRows);
    return backRows[state + 1] - backRows[state];
  }
  
  PMC::state SPMC::getPredState(state state, unsigned number) const {
    return backCols[backRows[state] + number];
  }

  RationalFunction SPMC::getPredProb(state state, unsigned number) const {
    PMC::state pred = getPredState(state, number);
    unsigned numSucc = getNumSuccStates(pred);
    for (unsigned succNr(0); succNr < numSucc; succNr++) {
      if (state == getSuccState(pred, succNr)) {
	return getSuccProb(pred, succNr);
      }
    }
    assert(false);
  }

  RationalFunction SPMC::getPredReward(state state, unsigned number) const {
    PMC::state pred = getPredState(state, number);
    unsigned numSucc = getNumSuccStates(pred);
    for (unsigned succNr(0); succNr < numSucc; succNr++) {
      if (state == getSuccState(pred, succNr)) {
	return getSuccReward(pred, succNr);
      }
    }
    assert(false);
  }

  void SPMC::computeBackTransitions() {
    assert(NULL == backRows);
    assert(NULL == backCols);
    backRows = new unsigned[numStates + 1];
    for (unsigned state(0); state < numStates; state++) {
      backRows[state] = 0;
    }
    backRows[numStates] = 0;
    for (unsigned state(0); state < numStates; state++) {
      for (unsigned succNr(rows[state]); succNr < rows[state + 1]; succNr++) {
	unsigned succState(cols[succNr]);
	backRows[succState + 1]++;
      }
    }
    for (unsigned state(0); state < numStates; state++) {
      backRows[state + 1] += backRows[state];
    }
    backCols = new unsigned[backRows[numStates]];
    for (unsigned state(0); state < numStates; state++) {
      for (unsigned succNr(rows[state]); succNr < rows[state + 1]; succNr++) {
	unsigned succState(cols[succNr]);
	backCols[backRows[succState]] = state;
	backRows[succState]++;
      }
    }
    delete[] backRows;
    backRows = new unsigned[numStates + 1];
    for (unsigned state(0); state < numStates; state++) {
      backRows[state] = 0;
    }
    backRows[numStates] = 0;
    for (unsigned state(0); state < numStates; state++) {
      for (unsigned succNr(rows[state]); succNr < rows[state + 1]; succNr++) {
	unsigned succState(cols[succNr]);
	backRows[succState + 1]++;
      }
    }
    for (unsigned state(0); state < numStates; state++) {
      backRows[state + 1] += backRows[state];
    }
  }

  void SPMC::setSuccProb(state state, unsigned number,
			 Rational::RationalFunction newVal) {
    nonZeros[rows[state] + number] = newVal;
  }

  void SPMC::setSuccReward(state state, unsigned number,
			   Rational::RationalFunction newVal) {
    transRewards[rows[state] + number] = newVal;
  }

  PMM::state SPMC::getSuccNrBySuccState(state state, unsigned succState) const {
    for (unsigned succNr(rows[state]); succNr < rows[state + 1]; succNr++) {
      if (cols[succNr] == succState) {
	return succNr - rows[state];
      }
    }

    return getInvalidState();
  }
}
