1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
use crate::{
    config::Config,
    error::{FragmentErrorKind, Result},
    net::constants::FRAGMENT_HEADER_SIZE,
    packet::header::FragmentHeader,
    sequence_buffer::{ReassemblyData, SequenceBuffer},
};

use std::io::Write;

/// Type that will manage fragmentation of packets.
pub struct Fragmentation {
    fragments: SequenceBuffer<ReassemblyData>,
    config: Config,
}

impl Fragmentation {
    /// Creates and returns a new Fragmentation
    pub fn new(config: &Config) -> Fragmentation {
        Fragmentation {
            fragments: SequenceBuffer::with_capacity(config.fragment_reassembly_buffer_size),
            config: config.clone(),
        }
    }

    /// This functions checks how many times a number fits into another number and will round up.
    ///
    /// For example we have two numbers:
    /// - number 1 = 4000;
    /// - number 2 = 1024;
    /// If you do it the easy way the answer will be 4000/1024 = 3.90625.
    /// But since we care about how how many whole times the number fits in we need the result 4.
    ///
    /// Note that when rust is rounding it is always rounding to zero (3.456 as u32 = 3)
    /// 1. calculate with modulo if `number 1` fits exactly in the `number 2`.
    /// 2. Divide `number 1` with `number 2` (this wil be rounded to zero by rust)
    /// 3. So in all cases we need to add 1 to get the right amount of fragments.
    ///
    /// lets take an example
    ///
    /// Calculate modules:
    /// - number 1 % number 2 = 928
    /// - this is bigger than 0 so remainder = 1
    ///
    /// Calculate how many times the `number 1` fits in `number 2`:
    /// - number 1 / number 2 = 3,90625 (this will be rounded to 3)
    /// - add remainder from above to 3 = 4.
    ///
    /// The above described method will figure out for all number how many times it fits into another number rounded up.
    ///
    /// So an example of dividing an packet of bytes we get these fragments:
    ///
    /// So for 4000 bytes we need 4 fragments
    /// [fragment: 1024] [fragment: 1024] [fragment: 1024] [fragment: 928]
    pub fn fragments_needed(payload_length: u16, fragment_size: u16) -> u16 {
        let remainder = if payload_length % fragment_size > 0 {
            1
        } else {
            0
        };
        ((payload_length / fragment_size) + remainder)
    }

    /// Split the given payload into fragments and write those fragments to the passed packet data.
    pub fn spit_into_fragments<'a>(payload: &'a [u8], config: &Config) -> Result<Vec<&'a [u8]>> {
        let mut fragments = Vec::new();

        let payload_length = payload.len() as u16;
        let num_fragments =
            // Safe cast max fragments is u8
            Fragmentation::fragments_needed(payload_length, config.fragment_size) as u8;

        if num_fragments > config.max_fragments {
            Err(FragmentErrorKind::ExceededMaxFragments)?;
        }

        for fragment_id in 0..num_fragments {
            // get start and end position of buffer
            let start_fragment_pos = u16::from(fragment_id) * config.fragment_size;
            let mut end_fragment_pos = (u16::from(fragment_id) + 1) * config.fragment_size;

            // If remaining buffer fits int one packet just set the end position to the length of the packet payload.
            if end_fragment_pos > payload_length {
                end_fragment_pos = payload_length;
            }

            // get specific slice of data for fragment
            let fragment_data = &payload[start_fragment_pos as usize..end_fragment_pos as usize];

            fragments.push(fragment_data);
        }

        Ok(fragments)
    }

    /// This will read fragment data and return the complete packet when all fragments are received.
    pub fn handle_fragment(
        &mut self,
        fragment_header: FragmentHeader,
        fragment_payload: &[u8],
    ) -> Result<Option<Vec<u8>>> {
        // read fragment packet

        self.create_fragment_if_not_exists(fragment_header);

        let num_fragments_received;
        let num_fragments_total;
        let sequence;
        let total_buffer;

        {
            // get entry of previous received fragments
            let reassembly_data = match self.fragments.get_mut(fragment_header.sequence()) {
                Some(val) => val,
                None => Err(FragmentErrorKind::CouldNotFindFragmentById)?,
            };

            // Got the data
            if reassembly_data.num_fragments_total != fragment_header.fragment_count() {
                Err(FragmentErrorKind::FragmentWithUnevenNumberOfFragemts)?
            }

            if reassembly_data.fragments_received[usize::from(fragment_header.id())] {
                Err(FragmentErrorKind::AlreadyProcessedFragment)?
            }

            // increase number of received fragments and set the specific fragment to received.
            reassembly_data.num_fragments_received += 1;
            reassembly_data.fragments_received[usize::from(fragment_header.id())] = true;

            // add the payload from the fragment to the buffer whe have in cache
            reassembly_data.buffer.write_all(&*fragment_payload)?;

            num_fragments_received = reassembly_data.num_fragments_received;
            num_fragments_total = reassembly_data.num_fragments_total;
            sequence = reassembly_data.sequence as u16;
            total_buffer = reassembly_data.buffer.clone();
        }

        // if whe received all fragments then remove entry and return the total received bytes.
        if num_fragments_received == num_fragments_total {
            let sequence = sequence as u16;
            self.fragments.remove(sequence);

            return Ok(Some(total_buffer));
        }

        Ok(None)
    }

    /// If fragment does not exist we need to insert a new entry.
    fn create_fragment_if_not_exists(&mut self, fragment_header: FragmentHeader) {
        if !self.fragments.exists(fragment_header.sequence()) {
            let reassembly_data = ReassemblyData::new(
                fragment_header.sequence(),
                fragment_header.fragment_count(),
                (u16::from(FRAGMENT_HEADER_SIZE) + self.config.fragment_size) as usize,
            );

            self.fragments
                .insert(fragment_header.sequence(), reassembly_data);
        }
    }
}

#[cfg(test)]
mod test {
    use super::Fragmentation;

    #[test]
    pub fn expect_right_number_of_fragments() {
        let fragment_number = Fragmentation::fragments_needed(4000, 1024);
        let fragment_number1 = Fragmentation::fragments_needed(500, 1024);

        assert_eq!(fragment_number, 4);
        assert_eq!(fragment_number1, 1);
    }
}