mod fps;
pub mod resource;
mod text;

use crate::config;
use anyhow::{Context, Result};
use crevice::std140::AsStd140;
use fps::FpsCounter;
use glam::Vec2;
use sdl3::{
    gpu::{
        ColorTargetBlendState, ColorTargetDescription, ColorTargetInfo, CommandBuffer, CompareOp,
        CullMode, DepthStencilState, Device, FillMode, Filter, FrontFace, GraphicsPipeline,
        GraphicsPipelineTargetInfo, LoadOp, PrimitiveType, RasterizerState, RenderPass,
        SampleCount, Sampler, SamplerCreateInfo, Shader, ShaderFormat, StoreOp, Texture,
        TextureCreateInfo, TextureFormat, TextureSamplerBinding, TextureType, TextureUsage,
        VertexInputState,
    },
    pixels::Color,
    sys::gpu::SDL_GPUViewport,
    video::Window,
};
use std::{ops::BitOr, path::Path};
use text::Text;

pub type StretchFactor = Option<f32>;

#[derive(AsStd140)]
struct Uniforms {
    time: f32,
    resolution: Vec2,
}

enum TextureInput {
    Framebuffer(usize),
}

struct ShaderQuad {
    pipeline: GraphicsPipeline,
    textures: Vec<TextureInput>,
}

enum Draw {
    Shader(ShaderQuad),
    Text(Text),
}

impl Draw {
    pub fn from_config(
        cfg: &config::Draw,
        renderer: &Renderer,
        vertex_shader: &Shader,
        target_formats: &[TextureFormat],
    ) -> Result<Self> {
        Ok(match cfg {
            config::Draw::Shader(config::Shader { name, textures }) => Draw::Shader(ShaderQuad {
                pipeline: renderer.load_shader_effect(
                    vertex_shader,
                    textures.len() as u32,
                    target_formats,
                    name,
                )?,
                textures: textures
                    .iter()
                    .map(|t| match t {
                        config::Texture::Framebuffer(i) => TextureInput::Framebuffer(*i),
                    })
                    .collect(),
            }),
            config::Draw::Text(text) => Draw::Text(Text::new(
                &renderer.gpu,
                target_formats[0],
                text,
                renderer.fps.value(),
            )?),
        })
    }

