#include "Pathfinding.h"
#include "DEFINES.h"
#include "Crawler.h"

INCLUDE_game

void Pathfinding::Initialize(){
    if(nodes!=nullptr){
        delete[] nodes;
    }
    nodes = new sNode[game->GetCurrentMap().width * game->GetCurrentMap().height];
		for (int x = 0; x < game->GetCurrentMap().width; x++)
			for (int y = 0; y < game->GetCurrentMap().height; y++)
			{
				nodes[y * game->GetCurrentMap().width + x].x = x; // ...because we give each node its own coordinates
				nodes[y * game->GetCurrentMap().width + x].y = y;
                geom2d::rect<int>tile=game->GetTileCollision(game->GetCurrentLevel(),{float(x*game->GetCurrentMap().tilewidth),float(y*game->GetCurrentMap().tilewidth)});
				nodes[y * game->GetCurrentMap().width + x].bObstacle = tile.pos!=game->NO_COLLISION.pos||tile.size!=game->NO_COLLISION.size;
                tile=game->GetTileCollision(game->GetCurrentLevel(),{float(x*game->GetCurrentMap().tilewidth),float(y*game->GetCurrentMap().tilewidth)},true);
                nodes[y * game->GetCurrentMap().width + x].bObstacleUpper = tile.pos!=game->NO_COLLISION.pos||tile.size!=game->NO_COLLISION.size;
				nodes[y * game->GetCurrentMap().width + x].parent = nullptr;
				nodes[y * game->GetCurrentMap().width + x].bVisited = false;
			}

		for (int x = 0; x < game->GetCurrentMap().width; x++)
			for (int y = 0; y < game->GetCurrentMap().height; y++)
			{
				if(y>0)
					nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y - 1) * game->GetCurrentMap().width + (x + 0)]);
				if(y<game->GetCurrentMap().height-1)
					nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y + 1) * game->GetCurrentMap().width + (x + 0)]);
				if (x>0)
					nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y + 0) * game->GetCurrentMap().width + (x - 1)]);
				if(x<game->GetCurrentMap().width-1)
					nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y + 0) * game->GetCurrentMap().width + (x + 1)]);
                if (y>0 && x>0)
                    nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y - 1) * game->GetCurrentMap().width + (x - 1)]);
                if (y<game->GetCurrentMap().height-1 && x>0)
                    nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y + 1) * game->GetCurrentMap().width + (x - 1)]);
                if (y>0 && x<game->GetCurrentMap().width-1)
                    nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y - 1) * game->GetCurrentMap().width + (x + 1)]);
                if (y<game->GetCurrentMap().height - 1 && x<game->GetCurrentMap().width-1)
                    nodes[y*game->GetCurrentMap().width + x].vecNeighbours.push_back(&nodes[(y + 1) * game->GetCurrentMap().width + (x + 1)]);
			}

		// Manually position the start and end markers so they are not nullptr
		nodeStart = &nodes[(game->GetCurrentMap().height / 2) * game->GetCurrentMap().width + 1];
		nodeEnd = &nodes[(game->GetCurrentMap().height / 2) * game->GetCurrentMap().width + game->GetCurrentMap().width-2];
}

