#include "gpgpu_frame.h"
#include <cstdarg>

KernelSrc::KernelSrc() : 
	d_keywords(NULL),
	d_indent(0)
{
}

std::string KernelSrc::getSrc() {
	if (!d_postExpandSrc.size()) {
		// We're running the inherited source generation
		genSrc();
		if (!d_preExpandSrc.size())
			throw std::string("Empty sources");

		d_postExpandSrc = d_keywords->header() + d_preExpandSrc;
		expandKeywords();
	}

	return d_postExpandSrc;
}

void KernelSrc::incrementIndent(int steps) {
	d_indent += 2*steps;
}

void KernelSrc::decrementIndent(int steps) {
	d_indent -= 2*steps;
}

void KernelSrc::append(const char *format, ...) {
	va_list list;
	char buffer[10240];
	va_start(list, format);
	vsprintf(buffer, format, list);

	append(std::string(buffer));
	va_end(list);
}

std::string KernelSrc::genStr(const char *format, ...) {
	va_list list;
	char buffer[10240];
	va_start(list, format);
	vsprintf(buffer, format, list);

	std::string str;

	str += std::string(buffer);
	va_end(list);

	return str;
}

void KernelSrc::useAsm(const std::string fname) {
	d_uploadAsm = readFile(fname);
}

std::string KernelSrc::getAsm() {
	return d_uploadAsm;
}

void KernelSrc::append(const std::string str) {
	if (d_postExpandSrc.size())
		throw std::string("Already expanded, can't append anymore");

	// We indent after newlines
	if (d_preExpandSrc.length() && d_preExpandSrc.at(d_preExpandSrc.length()-1) == '\n')
		d_preExpandSrc += std::string(d_indent, ' ');

	for (int i = 0; i < str.length(); ++i) {
		d_preExpandSrc += str.at(i);
		if (str.at(i) == '\n' && i != (str.length()-1))
			d_preExpandSrc += std::string(d_indent, ' ');
	}
}

void KernelSrc::write(std::string fname, std::string src) {
	FILE *f = fopen(fname.c_str(), "wo");
	if (!f)
		throw std::string("Couldn't open file for writing");

	size_t written = fwrite(src.c_str(), sizeof(char), src.size(), f);
	printf("Writing kernel to %s\n", fname.c_str());
	if (written != src.size())
		fprintf(stderr, "Was only able to write %ld out of %ld total bytes to %s\n", written, src.size(), fname.c_str());
	fclose(f);
}

void KernelSrc::write(std::string fname) {
	std::string finalFname = fname + (d_keywords ? d_keywords->fileExt() : "");

	write(finalFname, getSrc());
}

void KernelSrc::replace(std::string src, std::string dest) {
	if (!d_postExpandSrc.size())
		getSrc(); // FIXME: We just generate it here then
		//throw std::string("Replace is meant for post-expanded sources, which are not yet generated");
	
	size_t replacePos = d_postExpandSrc.find(src);
	if (replacePos == std::string::npos)
		return;
	size_t replaceLength = src.size();
	d_postExpandSrc.replace(replacePos, replaceLength, dest);
}

void KernelSrc::expandKeywords() {
	if (!d_keywords)
		throw std::string("Can't replace keywords without a dictionary");

	// First we find enclosing #s and extract the string inbetween
	size_t curPos = 0;
	while ((curPos = d_postExpandSrc.find("#", curPos)) != std::string::npos) {
		// Finding the next
		size_t nextPos = d_postExpandSrc.find("#", curPos+1);
		if (nextPos == std::string::npos)
			return;

		// We've got something inbetween curPos and nextPos
		std::string match = d_postExpandSrc.substr(curPos+1, nextPos-curPos-1);
		std::string oldMatch = match;

		if (d_keywords->replaceContent(match)) {
			d_postExpandSrc.replace(curPos, nextPos-curPos+1, match);
			//printf("Keyword replacement: \"%s\"  -->  \"%s\"\n", oldMatch.c_str(), match.c_str());
			curPos += match.length(); //nextPos+1;
		} else
			// No match, we therefore can assume that the # belonged in the code.
			curPos++;
	}
}

std::string KernelSrc::readFile(const std::string fname) {
	std::string src;

	FILE *f = fopen(fname.c_str(), "ro");
	if (!f)
		throw std::string("Couldn't open file for reading kernel src");

	const size_t segment = 1024; // Kilobyte at a time
	char data[segment+1];
	size_t read = 0;
	do {
		read = fread(data, sizeof(char), segment, f);
		data[read] = 0;
		src += std::string(data);
		printf("Read %ld bytes\n", read);
	} while (read == segment);

	return src;
}

void KernelSrc::appendFromFile(const std::string fname) {
	/*// In case we set something important during gen, we do it before overriding
	if (!d_src.size())
		genSrc();
	d_src = "";*/

	append(readFile(fname));
}

void KernelSrc::setKeywords(gpgpuKeywords *keywords) {
	d_keywords = keywords;
}


