#include "lineIndex.h"
#include <cmath>
#include <algorithm>


LineIndex::LineIndex(SSEOConfiguration *config) :
	d_config(config),
	d_lineInfos(NULL) {
}

void LineIndex::setLI(std::vector<struct LineInfo> *d) {
	d_lineInfos = d;
}

void LineIndex::setDirOffsets(std::vector<int> a) {
	d_dirOffsets = a;
}

bool LineIndex::lineCompare(struct LineCandidate a, struct LineCandidate b) {
	return a.distance < b.distance;
}

bool LineIndex::indexCompare(struct LineCandidate a, struct LineCandidate b) {
	return a.dirIndex < b.dirIndex;
}

bool LineIndex::lengthCompare(struct PixCandidate a, struct PixCandidate b) {
	return a.distance < b.distance;
}

bool LineIndex::dirCompare(struct PixCandidate a, struct PixCandidate b) {
	return a.dirIndex < b.dirIndex;
}

bool LineIndex::writeCompare(struct PixCandidate a, struct PixCandidate b) {
	if (a.dirIndex < b.dirIndex) return true;
	if (a.dirIndex > b.dirIndex) return false;
	if (a.globalIndex < b.globalIndex) return true;
	if (a.globalIndex > b.globalIndex) return false;
	return a.writeIndex < b.writeIndex;
	//return a.globalIndex > b.globalIndex;
}

float2 LineIndex::snapCoord(float2 c) {
	float2 outPos;
	if (d_config->stepInterpolation() == 0)
		outPos = c;
	else if (d_config->stepInterpolation() == 1 || d_config->stepInterpolation() == 2)
		// ABOUT FILTERING MODE 2:  We don't know the height here, so we don't know which filtering will be selected.
		outPos = float2((float)((int)(c.x*(float)d_config->hfWidth()))*1.0f/(float)d_config->hfWidth() + 0.5f/(float)d_config->hfWidth(), 
			(float)((int)(c.y*(float)d_config->hfHeight()))*1.0f/(float)d_config->hfHeight() + 0.5f/(float)d_config->hfHeight());
	else
		throw std::string("Unknown filtering mode");

	return outPos;
}

bool LineIndex::insideOcc(int x, int y) {
	return (x >= 0 && y >= 0 && x < d_config->occWidth() && y < d_config->occHeight());
}

void LineIndex::updatePix(std::vector<struct PixCandidate>* pool, struct PixCandidate px) {
	// The idea behind updating a pixel is as follows:
	// If a pixel with this dir already exists, replace it if its distance is larger than this one.
	// If it is smaller, ignore the new px.
	// Otherwise, push the new pix at the end.
	for (int i = 0; i < pool->size(); ++i)
		if (pool->at(i).dirIndex == px.dirIndex) {
			if (pool->at(i).distance > px.distance)
				pool->at(i) = px;
			return; // We break out 'cause we matched a dir
		}
	// We're here because we didn't match..  Pushing
	pool->push_back(px);
}

void LineIndex::organizeBlock(std::vector<struct PixCandidate> *block) {
	#if 0
	// We 
	#else
	// We sort all items by:  1 dirIndex  2 row  3 column
	std::sort(block->begin(), block->end(), writeCompare);
	#endif
}

