use anyhow::{Context, Result, anyhow, bail};
use ogg::PacketReader;
use opusic_sys::{
    OPUS_OK, OpusDecoder, opus_decode_float, opus_decoder_create, opus_decoder_destroy,
    opus_packet_get_nb_samples,
};
use std::{
    io::{Read, Seek, SeekFrom},
    ptr::NonNull,
};

use super::SAMPLE_RATE;

#[expect(dead_code, reason = "Most of the metadata is not needed for playback")]
#[derive(Debug, Clone)]
pub struct Metadata {
    version_number: u8,
    channel_count: u8,
    pre_skip: u16,
    input_sample_rate: u32,
    output_gain: i16,
    channel_mapping_family: u8,
}

impl Metadata {
    fn new(id_header: &[u8]) -> Result<Self> {
        if id_header.len() < 19 || &id_header[0..8] != b"OpusHead" {
            bail!("Invalid ogg/opus file");
        }

        let version_number = id_header[8];
        let channel_count = id_header[9];
        let pre_skip = u16::from_le_bytes(id_header[10..12].try_into().unwrap());
        let input_sample_rate = u32::from_le_bytes(id_header[12..16].try_into().unwrap());
        let output_gain = i16::from_le_bytes(id_header[16..18].try_into().unwrap());
        let channel_mapping_family = id_header[18];

        Ok(Metadata {
            version_number,
            channel_count,
            pre_skip,
            input_sample_rate,
            output_gain,
            channel_mapping_family,
        })
    }
}

pub struct OpusStream<T: Read + Seek> {
    ogg: PacketReader<T>,
    metadata: Metadata,
    firstgp: u64,
    decoder: NonNull<OpusDecoder>,
    length: u64,
}

unsafe impl<T: Read + Seek> Send for OpusStream<T> {}

impl<T: Read + Seek> OpusStream<T> {
    pub fn new(rdr: T) -> Result<Self> {
        let mut ogg = PacketReader::new(rdr);

        // Read headers
        let id_header = ogg
            .read_packet_expected()
            .context("Can't read ogg/opus ID Header")?;
        let _comment_header = ogg
            .read_packet_expected()
            .context("Can't read ogg/opus Comment Header")?;
        let metadata = Metadata::new(&id_header.data)?;

        // Read first packet granule position for determining track length
        // https://wiki.xiph.org/OggOpus
        // 'PCM sample position' = 'granule position' - 'pre-skip'
        let firstgp = ogg
            .read_packet_expected()
            .context("ogg/opus contains no audio")?
            .absgp_page();

        // Find the last packet and read granule position
        let lastgp = Self::seek_to_last_gp(&mut ogg).context("Can't seek ogg/opus file")?;

        // Track length can be approximated by subtracting granule positions
        let length = lastgp - u64::from(metadata.pre_skip) - firstgp;

        if metadata.channel_count as usize > super::MAX_CHANNELS {
            bail!(
                "Can't decode ogg/opus file with {} channels, only mono and stereo are supported",
                metadata.channel_count
            );
        }

        // Seek back to start
        ogg.seek_bytes(SeekFrom::Start(0))
            .context("Failed to start ogg/opus stream")?;
        let _id_header = ogg.read_packet_expected().ok();
        let _comment_header = ogg.read_packet_expected().ok();

        unsafe {
            let mut result_code = 0;
            let decoder = NonNull::new(opus_decoder_create(
                super::SAMPLE_RATE as _,
                metadata.channel_count.into(),
                &mut result_code,
            ))
            .and_then(|ptr| (result_code == OPUS_OK).then_some(ptr))
            .context("Failed to initialize opus decoder")?;

            Ok(Self {
                ogg,
                metadata,
                firstgp,
                decoder,
                length,
            })
        }
    }

