#include <iostream>
#include <fstream>
#include <cmath>
#include <gmpxx.h>
#include "ResultPrinter.h"

namespace parametric {
  using namespace std;
  using namespace Rational;

  void ResultPrinter::setOutputPrefix(const string &outputPrefix__) {
    outputPrefix = outputPrefix__;
  }

  void ResultPrinter::setOutputFormat(const string &outputFormat__) {
    outputFormat = outputFormat__;
  }

  void ResultPrinter::setResult(const RegionScheduler &result__) {
    result = result__;
  }

  void ResultPrinter::setPlotStep(const mpq_class &plotStep__) {
    plotStep = plotStep__;
  }

  void ResultPrinter::setPlotStep(const string &plotStep__) {
    plotStep = plotStep__;
  }

  void ResultPrinter::setMinimize(bool minimize__) {
    minimize = minimize__;
  }

  void ResultPrinter::print() {
    if ("regions" == outputFormat) {
      printRegionMap();
    } else if ("gnuplot" == outputFormat) {
      printGnuplotFile();
      printDATFile();
    } else if ("dat" == outputFormat) {
      printDATFile();
    }
  }

  void ResultPrinter::printBox(const Box &box, ostream &out) {
    out << "[";
    for (unsigned symNr(0); symNr < box.size(); symNr++) {
      out << "[" << box[symNr].first << ", "
	  << box[symNr].second << "]";
      if (symNr != box.size() - 1) {
	out << " ";
      }
    }
    out << "]";
  }

  void ResultPrinter::printRegionMap() {
    const string outFilename(outputPrefix + ".out");
    ofstream out(outFilename.c_str(), ios::out);

    const unsigned numSymbols(RationalFunction::getNumSymbols());
    out << "[";
    for (unsigned symNr(0); symNr < numSymbols; symNr++) {
      out << RationalFunction::getSymbolName(symNr);
      if (symNr != numSymbols - 1) {
	out << ", ";
      }
    }
    out << "]\n";

    RegionScheduler::const_iterator rit;
    for (rit = result.begin(); rit != result.end(); rit++) {
      printBox(rit->get<0>(), out);
      out << "\n";
      const Results &r(rit->get<1>());
      Results::const_iterator it2;
      for (it2 = r.begin(); it2 != r.end(); it2++) {
	out << "  " << it2->second << "\n";
      }
    }
  }

  void ResultPrinter::printGnuplotFile() {
    const unsigned numSymbols(RationalFunction::getNumSymbols());
    if (2 < numSymbols) {
      throw runtime_error("gnuplot output can only be generated for one or "
			  "two parameters");
    }
    const string gpiFilename(outputPrefix + ".gpi");
    ofstream out(gpiFilename.c_str(), ios::out);
    if (1 == numSymbols) {
      out << "set xrange[0:1]\n"
	  << "set term postscript eps font \"Helvetica, 25\"\n"
	  << "set output \"" << outputPrefix << ".eps\"\n"
	  << "set xlabel \"" << RationalFunction::getSymbolName(0)
	  << "\" font \"Helvetica Italic, 25\"\n"
	  << "plot \"" << outputPrefix << ".dat\" with lines title \""
	  << outputPrefix << "\" lt rgb \"black\"\n";
    } else {
      out << "set data style lines\n"
	  << "set dgrid3d 20,20,3\n"
	  << "set xrange[0:1]\n"
	  << "set yrange[0:1]\n"
	  << "set term postscript eps font \"Helvetica, 25\"\n"
	  << "set output \"" << outputPrefix << ".eps\"\n"
	  << "set xlabel \"" << RationalFunction::getSymbolName(0)
	  << "\" font \"Helvetica Italic, 25\"\n"
	  << "set ylabel \"" << RationalFunction::getSymbolName(1)
	  << "\" font \"Helvetica Italic, 25\"\n"
	  << "set ztics 0.4\n"
	  << "splot \"" << outputPrefix << ".dat\" title \""
	  << outputPrefix << "\" lt rgb \"black\"\n";
    }
  }

  void ResultPrinter::printDATFile() {
    const string datFilename(outputPrefix + ".dat");
    ofstream out(datFilename.c_str(), ios::out);

    const unsigned numSymbols(RationalFunction::getNumSymbols());
    mpq_class pointWidth(plotStep);
    RegionScheduler::const_iterator rit;
    for (rit = result.begin(); rit != result.end(); rit++) {
      const Box &box(rit->get<0>());
      const Results &result(rit->get<1>());
      vector<mpq_class> startPoints;
      vector<unsigned> numPoints;
      unsigned maxCountNumber(1);
      for (unsigned symNr(0); symNr < numSymbols; symNr++) {
	mpq_class leftBar(box[symNr].first / pointWidth);
	mpq_class rightBar(box[symNr].second / pointWidth);
	mpq_class startPoint(ceil(leftBar.get_d()) * pointWidth);
	mpq_class endPoint(floor(rightBar.get_d()) * pointWidth);
	unsigned numP(floor(rightBar.get_d()) - ceil(leftBar.get_d()) + 1);
	// TODO modify if different var ranges become possible
	// TODO modify this for overlapping boxes
	// (may be possible when reading boxes from RSolver splitting)
	if ((box[symNr].second != mpq_class(1))
	    && (box[symNr].second == endPoint)) {
	  numP--;
	}
	startPoints.push_back(startPoint);
	numPoints.push_back(numP);
	maxCountNumber *= numP;
      }

      for (unsigned number(0); number < maxCountNumber; number++) {
	unsigned numberRest(number);
	vector<mpq_class> point;
	for (unsigned symNr(0); symNr < numSymbols; symNr++) {
	  unsigned mult(numberRest % numPoints[symNr]);
	  numberRest /= numPoints[symNr];
	  point.push_back(startPoints[symNr] + mult * pointWidth);
	}
	Results::const_iterator it2;
	mpq_class optPointValue(minimize ? 2 : -1);
	for (it2 = result.begin(); it2 != result.end(); it2++) {
	  RationalFunction ratFun(it2->second);
	  mpq_class pointValue(ratFun.evaluate(point));
	  optPointValue = ((minimize && (optPointValue < pointValue))
			   || (!minimize && (optPointValue > pointValue)))
	    ? optPointValue : pointValue;
	}
	for (unsigned symNr(0); symNr < numSymbols; symNr++) {
	  out << point[symNr].get_d() << "  ";
	}
	out << optPointValue.get_d() << "\n";
      }
    }
  }
}