void LineIndex::optimizeWrites(std::vector<std::vector<std::vector<struct PixCandidate> > > *pixHits) {
	// We process everything in thread blocks..  And only reorganize full blocks
	for (int blockY = 0; blockY < d_config->occHeight()/d_config->accBlockY(); ++blockY)
		for (int blockX = 0; blockX < d_config->occWidth()/d_config->accBlockX(); ++blockX) {
			std::vector<struct PixCandidate> pixBlock;
			for (int localY = 0; localY < d_config->accBlockY(); ++localY)
				for (int localX = 0; localX < d_config->accBlockX(); ++localX) {
					int globalY = blockY*d_config->accBlockY() + localY;
					int globalX = blockX*d_config->accBlockX() + localX;

					for (int singleHit = 0; singleHit < d_config->linesPerPixel(); ++singleHit)
						pixBlock.push_back(pixHits->at(globalY).at(globalX).at(singleHit));
				}

			organizeBlock(&pixBlock);

			// Now that the block is okay, we split it back
			int progressIndex = 0;
			for (int singleHit = 0; singleHit < d_config->linesPerPixel(); ++singleHit)
				for (int localY = 0; localY < d_config->accBlockY(); ++localY)
					for (int localX = 0; localX < d_config->accBlockX(); ++localX) {
						int globalY = blockY*d_config->accBlockY() + localY;
						int globalX = blockX*d_config->accBlockX() + localX;

						pixHits->at(globalY).at(globalX).at(singleHit) = pixBlock.at(progressIndex++);
					}
		}
}

void LineIndex::checkBounds(size_t bufSize) {
	// We're not checking if non-normal buffer types are used
	if (d_config->useV4Buffers() || d_config->useHalfBuffers())
		return;

	for (int line = 0; line < d_lineInfos->size(); ++line) {
		struct LineInfo li = d_lineInfos->at(line);
		//int myStripe = (li.dirIndex < d_config->dirs()/2) ? d_config->sweepStripe() : -d_config->sweepStripe();
		int myStripe = (li.dirIndex < d_config->dirs()/2) ? d_config->sweepStorageStripe() : -d_config->sweepStorageStripe();
		if (!d_config->matchOpposite())
			//myStripe = d_config->sweepStripe();
			myStripe = d_config->sweepStorageStripe();

		int stepCounter = 0;
		if (li.idleSteps > 1000)
			throw std::string("An oddly big number of idle steps");

		while (li.idleSteps > 0) {
			li.startPos += li.stepDir;
			li.idleSteps--;
			li.numSteps--;
			stepCounter++;
		}

		while (li.numSteps > 0) {
			int writePos;
		   	if (d_config->matchOpposite()) {
				if (myStripe > 0)
					writePos = li.layerDistance*3 + 0;
				else
					writePos = li.layerDistance*3 + 2;
			} else {
				writePos = li.layerDistance*2;
			}

			if (writePos < 0) {
				fprintf(stderr, "Sweep step on line %d, step %d, at location %d underflowed\n",
						line, stepCounter, writePos);
				exit(13);
			}
			if (writePos >= bufSize) {
				fprintf(stderr, "Sweep step on line %d, step %d, at location %d overflowed (a buffer of size %ld elems)\n",
						line, stepCounter, writePos, bufSize);
				printf("Line %d, dir index %d, start index %d, stripe %d, num steps %d\n",
						line,
						li.dirIndex, li.layerDistance, myStripe, li.numSteps+stepCounter);
				exit(12);
			}
			stepCounter++;
			li.layerDistance += myStripe;
			li.numSteps--;
		}
	}
}