    /// Seek to a position.
    ///
    /// Seeks to a granule position that is <= `pos_goal`,
    /// and returns the number of granules to skip to reach `pos_goal`.
    pub fn seek(&mut self, pos_goal_pcm: u64) -> Result<usize> {
        // Go back 80ms at a time
        const SEEK_ADJUSTMENT: u64 = SAMPLE_RATE as u64 * 8 / 100;
        let pos_goal_gp = self.pcm_to_gp(pos_goal_pcm);
        let mut cur_goal_gp = pos_goal_gp.saturating_sub(SEEK_ADJUSTMENT);
        let mut seeked_to_gp = loop {
            if !self
                .ogg
                .seek_absgp(None, cur_goal_gp)
                .context("Can't seek ogg/opus file")?
            {
                return Err(anyhow!("ogg::PacketReader::seek_absgp returned false"));
            }
            if let Ok(packet) = self.ogg.read_packet_expected() {
                let page_gp = packet.absgp_page();
                if page_gp <= pos_goal_gp {
                    break page_gp;
                }
            }
            if cur_goal_gp == 0 {
                unreachable!("File doesn't contain pages earlier than 1");
            }
            cur_goal_gp = cur_goal_gp.saturating_sub(SEEK_ADJUSTMENT);
        };

        // Fix if gone to metadata pages
        if seeked_to_gp < self.firstgp {
            self.ogg.seek_bytes(SeekFrom::Start(0))?;
            let _id_header = self.ogg.read_packet_expected()?;
            let _comment_header = self.ogg.read_packet_expected()?;
            seeked_to_gp = self.firstgp;
        }

        self.ogg
            .seek_absgp(None, seeked_to_gp)
            .context("Can't seek ogg/opus file")?;

        Ok((pos_goal_gp - seeked_to_gp) as usize)
    }

    pub fn decode_packet(&mut self, buf: &mut super::Buffer) -> usize {
        let Ok(packet) = self.ogg.read_packet_expected() else {
            return 0;
        };

        let data = packet.data.as_slice();
        let samples = unsafe {
            #[allow(
                clippy::cast_possible_wrap,
                reason = "ogg packets aren't longer than 2^31"
            )]
            opus_packet_get_nb_samples(data.as_ptr(), data.len() as i32, super::SAMPLE_RATE as _)
        };
        if samples <= 0 {
            return 0;
        }

        let result = unsafe {
            #[allow(
                clippy::cast_possible_wrap,
                reason = "ogg packets aren't longer than 2^31"
            )]
            opus_decode_float(
                self.decoder.as_ptr(),
                data.as_ptr(),
                data.len() as i32,
                buf.as_mut_ptr(),
                samples,
                0,
            )
        };
        if result <= 0 {
            return 0;
        }

        result as usize * usize::from(self.metadata.channel_count)
    }

    pub fn channel_count(&self) -> u8 {
        self.metadata.channel_count
    }

    pub fn length(&self) -> u64 {
        self.length
    }

    pub fn pre_skip(&self) -> usize {
        self.metadata.pre_skip.into()
    }

    fn pcm_to_gp(&self, pcm: u64) -> u64 {
        pcm + self.firstgp + u64::from(self.metadata.pre_skip)
    }

    fn seek_to_last_gp(ogg: &mut PacketReader<T>) -> Result<u64> {
        const SEEK: i64 = -4096;
        ogg.seek_bytes(SeekFrom::End(SEEK))?;
        let gp = loop {
            let packet = ogg.read_packet()?;
            if let Some(packet) = packet {
                if packet.last_in_stream() {
                    let data = packet.data.as_slice();
                    let samples = unsafe {
                        #[allow(
                            clippy::cast_possible_wrap,
                            reason = "ogg packets aren't longer than 2^31"
                        )]
                        opus_packet_get_nb_samples(
                            data.as_ptr(),
                            data.len() as i32,
                            super::SAMPLE_RATE as _,
                        )
                    };
                    break packet.absgp_page() + if samples > 0 { samples as u64 } else { 0 };
                }
            } else {
                ogg.seek_bytes(SeekFrom::Current(SEEK))?;
            }
        };
        Ok(gp)
    }
}

impl<T: Read + Seek> Drop for OpusStream<T> {
    fn drop(&mut self) {
        unsafe {
            opus_decoder_destroy(self.decoder.as_ptr());
        }
    }
}