std::string gpgpuKeywords::getNextParam(std::string a, size_t start) {
	size_t end = a.find('}', start);
	if (end == std::string::npos)
		return "";

	std::string r = a.substr(start, end-start);

	// This kludge is for avoiding the rare cases of empty parameters 
	// (such as function and kernel declarations without input params)
	return r == "" ? " " : r;
}

bool gpgpuKeywords::replaceContent(std::string &a) {
	std::string keyword = a.substr(0, a.find('{'));

	// Zero parameter keywords
	if (keyword == "blockX") a = blockX();
	else if (keyword == "blockY") a = blockY();
	else if (keyword == "threadX") a = threadX();
	else if (keyword == "threadY") a = threadY();
	else if (keyword == "blockDimX") a = blockDimX();
	else if (keyword == "blockDimY") a = blockDimY();
	else if (keyword == "globalThreadX") a = globalThreadX();
	else if (keyword == "globalThreadY") a = globalThreadY();
	else if (keyword == "sharedMem") a = sharedMem();
	else if (keyword == "constMem") a = constMem();
	else if (keyword == "localSync") a = localSync();
	else if (keyword == "float2Operators") a = float2Operators();
	else if (keyword == "float3Operators") a = float3Operators();
	else if (keyword == "vectorOperators") a = vectorOperators();
	else if (keyword == "halfType") a = halfType();
	else if (keyword == "regType") a = regType();
	else if (keyword == "sharedType") a = sharedType();
	else if (keyword == "globalType") a = globalType();

	// One parameter keywords
	else {
		size_t start = keyword.length()+1;

		std::string p1 = getNextParam(a, start);
		if (p1 == "")
			//fprintf(stderr, "Warning:  Unknown keyword (or not enough params for) \"%s\"\n", keyword.c_str());
			return false;

		if (keyword == "sqrt") a = sqrt(p1);
		else if (keyword == "rsqrt") a = rsqrt(p1);
		else if (keyword == "sin") a = sin(p1);
		else if (keyword == "cos") a = cos(p1);
		else if (keyword == "exp2") a = exp2(p1);
		else if (keyword == "absf") a = absf(p1);
		else if (keyword == "rcp") a = rcp(p1);
		else if (keyword == "floorf") a = floorf(p1);
		/*else if (keyword == "float2Half") a = float2Half(p1);
		else if (keyword == "half2Float") a = half2Float(p1);*/

		// Two parameter keywords
		else {
			start = a.find('{', start);
			
			std::string p2 = getNextParam(a, start+1);
			if (start == std::string::npos || p2 == "")
				return false;

			if (keyword == "kernelDecl") a = kernelDecl(p1, p2);
			else if (keyword == "float2Ctor") a = float2Ctor(p1, p2);
			else if (keyword == "writeHalf") a = writeHalf(p1, p2);
			else if (keyword == "div") a = div(p1, p2);
			else if (keyword == "maxf") a = maxf(p1, p2);
			else if (keyword == "minf") a = minf(p1, p2);
			else if (keyword == "pow") a = pow(p1, p2);
			else if (keyword == "atomicMin") a = atomicMin(p1, p2);
			else if (keyword == "atomicMax") a = atomicMax(p1, p2);
			else if (keyword == "atomicAdd") a = atomicAdd(p1, p2);

			// Three parameter keywords
			else {
				start = a.find('{', start+1);
				
				std::string p3 = getNextParam(a, start+1);
				if (start == std::string::npos || p3 == "")
					return false;

				if (keyword == "float3Ctor") a = float3Ctor(p1, p2, p3);
				else if (keyword == "sincos") a = sincos(p1, p2, p3);
				else if (keyword == "tex2DSample1") a = tex2DSample1(p1, p2, p3);
				else if (keyword == "tex2DSample4") a = tex2DSample4(p1, p2, p3);
								
				// Specially treated
				else if (keyword == "funcDecl") {
					// The remaining parameters will be textures
					std::vector<std::string> textures;

					while ((start = a.find('{', start+1)) != std::string::npos) {
						std::string newTex = getNextParam(a, start+1);
						textures.push_back(newTex);
					}
					
					a = funcDecl(p1, p2, p3, textures);
				}

				// Four parameter keywords
				else {
					start = a.find('{', start+1);
					
					std::string p4 = getNextParam(a, start+1);
					if (start == std::string::npos || p4 == "")
						return false;

					if (keyword == "float4Ctor") a = float4Ctor(p1, p2, p3, p4);
					else if (keyword == "tex3DSample") a = tex3DSample(p1, p2, p3, p4);
					else if (keyword == "surf2DRead") a = surf2DRead(p1, p2, p3, p4);
					else if (keyword == "readHalf2") a = readHalf2(p1, p2, p3, p4);
					else if (keyword == "writeHalf2") a = writeHalf2(p1, p2, p3, p4);

					// Five parameter keywords
					else { 
						start = a.find('{', start+1);
						
						std::string p5 = getNextParam(a, start+1);
						if (start == std::string::npos || p5 == "")
							return false;


						else if (keyword == "surf2DWrite") a = surf2DWrite(p1, p2, p3, p4, p5);
						else
							return false;
					}
				}
			}
		}
	}

	return true;
}
