#pragma once
#include <QOpenGLFunctions_4_5_Core>
#include <QOpenGLShaderProgram>
#include <QOpenGLTexture>
#include <QVector3D>
#include <QVector2D>
#include <QVector4D>
#include <vector>
#include <memory>
#include <QString>

// 顶点数据结构（位置+纹理坐标）
struct CubeVertex {
	QVector3D pos;    // 位置
	QVector2D texCoord; // 纹理坐标
};

class ViewCube : protected QOpenGLFunctions_4_5_Core {
public:
	// 构造函数：传入纹理路径
	explicit ViewCube(QOpenGLFunctions_4_5_Core* gl, const QString& texPath, unsigned int windowWidth, unsigned int windowHeight)
		: m_gl(gl), m_texPath(texPath), m_windowWidth(windowWidth), m_windowHeight(windowHeight) {
		init();
	}

	// 扩展draw函数：接收view和projection矩阵并存储（关键修改）
	void draw(QOpenGLShaderProgram& shader, const QMatrix4x4& model,
		const QMatrix4x4& view, const QMatrix4x4& projection) {
		// 存储当前绘制的矩阵（用于后续检测）
		m_modelMatrix = model;
		m_viewMatrix = view;
		m_projectionMatrix = projection;

		// 原有绘制逻辑
		m_gl->glBindVertexArray(m_vao);
		shader.setUniformValue("model", model);
		m_texture->bind(0);
		shader.setUniformValue("u_texture", 0);
		m_gl->glDrawElements(GL_TRIANGLES, 36, GL_UNSIGNED_INT, nullptr);
		m_gl->glBindVertexArray(0);
		m_texture->release();
	}

	enum class CubeFaceType {
		None,    // 无点击 0
		Top,     // 上面 1
		Bottom,  // 下面 2
		Left,    // 左面 3
		Right,   // 右面 4
		Front,   // 前面 5
		Back     // 后面 6
	};
	
	// 设置窗口宽高（外部窗口大小变化时调用）
	void setWindowSize(int width, int height) {
		m_windowWidth = width;
		m_windowHeight = height;
	}

	// 核心函数：判断鼠标点击射线与ViewCube最先相交的面
	CubeFaceType getFirstHitFace(
		float mouseX, float mouseY,
		int windowWidth, int windowHeight,
		const QMatrix4x4& viewMatrix,
		const QMatrix4x4& projectionMatrix)
	{
		// 1. 生成射线
		Ray ray = createRayFromScreenCoord(mouseX, mouseY, windowWidth, windowHeight, viewMatrix, projectionMatrix);

		// 2. 获取ViewCube的6个面
		std::vector<Plane> cubePlanes = createViewCubePlanes();

		// 3. 遍历所有面，找最小的有效t
		float minT = std::numeric_limits<float>::max();
		CubeFaceType hitFace = CubeFaceType::None; // 默认值

		for (const auto& plane : cubePlanes) {
			float t = getIntersectionT(ray, plane);
			if (t > 0 && t < minT) { // 有效且更小的t
				minT = t;
				hitFace = plane.face;
			}
		}

		return hitFace;
	}

private:
	// 射线结构体
	struct Ray {
		QVector3D origin;    // 起点
		QVector3D direction; // 方向向量（归一化）
	};

	// 平面方程（ax + by + cz + d = 0）
	struct Plane {
		float a, b, c, d;
		CubeFaceType face; // 关联的面
		Plane(float a, float b, float c, float d, CubeFaceType face)
			: a(a), b(b), c(c), d(d), face(face) {
		}
	};

	// 生成ViewCube的6个面（假设Cube中心在原点，边长为2，范围[-1,1]）
	std::vector<Plane> createViewCubePlanes() {
		return {
			// 前面（z=0.5）：法向量+z（向外）
			Plane(0, 0, 1, -0.5f, CubeFaceType::Front),
			// 后面（z=-0.5）：法向量-z（向外）
			Plane(0, 0, -1, -0.5f, CubeFaceType::Back),  // 修正：c=-1，d=-0.5
			// 左面（x=-0.5）：法向量-x（向外）
			Plane(-1, 0, 0, -0.5f, CubeFaceType::Left),  // 修正：a=-1，d=-0.5
			// 右面（x=0.5）：法向量+x（向外）
			Plane(1, 0, 0, -0.5f, CubeFaceType::Right),  // 修正：d=-0.5
			// 顶面（y=0.5）：法向量+y（向外）
			Plane(0, 1, 0, -0.5f, CubeFaceType::Top),    // 修正：b=1，d=-0.5
			// 底面（y=-0.5）：法向量-y（向外）
			Plane(0, -1, 0, -0.5f, CubeFaceType::Bottom) // 修正：b=-1，d=-0.5
		};
	}//n*(P-p0)