std::vector<vf2d> Pathfinding::Solve_AStar(vf2d startPos,vf2d endPos,float maxRange,bool upperLevel){
    float dist=sqrt(pow(endPos.x-startPos.x,2)+pow(endPos.y-startPos.y,2));
    if(dist>maxRange*game->GetCurrentMap().tilewidth)return {};

    nodeStart=&nodes[int(startPos.y/game->GetCurrentMap().tilewidth)*game->GetCurrentMap().width+int(startPos.x/game->GetCurrentMap().tilewidth)];
    nodeEnd=&nodes[int(endPos.y/game->GetCurrentMap().tilewidth)*game->GetCurrentMap().width+int(endPos.x/game->GetCurrentMap().tilewidth)];

    
    geom2d::rect<int>posPerimeter{{int(std::min(startPos.x,endPos.x)),int(std::min(startPos.y,endPos.y))},{int(abs(endPos.x-startPos.x)),int(abs(endPos.y-startPos.y))}};
    posPerimeter.pos={int(std::clamp(posPerimeter.pos.x-maxRange*game->GetCurrentMap().tilewidth,0.f,game->GetCurrentMap().width*float(game->GetCurrentMap().tilewidth))),int(std::clamp(posPerimeter.pos.y-maxRange*game->GetCurrentMap().tilewidth,0.f,game->GetCurrentMap().height*float(game->GetCurrentMap().tilewidth)))};
    posPerimeter.size={int(std::clamp(posPerimeter.size.x+maxRange*game->GetCurrentMap().tilewidth*2,0.f,game->GetCurrentMap().width*float(game->GetCurrentMap().tilewidth)-posPerimeter.pos.x)),int(std::clamp(posPerimeter.size.y+maxRange*game->GetCurrentMap().tilewidth*2,0.f,game->GetCurrentMap().height*float(game->GetCurrentMap().tilewidth)-posPerimeter.pos.y))};
    
    for (int x = 0; x < game->GetCurrentMap().width; x++){
        for (int y = 0; y < game->GetCurrentMap().height; y++){
            if(geom2d::overlaps(posPerimeter,vi2d{x*game->GetCurrentMap().tilewidth,y*game->GetCurrentMap().tilewidth})){
                nodes[y*game->GetCurrentMap().width + x].bVisited = false;
            } else {
                nodes[y*game->GetCurrentMap().width + x].bVisited = true;
            }
            nodes[y*game->GetCurrentMap().width + x].fGlobalGoal = INFINITY;
            nodes[y*game->GetCurrentMap().width + x].fLocalGoal = INFINITY;
            nodes[y*game->GetCurrentMap().width + x].parent = nullptr;	// No parents
        }
    }

    auto distance = [](sNode* a, sNode* b) // For convenience
    {
        return sqrtf((a->x - b->x)*(a->x - b->x) + (a->y - b->y)*(a->y - b->y));
    };

    auto heuristic = [distance](sNode* a, sNode* b)
    {
        return distance(a, b);
    };

    sNode *nodeCurrent = nodeStart;
    nodeStart->fLocalGoal = 0.0f;
    nodeStart->fGlobalGoal = heuristic(nodeStart, nodeEnd);

    std::list<sNode*> listNotTestedNodes;
    //if((!upperLevel && nodeStart->bObstacle)||(upperLevel && nodeStart->bObstacleUpper))return {};
    listNotTestedNodes.push_back(nodeStart);

    while (!listNotTestedNodes.empty() && nodeCurrent != nodeEnd)
    {
        listNotTestedNodes.sort([](const sNode* lhs, const sNode* rhs){ return lhs->fGlobalGoal < rhs->fGlobalGoal; } );
        
        while(!listNotTestedNodes.empty() && listNotTestedNodes.front()->bVisited)
            listNotTestedNodes.pop_front();
        if (listNotTestedNodes.empty())
            break;

        nodeCurrent = listNotTestedNodes.front();
        nodeCurrent->bVisited = true;
        for (auto nodeNeighbour : nodeCurrent->vecNeighbours)
        {
            if (!nodeNeighbour->bVisited && ((!upperLevel && nodeNeighbour->bObstacle == 0)||(upperLevel && nodeNeighbour->bObstacleUpper==0)))
                listNotTestedNodes.push_back(nodeNeighbour);

            float fPossiblyLowerGoal = nodeCurrent->fLocalGoal + distance(nodeCurrent, nodeNeighbour);

            if (fPossiblyLowerGoal < nodeNeighbour->fLocalGoal)
            {
                nodeNeighbour->parent = nodeCurrent;
                nodeNeighbour->fLocalGoal = fPossiblyLowerGoal;
                nodeNeighbour->fGlobalGoal = nodeNeighbour->fLocalGoal + heuristic(nodeNeighbour, nodeEnd);
            }
        }	
    } 

    if (nodeEnd != nullptr)
    {
        int pathLength=INFINITE;
        sNode *p = nodeEnd;
        std::vector<vf2d>finalPath;
        while (p->parent != nullptr)
        {
            if(pathLength==INFINITE){
                pathLength=1;
            } else {
                pathLength++;
            }
            finalPath.push_back({float((*p).x),float((*p).y)});
            p = p->parent;
        }
        std::reverse(finalPath.begin(),finalPath.end());
        return finalPath;
    } else {
        return {};
    }
}