    pub fn draw(
        &mut self,
        command_buffer: &CommandBuffer,
        render_pass: &RenderPass,
        framebuffers: &[Texture<'static>],
        sampler: &Sampler,
        time: f32,
        resolution: config::Resolution,
    ) {
        // This goes to layout (set=3, binding = 0) as per
        // https://wiki.libsdl.org/SDL3/SDL_CreateGPUShader
        let uniforms = Uniforms {
            time,
            resolution: Vec2::new(resolution.width as f32, resolution.height as f32),
        }
        .as_std140();
        command_buffer.push_fragment_uniform_data(0, &uniforms);

        match self {
            Draw::Shader(ShaderQuad { pipeline, textures }) => {
                render_pass.bind_graphics_pipeline(pipeline);
                let sampler_bindings: Vec<TextureSamplerBinding> = textures
                    .iter()
                    .map(|texture_input| {
                        TextureSamplerBinding::new()
                            .with_texture(match texture_input {
                                TextureInput::Framebuffer(i) => &framebuffers[*i],
                            })
                            .with_sampler(sampler)
                    })
                    .collect();
                render_pass.bind_fragment_samplers(0, &sampler_bindings);
                render_pass.draw_primitives(4, 1, 0, 0);
            }
            Draw::Text(text) => {
                text.draw(command_buffer, render_pass, resolution, time);
            }
        }
    }
}

struct Pass {
    draw: Vec<Draw>,
    targets: Vec<usize>,
}

struct Output {
    draw: Vec<Draw>,
}

pub struct Renderer {
    config: config::Render,
    window: Window,
    stretch: StretchFactor,
    gpu: Device,
    fps: FpsCounter,
    passes: Vec<Pass>,
    output: Output,
    framebuffers: Vec<Texture<'static>>,
    sampler: Sampler,
}

impl Renderer {
    pub fn new(gpu: &Device, window: &Window, config: &config::Render) -> Result<Self> {
        let sampler = gpu.create_sampler(
            SamplerCreateInfo::new()
                .with_min_filter(Filter::Nearest)
                .with_mag_filter(Filter::Nearest)
                .with_mipmap_mode(sdl3::gpu::SamplerMipmapMode::Nearest)
                .with_address_mode_u(sdl3::gpu::SamplerAddressMode::ClampToEdge)
                .with_address_mode_v(sdl3::gpu::SamplerAddressMode::ClampToEdge)
                .with_address_mode_w(sdl3::gpu::SamplerAddressMode::ClampToEdge),
        )?;

        let mut renderer = Self {
            config: config.clone(),
            window: window.clone(),
            stretch: None,
            gpu: gpu.clone(),
            fps: FpsCounter::new(),
            passes: Vec::new(),
            output: Output { draw: Vec::new() },
            framebuffers: Vec::new(),
            sampler,
        };

        let (code, stage) = resource::load_shader("quad.vert")?;
        let vertex_shader = gpu
            .create_shader()
            .with_code(ShaderFormat::SpirV, &code, stage)
            .with_entrypoint(c"main")
            .build()
            .context("Can't create vertex shader")?;

        for pass in &config.passes {
            let target_formats: Vec<TextureFormat> = pass
                .targets
                .iter()
                .map(|idx| TextureFormat::from(&config.framebuffers[*idx]))
                .collect();
            let mut draw = Vec::new();
            for cfg_draw in &pass.draw {
                draw.push(Draw::from_config(
                    cfg_draw,
                    &renderer,
                    &vertex_shader,
                    &target_formats,
                )?);
            }
            renderer.passes.push(Pass {
                draw,
                targets: pass.targets.clone(),
            });
        }

        let window_format = gpu.get_swapchain_texture_format(window);
        for cfg_draw in &config.output.draw {
            renderer.output.draw.push(Draw::from_config(
                cfg_draw,
                &renderer,
                &vertex_shader,
                &[window_format],
            )?);
        }

        for fb in &config.framebuffers {
            let texture = gpu
                .create_texture(
                    TextureCreateInfo::new()
                        .with_type(TextureType::_2D)
                        .with_format(fb.into())
                        .with_usage(TextureUsage::ColorTarget.bitor(TextureUsage::Sampler))
                        .with_width(config.resolution.width)
                        .with_height(config.resolution.height)
                        .with_layer_count_or_depth(1)
                        .with_num_levels(1)
                        .with_sample_count(SampleCount::NoMultiSampling),
                )
                .context("Can't create framebuffer texture")?;
            renderer.framebuffers.push(texture);
        }

        Ok(renderer)
    }

    pub fn set_stretch_factor(&mut self, stretch: StretchFactor) {
        self.stretch = stretch;
    }

