/*
 * This file is part of PARAM.
 *
 * PARAM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PARAM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with PARAM.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2009 Ernst Moritz Hahn (emh@cs.uni-sb.de)
 */

#include <assert.h>
#include "RationalFunction.h"
#include "Base.h"
#include "Polynomial.h"

using namespace std;

namespace Rational {
  void RationalFunction::start() {
    Base::start();
  }

  void RationalFunction::removeUnusedSymbols() {
    Base::removeUnusedSymbols();
  }

  void RationalFunction::addSymbol(const std::string &symbol) {
    Base::addSymbol(symbol);
  }

  void RationalFunction::addNewSymbolsWhileRunning
  (const vector<string> &symbols) {
    Base::addNewSymbolsWhileRunning(symbols);
  }

  unsigned RationalFunction::getNumSymbols() {
    return Base::getNumSymbols();
  }

  RationalFunction::RationalFunction() {
    Polynomial *nom(new Polynomial(0));
    Polynomial *den(new Polynomial(1));
    mpz_set_si(den->coefficients[0], 1);
    for (unsigned i = 0; i < Base::getNumSymbols(); i++) {
      den->monomials[i] = 0;
    }
    Base::PolyPair rational(make_pair(nom, den));
    index = Base::insertRational(rational);
    Base::registerRational(this);
    Base::incRef(index);
  }

  RationalFunction::RationalFunction(const RationalFunction &rat) {
    index = rat.index;
    Base::registerRational(this);
    Base::incRef(index);
  }

  RationalFunction &RationalFunction::operator=
  (const RationalFunction &rat) {
    Base::incRef(rat.index);
    Base::decRef(index);
    index = rat.index;

    return *this;
  }

  RationalFunction::RationalFunction
  (int coefficient, vector<unsigned> monomial) {
    Polynomial *nom;
    if (0 == coefficient) {
      nom = new Polynomial(0);
    } else {
      nom = new Polynomial(1);
      mpz_set_si(nom->coefficients[0], coefficient);
      for (unsigned i = 0; i < monomial.size(); i++) {
        nom->monomials[i] = monomial[i];
      }
    }
    Polynomial *den = new Polynomial(1);
    mpz_set_si(den->coefficients[0], 1);
    for (unsigned i = 0; i < monomial.size(); i++) {
      den->monomials[i] = 0;
    }
    Base::PolyPair rational(make_pair(nom, den));
    index = Base::insertRational(rational);
    Base::registerRational(this);
    Base::incRef(index);
  }

  RationalFunction::RationalFunction(const int coefficient) {
    Polynomial *nom;
    if (0 == coefficient) {
      nom = new Polynomial(0);
    } else {
      nom = new Polynomial(1);
      mpz_set_si(nom->coefficients[0], coefficient);
      for (unsigned i = 0; i < Base::getNumSymbols(); i++) {
        nom->monomials[i] = 0;
      }
    }
    Polynomial *den = new Polynomial(1);
    mpz_set_si(den->coefficients[0], 1);
    for (unsigned i = 0; i < Base::getNumSymbols(); i++) {
      den->monomials[i] = 0;
    }
    Base::PolyPair rational(make_pair(nom, den));
    index = Base::insertRational(rational);
    Base::registerRational(this);
    Base::incRef(index);
  }

  RationalFunction::~RationalFunction() {
    Base::decRef(index);
    Base::deregisterRational(this);
  }

  RationalFunction operator+
  (const RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::add(r1, r2));
    RationalFunction result;
    Base::incRef(number);
    result.index = number;
    return result;
  }
  
  RationalFunction operator-
  (const RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::sub(r1, r2));
    RationalFunction result;
    Base::incRef(number);
    result.index = number;
    return result;
  }
  
  RationalFunction operator*
  (const RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::mul(r1, r2));
    RationalFunction result;
    Base::incRef(number);
    result.index = number;
    return result;
  }

  RationalFunction operator/
  (const RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::div(r1, r2));
    RationalFunction result;
    Base::incRef(number);
    result.index = number;
    return result;
  }

  RationalFunction operator-
  (const RationalFunction &r1) {
    RationalFunction::PolyPairIter number(Base::neg(r1));
    RationalFunction result;
    Base::incRef(number);
    result.index = number;
    return result;
  }

  RationalFunction operator+=
  (RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::add(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator-=
  (RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::sub(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator*=
  (RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::mul(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator/=
  (RationalFunction &r1, const RationalFunction &r2) {
    RationalFunction::PolyPairIter number(Base::div(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator+=(RationalFunction &r1, const int i2) {
    RationalFunction r2(i2);
    RationalFunction::PolyPairIter number(Base::add(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator-=(RationalFunction &r1, const int i2) {
    RationalFunction r2(i2);
    RationalFunction::PolyPairIter number(Base::sub(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator*=(RationalFunction &r1, const int i2) {
    RationalFunction r2(i2);
    RationalFunction::PolyPairIter number(Base::mul(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  RationalFunction operator/=(RationalFunction &r1, const int i2) {
    RationalFunction r2(i2);
    RationalFunction::PolyPairIter number(Base::div(r1, r2));
    Base::incRef(number);
    Base::decRef(r1.index);
    r1.index = number;

    return r1;
  }

  void RationalFunction::setCleanupMethod(CleanupMethod method) {
    Base::setCleanupMethod(method);
  }

  ostream &operator<<(ostream &os, const RationalFunction &rat) {
    Polynomial &den(*(rat.index->first.second));
    const mpz_t *coefficients(den.getCoefficients());
    const unsigned *monomials(den.getMonomials());

    bool isPoly(true);
    if (den.getNumTerms() > 1) {
      isPoly = false;
    }
    if (0 != mpz_cmp_si(coefficients[0], 1)) {
      isPoly = false;
    } else {
      for (unsigned symbolNr(0); symbolNr < Base::getNumSymbols(); symbolNr++) {
        if (0 != monomials[symbolNr]) {
          isPoly = false;
          break;
        }
      }
    }
    
    if (isPoly) {
      os << *(rat.index->first.first);
    } else {
      os << "(" << *(rat.index->first.first)
         << ") / (" << *(rat.index->first.second)
         << ")";
    }

    return os;
  }
}