size_t LineIndex::genData(unsigned int **indPtrPtr) {
	if (!d_lineInfos)
		throw std::string("Can't gen LineIndex data without LineInfo pointer set");

	if (d_config->linesPerPixel() > d_config->dirs())
		throw std::string("More than K lines per pixel should be done via post processing");

	/*if (d_config->stepInterpolation() != 1 && d_config->stepInterpolation() != 0)
		throw std::string("Only snap impl. atm");*/

	if (d_config->edgeAwareAcc() == 2 && d_config->gatherBuffer())
		throw std::string("Right now, edge aware acc 2 and gather buffer don't mix");

	// We go through every line, step through each coordinate, 
	// and splat the coordinate and mem location to nearby pixels.
	// We keep d_config->linesPerPixel() closest hits for each pixel, but not two for the same direction
	struct PixCandidate emptyPix = {
		-1, 0, 0, 0, 0, 0.0f
	};

	std::vector<std::vector<std::vector<struct PixCandidate> > > pixHits;
	for (int y = 0; y < d_config->occHeight(); ++y) {
		pixHits.push_back(std::vector<std::vector<struct PixCandidate> >());
		for (int x = 0; x < d_config->occWidth(); ++x)
			pixHits.at(y).push_back(std::vector<struct PixCandidate>());
	}

	float borderSideW = (d_config->hfWidth() - d_config->occWidth())/2;
	float borderSideH = (d_config->hfHeight() - d_config->occHeight())/2;

	// The definition of a neighborhood is half the radius of (lineSkip, stepSkip)..  We round this up.
	float neighborRadius2 = (d_config->lineSkip()*d_config->lineSkip() + d_config->stepSkip()*d_config->stepSkip())/2.0f;
	int pixRadius2 = (int)(neighborRadius2 + 0.5f); //+4; //+2;
	//printf("Pixradius^2 %d\n", pixRadius2);

	// Now we are ready to traverse the lines
	for (int line = 0; line < d_lineInfos->size(); ++line) {
		struct LineInfo li = d_lineInfos->at(line);
		//int myStripe = (li.dirIndex < d_config->dirs()/2) ? d_config->sweepStripe() : -d_config->sweepStripe();
		// We don't count mirrored lines
		if (d_config->matchOpposite() && li.dirIndex > d_config->dirs()/2-1)
			continue;

		/*printf("start %f %f, step %f %f, steps %d, idlesteps %d, dirindex %d\n",
				li.startPos.x, li.startPos.y,
				li.stepDir.x, li.stepDir.y,
				li.numSteps, li.idleSteps, li.dirIndex);*/

		if (li.numSteps) {
			// Going through all the steps
			for (int step = 0; step < li.numSteps; ++step) {
				float2 pos = li.startPos + li.stepDir*(float)step;

				float2 snapPos = pos;
				snapPos = snapCoord(pos);
				
				//printf("%f %f -> %f %f", pos.x, pos.y, snapPos.x, snapPos.y);
				snapPos *= float2((float)d_config->hfWidth(), (float)d_config->hfHeight());
				//printf("(%f %f)\n", snapPos.x, snapPos.y);

				int posX = (int)snapPos.x - borderSideW;
				int posY = (int)snapPos.y - borderSideH;

				if (!insideOcc(posX, posY))
					continue;

				int meatStep = step - li.idleSteps;

				if (meatStep < -5) {
					fprintf(stderr, "We supposedly found a hit far outside occbox (meatiters %d)\n", meatStep);
				}

				//int stepLocation = li.layerDistance + meatStep*d_config->sweepStripe();
				int stepLocation = li.layerDistance + meatStep*d_config->sweepStorageStripe();

				if (stepLocation < 0)
					stepLocation = 0;

				// We're ready to splat this contribution
				for (int y = posY-pixRadius2; y <= posY+pixRadius2; ++y) {
					int mirrorY = d_config->occHeight() - y - 1;
					for (int x = posX-pixRadius2; x <= posX+pixRadius2; ++x) {
						if (!insideOcc(x, y))
							continue;

						unsigned int writeIndex = (mirrorY % d_config->accBlockY())*d_config->accBlockX() + (x % d_config->accBlockX());
						float distance2 = (float)((x-posX)*(x-posX) + (y-posY)*(y-posY));

						struct PixCandidate px = {
							li.dirIndex,
							stepLocation,
							posX, posY,
							writeIndex,
							distance2
						};

						updatePix(&pixHits.at(y).at(x), px);
					}
				}

				/*if (insideOcc(posX, posY))
					printf("\t inside\n");
				else
					printf("\t outside\n");*/
			}
		}

		int reportOffset = d_lineInfos->size()/100;
		if (!(line%reportOffset)) {
			float progress = (float)(line+1)/(float)d_lineInfos->size()*100.0f;
			printf("Generating gather lists (%d%% done)\r", (int)(progress + 0.49f));
			fflush(stdout);
		}
	}
	printf("\n");

	// Now that we have all the relevant pixels in, and no duplicates (per-dir)
	// we can sort them by the distance and get the first ones..
	// Unless we need all the dirs..
	int maxDirs = d_config->dirs();
	if (d_config->matchOpposite())
		maxDirs /= 2;
	if (d_config->linesPerPixel() < maxDirs) {
		for (int y = 0; y < d_config->occHeight(); ++y)
			for (int x = 0; x < d_config->occWidth(); ++x) {
				std::vector<struct PixCandidate> *pixes = &pixHits.at(y).at(x);
				std::sort(pixes->begin(), pixes->end(), lengthCompare);
				//printf("(%d, %d): %d pixels\n", y, x, pixes->size());
				while (pixes->size() != d_config->linesPerPixel()) {
					//struct PixCandidate reject = pixes->back();
					pixes->pop_back();
					//printf("(%d, %d) Rejected dist %f, dir %d\n", y, x, reject.distance, reject.dirIndex);
				}
			}
	}
	// And in all cases we sort by dirIndex..
	// FIXME: This is not the optimal strategy, actually.  You should maximize per-iter coherency
	for (int y = 0; y < d_config->occHeight(); ++y)
		for (int x = 0; x < d_config->occWidth(); ++x)
			std::sort(pixHits.at(y).at(x).begin(), pixHits.at(y).at(x).end(), dirCompare);

	// This is an important step:  We reorder all pixels within a thread block
	// to maximize write coalescing.
	if (d_config->gatherBuffer())
		optimizeWrites(&pixHits);

	// Then we allocate the resultant buffer	
	size_t size = sizeof(unsigned int)*
		d_config->occHeight()*
		d_config->occWidth()*
		d_config->linesPerPixel();

	if (d_config->edgeAwareAcc() == 2)
		size *= 2;
	
	if (d_config->gatherBuffer())
		size *= 2;

	*indPtrPtr = (unsigned int*) malloc(size);
	unsigned int *indData = *indPtrPtr;

	int incompletePixels = 0;
	double totalDist = 0.0;
	// Vomiting the data..
	for (int y = 0; y < d_config->occHeight(); ++y)
		for (int x = 0; x < d_config->occWidth(); ++x)
			for (int px = 0; px < d_config->linesPerPixel(); ++px) {
				int mirrorY = d_config->occHeight() - y - 1;

				int index = (mirrorY*d_config->occWidth() + x) +
					(d_config->occWidth()*d_config->occHeight())*px; //*d_config->linesPerPixel() + px;
				struct PixCandidate pixel;

				if (px >= pixHits.at(y).at(x).size()) {
					incompletePixels++;
					pixel = emptyPix;
				} else {
					pixel = pixHits.at(y).at(x).at(px);
					totalDist += sqrt(pixel.distance);
				}

				if (d_config->edgeAwareAcc() == 2) {
					indData[index*2 + 0] = pixel.globalIndex;
					indData[index*2 + 1] = pixel.pixCoordX | (pixel.pixCoordY << 16);
				} else if (d_config->gatherBuffer()) {
					indData[index*2 + 0] = pixel.globalIndex;
					indData[index*2 + 1] = pixel.writeIndex;
				} else
					indData[index] = pixel.globalIndex;
			}

	printf("%d pixels didn't have full list of contributors\n", incompletePixels);
	if ((float)incompletePixels/(float)(d_config->occWidth()*d_config->occHeight()) > 0.05f) {
		printf("\tThis is too much!  Aborting\n");
		throw std::string("Too little contributors to too many pixels");
	}

	printf("Average distance for a gathered pixel: %f\n",
			totalDist/(double)(d_config->occHeight()*d_config->occWidth()*d_config->linesPerPixel()));

	return size;
}