	// 计算射线与平面的交点参数t（返回-1表示不相交）
	float getIntersectionT(const Ray& ray, const Plane& plane) {
		const float eps = 1e-6;
		float denom = plane.a * ray.direction.x() + plane.b * ray.direction.y() + plane.c * ray.direction.z();
		if (fabs(denom) < eps) return -1; // 平行无交点

		float numerator = -(plane.a * ray.origin.x() + plane.b * ray.origin.y() + plane.c * ray.origin.z() + plane.d);
		float t = numerator / denom;
		if (t < eps) return -1; // 交点在射线起点后方

		// 计算交点坐标
		QVector3D hitPoint = ray.origin + ray.direction * t;

		// 根据平面所属面，检查交点是否在立方体的面上（范围[-0.5, 0.5]）
		switch (plane.face) {
		case CubeFaceType::Front:    // z=0.5，检查x和y
		case CubeFaceType::Back:     // z=-0.5，检查x和y
			if (hitPoint.x() < -0.5f - eps || hitPoint.x() > 0.5f + eps) return -1;
			if (hitPoint.y() < -0.5f - eps || hitPoint.y() > 0.5f + eps) return -1;
			break;
		case CubeFaceType::Left:     // x=-0.5，检查y和z
		case CubeFaceType::Right:    // x=0.5，检查y和z
			if (hitPoint.y() < -0.5f - eps || hitPoint.y() > 0.5f + eps) return -1;
			if (hitPoint.z() < -0.5f - eps || hitPoint.z() > 0.5f + eps) return -1;
			break;
		case CubeFaceType::Top:      // y=0.5，检查x和z
		case CubeFaceType::Bottom:   // y=-0.5，检查x和z
			if (hitPoint.x() < -0.5f - eps || hitPoint.x() > 0.5f + eps) return -1;
			if (hitPoint.z() < -0.5f - eps || hitPoint.z() > 0.5f + eps) return -1;
			break;
		default: return -1;
		}

		return t;
	}

	// 从屏幕坐标生成射线（复用之前的函数）
	Ray createRayFromScreenCoord(
		float screenX, float screenY,
		int windowWidth, int windowHeight,
		const QMatrix4x4& viewMatrix,
		const QMatrix4x4& projectionMatrix)
	{
		// 1. 屏幕坐标→NDC坐标
		float ndcX = (2.0f * screenX) / windowWidth - 1.0f;
		float ndcY = 1.0f - (2.0f * screenY) / windowHeight; // 翻转Y轴
		// 2. 逆矩阵转换
		QMatrix4x4 invViewProj = (projectionMatrix * viewMatrix).inverted();
		QVector4D nearPt(ndcX, ndcY, -1.0f, 1.0f);
		QVector4D farPt(ndcX, ndcY, 1.0f, 1.0f);
		QVector4D worldNearHomogeneous = invViewProj * nearPt;
		QVector3D worldNear = worldNearHomogeneous.toVector3D() / worldNearHomogeneous.w();

		QVector4D worldFarHomogeneous = invViewProj * farPt;
		QVector3D worldFar = worldFarHomogeneous.toVector3D() / worldFarHomogeneous.w();
		// 3. 生成射线
		return { worldNear, (worldFar - worldNear).normalized() };
	}

	// 原有初始化逻辑（未修改）
	void init() {
		m_gl->initializeOpenGLFunctions();
		generateVertices();
		createTexture();
		setupBuffers();
	}