    pub fn render(&mut self, time: f32) -> Result<()> {
        // Record offscreen passes
        let mut command_buffer = self
            .gpu
            .acquire_command_buffer()
            .context("Can't acquire command buffer")?;

        for pass in &mut self.passes {
            let color_infos: Vec<ColorTargetInfo> = pass
                .targets
                .iter()
                .map(|target| {
                    ColorTargetInfo::default()
                        .with_texture(&self.framebuffers[*target])
                        .with_clear_color(Color::BLACK)
                        .with_load_op(LoadOp::Clear)
                        .with_store_op(StoreOp::Store)
                })
                .collect();

            let render_pass = self
                .gpu
                .begin_render_pass(&command_buffer, &color_infos, None)
                .context("Can't begin render pass")?;

            for draw in &mut pass.draw {
                draw.draw(
                    &command_buffer,
                    &render_pass,
                    &self.framebuffers,
                    &self.sampler,
                    time,
                    self.config.resolution,
                );
            }

            self.gpu.end_render_pass(render_pass);
        }

        // Record onscreen pass
        let Ok(swapchain_texture) = command_buffer.wait_and_acquire_swapchain_texture(&self.window)
        else {
            command_buffer.cancel();
            return Ok(());
        };
        let width = swapchain_texture.width();
        let height = swapchain_texture.height();

        // Compute aspect-ratio preserving viewport size
        let viewport = self.viewport(width, height);

        let color_infos = [ColorTargetInfo::default()
            .with_texture(&swapchain_texture)
            .with_clear_color(Color::BLACK)
            .with_load_op(LoadOp::Clear)
            .with_store_op(StoreOp::Store)];

        let render_pass = self
            .gpu
            .begin_render_pass(&command_buffer, &color_infos, None)
            .context("Can't begin render pass")?;

        // Configure viewport so that aspect ratio is preserved
        self.gpu.set_viewport(&render_pass, viewport);

        for draw in &mut self.output.draw {
            draw.draw(
                &command_buffer,
                &render_pass,
                &self.framebuffers,
                &self.sampler,
                time,
                self.config.resolution,
            );
        }

        self.gpu.end_render_pass(render_pass);

        command_buffer
            .submit()
            .context("Command buffer submission failed")?;

        self.fps.frame();

        Ok(())
    }

    fn load_shader_effect(
        &self,
        vertex_shader: &Shader,
        texture_inputs: u32,
        target_formats: &[TextureFormat],
        name: impl AsRef<Path>,
    ) -> Result<GraphicsPipeline> {
        let (code, stage) = resource::load_shader(name)?;
        let fragment_shader = self
            .gpu
            .create_shader()
            .with_code(ShaderFormat::SpirV, &code, stage)
            .with_entrypoint(c"main")
            .with_uniform_buffers(2)
            .with_samplers(texture_inputs)
            .build()
            .context("Can't create fragment shader")?;

        let color_target_descs: Vec<ColorTargetDescription> = target_formats
            .iter()
            .map(|texture_format| {
                ColorTargetDescription::new()
                    .with_format(*texture_format)
                    .with_blend_state(ColorTargetBlendState::default())
            })
            .collect();

        let pipeline = self
            .gpu
            .create_graphics_pipeline()
            .with_vertex_shader(vertex_shader)
            .with_fragment_shader(&fragment_shader)
            .with_vertex_input_state(VertexInputState::default())
            .with_primitive_type(PrimitiveType::TriangleStrip)
            .with_rasterizer_state(
                RasterizerState::new()
                    .with_fill_mode(FillMode::Fill)
                    .with_cull_mode(CullMode::None)
                    .with_front_face(FrontFace::Clockwise),
            )
            .with_depth_stencil_state(
                DepthStencilState::new()
                    .with_compare_op(CompareOp::Greater)
                    .with_enable_depth_test(false)
                    .with_enable_stencil_test(false),
            )
            .with_target_info(
                GraphicsPipelineTargetInfo::new()
                    .with_color_target_descriptions(&color_target_descs)
                    .with_has_depth_stencil_target(false),
            )
            .build()
            .context("Can't create graphics pipeline")?;

        Ok(pipeline)
    }

    fn viewport(&self, width: u32, height: u32) -> SDL_GPUViewport {
        let aspect_ratio = width as f32 / height as f32;
        let mut target_ratio =
            self.config.resolution.width as f32 / self.config.resolution.height as f32;
        if let Some(stretch) = self.stretch {
            target_ratio /= stretch;
        }
        let (vp_width, vp_height) = if aspect_ratio > target_ratio {
            (height as f32 * target_ratio, height as f32)
        } else {
            (width as f32, width as f32 / target_ratio)
        };

        sdl3::sys::gpu::SDL_GPUViewport {
            x: if aspect_ratio > target_ratio {
                (width as f32 - vp_width) / 2.
            } else {
                0.
            },
            y: if aspect_ratio > target_ratio {
                0.
            } else {
                (height as f32 - vp_height) / 2.
            },
            w: vp_width,
            h: vp_height,
            min_depth: 0.,
            max_depth: 1.,
        }
    }
}
