use byteorder::{LittleEndian, WriteBytesExt};
use std::io::{self, Write};
use color;
const BITMAPFILEHEADER_SIZE: u32 = 14;
const BITMAPINFOHEADER_SIZE: u32 = 40;
const BITMAPV4HEADER_SIZE: u32 = 108;
pub struct BMPEncoder<'a, W: 'a> {
    writer: &'a mut W,
}
impl<'a, W: Write + 'a> BMPEncoder<'a, W> {
    
    pub fn new(w: &'a mut W) -> Self {
        BMPEncoder { writer: w }
    }
    
    
    
    pub fn encode(
        &mut self,
        image: &[u8],
        width: u32,
        height: u32,
        c: color::ColorType,
    ) -> io::Result<()> {
        let bmp_header_size = BITMAPFILEHEADER_SIZE;
        let (dib_header_size, written_pixel_size, palette_color_count) = try!(get_pixel_info(c));
        let row_pad_size = (4 - (width * written_pixel_size) % 4) % 4; 
        let image_size = width * height * written_pixel_size + (height * row_pad_size);
        let palette_size = palette_color_count * 4; 
        let file_size = bmp_header_size + dib_header_size + palette_size + image_size;
        
        try!(self.writer.write_u8(b'B'));
        try!(self.writer.write_u8(b'M'));
        try!(self.writer.write_u32::<LittleEndian>(file_size)); 
        try!(self.writer.write_u16::<LittleEndian>(0)); 
        try!(self.writer.write_u16::<LittleEndian>(0)); 
        try!(
            self.writer
                .write_u32::<LittleEndian>(bmp_header_size + dib_header_size + palette_size)
        ); 
        
        try!(self.writer.write_u32::<LittleEndian>(dib_header_size));
        try!(self.writer.write_i32::<LittleEndian>(width as i32));
        try!(self.writer.write_i32::<LittleEndian>(height as i32));
        try!(self.writer.write_u16::<LittleEndian>(1)); 
        try!(
            self.writer
                .write_u16::<LittleEndian>((written_pixel_size * 8) as u16)
        ); 
        if dib_header_size >= BITMAPV4HEADER_SIZE {
            
            try!(self.writer.write_u32::<LittleEndian>(3)); 
        } else {
            try!(self.writer.write_u32::<LittleEndian>(0)); 
        }
        try!(self.writer.write_u32::<LittleEndian>(image_size));
        try!(self.writer.write_i32::<LittleEndian>(0)); 
        try!(self.writer.write_i32::<LittleEndian>(0)); 
        try!(self.writer.write_u32::<LittleEndian>(palette_color_count));
        try!(self.writer.write_u32::<LittleEndian>(0)); 
        if dib_header_size >= BITMAPV4HEADER_SIZE {
            
            try!(self.writer.write_u32::<LittleEndian>(0xff << 16)); 
            try!(self.writer.write_u32::<LittleEndian>(0xff << 8)); 
            try!(self.writer.write_u32::<LittleEndian>(0xff << 0)); 
            try!(self.writer.write_u32::<LittleEndian>(0xff << 24)); 
            try!(self.writer.write_u32::<LittleEndian>(0x73524742)); 
            
            for _ in 0..12 {
                try!(self.writer.write_u32::<LittleEndian>(0));
            }
        }
        
        match c {
            color::ColorType::RGB(8) => {
                try!(self.encode_rgb(image, width, height, row_pad_size, 3))
            }
            color::ColorType::RGBA(8) => {
                try!(self.encode_rgba(image, width, height, row_pad_size, 4))
            }
            color::ColorType::Gray(8) => {
                try!(self.encode_gray(image, width, height, row_pad_size, 1))
            }
            color::ColorType::GrayA(8) => {
                try!(self.encode_gray(image, width, height, row_pad_size, 2))
            }
            _ => {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    &get_unsupported_error_message(c)[..],
                ))
            }
        }
        Ok(())
    }
    fn encode_rgb(
        &mut self,
        image: &[u8],
        width: u32,
        height: u32,
        row_pad_size: u32,
        bytes_per_pixel: u32,
    ) -> io::Result<()> {
        let x_stride = bytes_per_pixel;
        let y_stride = width * x_stride;
        for row in 0..height {
            
            let row_start = (height - row - 1) * y_stride;
            for col in 0..width {
                let pixel_start = (row_start + (col * x_stride)) as usize;
                let r = image[pixel_start];
                let g = image[pixel_start + 1];
                let b = image[pixel_start + 2];
                
                try!(self.writer.write_u8(b));
                try!(self.writer.write_u8(g));
                try!(self.writer.write_u8(r));
                
            }
            try!(self.write_row_pad(row_pad_size));
        }
        Ok(())
    }
    fn encode_rgba(
        &mut self,
        image: &[u8],
        width: u32,
        height: u32,
        row_pad_size: u32,
        bytes_per_pixel: u32,
    ) -> io::Result<()> {
        let x_stride = bytes_per_pixel;
        let y_stride = width * x_stride;
        for row in 0..height {
            
            let row_start = (height - row - 1) * y_stride;
            for col in 0..width {
                let pixel_start = (row_start + (col * x_stride)) as usize;
                let r = image[pixel_start];
                let g = image[pixel_start + 1];
                let b = image[pixel_start + 2];
                let a = image[pixel_start + 3];
                
                try!(self.writer.write_u8(b));
                try!(self.writer.write_u8(g));
                try!(self.writer.write_u8(r));
                try!(self.writer.write_u8(a));
            }
            try!(self.write_row_pad(row_pad_size));
        }
        Ok(())
    }
    fn encode_gray(
        &mut self,
        image: &[u8],
        width: u32,
        height: u32,
        row_pad_size: u32,
        bytes_per_pixel: u32,
    ) -> io::Result<()> {
        
        for val in 0..256 {
            
            let val = val as u8;
            try!(self.writer.write_u8(val));
            try!(self.writer.write_u8(val));
            try!(self.writer.write_u8(val));
            try!(self.writer.write_u8(0));
        }
        
        let x_stride = bytes_per_pixel;
        let y_stride = width * x_stride;
        for row in 0..height {
            
            let row_start = (height - row - 1) * y_stride;
            for col in 0..width {
                let pixel_start = (row_start + (col * x_stride)) as usize;
                
                try!(self.writer.write_u8(image[pixel_start]));
                
            }
            try!(self.write_row_pad(row_pad_size));
        }
        Ok(())
    }
    fn write_row_pad(&mut self, row_pad_size: u32) -> io::Result<()> {
        for _ in 0..row_pad_size {
            try!(self.writer.write_u8(0));
        }
        Ok(())
    }
}
fn get_unsupported_error_message(c: color::ColorType) -> String {
    format!(
        "Unsupported color type {:?}.  Supported types: RGB(8), RGBA(8), Gray(8), GrayA(8).",
        c
    )
}
fn get_pixel_info(c: color::ColorType) -> io::Result<(u32, u32, u32)> {
    let sizes = match c {
        color::ColorType::RGB(8) => (BITMAPINFOHEADER_SIZE, 3, 0),
        color::ColorType::RGBA(8) => (BITMAPV4HEADER_SIZE, 4, 0),
        color::ColorType::Gray(8) => (BITMAPINFOHEADER_SIZE, 1, 256),
        color::ColorType::GrayA(8) => (BITMAPINFOHEADER_SIZE, 1, 256),
        _ => {
            return Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                &get_unsupported_error_message(c)[..],
            ))
        }
    };
    Ok(sizes)
}
#[cfg(test)]
mod tests {
    use super::super::BMPDecoder;
    use super::BMPEncoder;
    use color::ColorType;
    use image::ImageDecoder;
    use std::io::Cursor;
    fn round_trip_image(image: &[u8], width: u32, height: u32, c: ColorType) -> Vec<u8> {
        let mut encoded_data = Vec::new();
        {
            let mut encoder = BMPEncoder::new(&mut encoded_data);
            encoder
                .encode(&image, width, height, c)
                .expect("could not encode image");
        }
        let decoder = BMPDecoder::new(Cursor::new(&encoded_data)).expect("failed to decode");
        decoder.read_image().expect("failed to decode")
    }
    #[test]
    fn round_trip_single_pixel_rgb() {
        let image = [255u8, 0, 0]; 
        let decoded = round_trip_image(&image, 1, 1, ColorType::RGB(8));
        assert_eq!(3, decoded.len());
        assert_eq!(255, decoded[0]);
        assert_eq!(0, decoded[1]);
        assert_eq!(0, decoded[2]);
    }
    #[test]
    fn round_trip_single_pixel_rgba() {
        let image = [1, 2, 3, 4];
        let decoded = round_trip_image(&image, 1, 1, ColorType::RGBA(8));
        assert_eq!(&decoded[..], &image[..]);
    }
    #[test]
    fn round_trip_3px_rgb() {
        let image = [0u8; 3 * 3 * 3]; 
        let _decoded = round_trip_image(&image, 3, 3, ColorType::RGB(8));
    }
    #[test]
    fn round_trip_gray() {
        let image = [0u8, 1, 2]; 
        let decoded = round_trip_image(&image, 3, 1, ColorType::Gray(8));
        
        assert_eq!(9, decoded.len());
        assert_eq!(0, decoded[0]);
        assert_eq!(0, decoded[1]);
        assert_eq!(0, decoded[2]);
        assert_eq!(1, decoded[3]);
        assert_eq!(1, decoded[4]);
        assert_eq!(1, decoded[5]);
        assert_eq!(2, decoded[6]);
        assert_eq!(2, decoded[7]);
        assert_eq!(2, decoded[8]);
    }
    #[test]
    fn round_trip_graya() {
        let image = [0u8, 0, 1, 0, 2, 0]; 
        let decoded = round_trip_image(&image, 1, 3, ColorType::GrayA(8));
        
        assert_eq!(9, decoded.len());
        assert_eq!(0, decoded[0]);
        assert_eq!(0, decoded[1]);
        assert_eq!(0, decoded[2]);
        assert_eq!(1, decoded[3]);
        assert_eq!(1, decoded[4]);
        assert_eq!(1, decoded[5]);
        assert_eq!(2, decoded[6]);
        assert_eq!(2, decoded[7]);
        assert_eq!(2, decoded[8]);
    }
}