	// 原有顶点生成逻辑（未修改）
	void generateVertices() {
		const QVector3D positions[8] = {
			{ -0.5f, -0.5f, -0.5f },{ 0.5f, -0.5f, -0.5f },
			{ 0.5f, 0.5f, -0.5f },{ -0.5f, 0.5f, -0.5f },
			{ -0.5f, -0.5f, 0.5f },{ 0.5f, -0.5f, 0.5f },
			{ 0.5f, 0.5f, 0.5f },{ -0.5f, 0.5f, 0.5f }
		};

		const QVector2D texCoords[6][4] = {
			{ { 0.0f, 1.0f },{ 0.5f, 1.0f },{ 0.5f, 2.0f / 3.0f },{ 0.0f, 2.0f / 3.0f } },
			{ { 1.0f, 2.0f / 3.0f },{ 0.5f, 2.0f / 3.0f },{ 0.5f, 1.0f },{ 1.0f, 1.0f } },
			{ { 0.0f, 1.0f / 3.0f },{ 0.0f, 2.0f / 3.0f },{ 0.5f, 2.0f / 3.0f },{ 0.5f, 1.0f / 3.0f } },
			{ { 1.0f, 2.0f / 3.0f },{ 1.0f, 1.0f / 3.0f },{ 0.5f, 1.0f / 3.0f },{ 0.5f, 2.0f / 3.0f } },
			{ { 0.0f, 0.0f },{ 0.5f, 0.0f },{ 0.5f, 1.0f / 3.0f },{ 0.0f, 1.0f / 3.0f } },
			{ { 1.0f, 0.0f } ,{ 0.5f, 0.0f },{ 0.5f, 1.0f / 3.0f } ,{ 1.0f, 1.0f / 3.0f } }
		};

		const unsigned int indices[6][6] = {
			{ 3, 2, 6, 3, 7, 6 },{ 1, 0, 4, 4, 5, 1 },
			{ 0, 3, 7, 7, 4, 0 },{ 2, 1, 5, 5, 6, 2 },
			{ 4, 5, 6, 6, 7, 4 },{ 0, 1, 2, 2, 3, 0 }
		};

		m_vertices.clear();
		m_indices.clear();

		for (int face = 0; face < 6; ++face) {
			for (int i = 0; i < 6; ++i) {
				unsigned int globalIdx = indices[face][i];
				int localIdx = -1;
				for (int j = 0; j < 4; ++j) {
					if (faceVertices[face][j] == globalIdx) {
						localIdx = j;
						break;
					}
				}
				CubeVertex v;
				v.pos = positions[globalIdx];
				v.texCoord = texCoords[face][localIdx];
				m_vertices.push_back(v);
			}
		}

		for (unsigned int i = 0; i < m_vertices.size(); ++i) {
			m_indices.push_back(i);
		}
	}

	// 原有纹理创建逻辑（未修改）
	void createTexture() {
		QImage img(m_texPath);
		if (img.isNull()) {
			qWarning() << "纹理加载失败：" << m_texPath;
			return;
		}
		m_texture = std::make_unique<QOpenGLTexture>(img.mirrored());
		m_texture->setMinificationFilter(QOpenGLTexture::Linear);
		m_texture->setMagnificationFilter(QOpenGLTexture::Linear);
	}

	// 原有缓冲区初始化逻辑（未修改）
	void setupBuffers() {
		m_gl->glGenVertexArrays(1, &m_vao);
		m_gl->glGenBuffers(1, &m_vbo);
		m_gl->glGenBuffers(1, &m_ebo);

		m_gl->glBindVertexArray(m_vao);
		m_gl->glBindBuffer(GL_ARRAY_BUFFER, m_vbo);
		m_gl->glBufferData(GL_ARRAY_BUFFER, m_vertices.size() * sizeof(CubeVertex), m_vertices.data(), GL_STATIC_DRAW);
		m_gl->glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, m_ebo);
		m_gl->glBufferData(GL_ELEMENT_ARRAY_BUFFER, m_indices.size() * sizeof(unsigned int), m_indices.data(), GL_STATIC_DRAW);

		m_gl->glEnableVertexAttribArray(0);
		m_gl->glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, sizeof(CubeVertex), (void*)offsetof(CubeVertex, pos));
		m_gl->glEnableVertexAttribArray(1);
		m_gl->glVertexAttribPointer(1, 2, GL_FLOAT, GL_FALSE, sizeof(CubeVertex), (void*)offsetof(CubeVertex, texCoord));

		m_gl->glBindVertexArray(0);
	}

private:
	// 原有成员变量
	QOpenGLFunctions_4_5_Core* m_gl;
	QString m_texPath;
	std::vector<CubeVertex> m_vertices;
	std::vector<unsigned int> m_indices;
	unsigned int m_vao, m_vbo, m_ebo;
	std::unique_ptr<QOpenGLTexture> m_texture;
	const unsigned int faceVertices[6][4] = {
		{ 3, 2, 6, 7 },{ 1, 0, 4, 5 },{ 0, 3, 7, 4 },
		{ 2, 1, 5, 6 },{ 4, 5, 6, 7 },{ 0, 1, 2, 3 }
	};

	// 新增成员变量（用于检测）
	QMatrix4x4 m_modelMatrix;    // 存储绘制时的模型矩阵
	QMatrix4x4 m_viewMatrix;     // 存储绘制时的视图矩阵
	QMatrix4x4 m_projectionMatrix;// 存储绘制时的投影矩阵
	int m_windowWidth;           // 窗口宽度
	int m_windowHeight;          // 窗口高度
};