use crate::{
config::Config,
error::{ErrorKind, Result},
net::{connection::ActiveConnections, events::SocketEvent, link_conditioner::LinkConditioner},
packet::{DeliveryGuarantee, Outgoing, Packet},
};
use crossbeam_channel::{self, unbounded, Receiver, Sender};
use log::error;
use std::{
self, io,
net::{SocketAddr, ToSocketAddrs, UdpSocket},
};
pub struct Socket {
socket: UdpSocket,
config: Config,
connections: ActiveConnections,
recv_buffer: Vec<u8>,
link_conditioner: Option<LinkConditioner>,
event_sender: Sender<SocketEvent>,
packet_receiver: Receiver<Packet>,
}
impl Socket {
pub fn bind<A: ToSocketAddrs>(
addresses: A,
) -> Result<(Self, Sender<Packet>, Receiver<SocketEvent>)> {
Socket::bind_with_config(addresses, Config::default())
}
pub fn bind_with_config<A: ToSocketAddrs>(
addresses: A,
config: Config,
) -> Result<(Self, Sender<Packet>, Receiver<SocketEvent>)> {
let socket = UdpSocket::bind(addresses)?;
socket.set_nonblocking(true)?;
let (event_sender, event_receiver) = unbounded();
let (packet_sender, packet_receiver) = unbounded();
Ok((
Socket {
recv_buffer: vec![0; config.receive_buffer_max_size],
socket,
config,
connections: ActiveConnections::new(),
link_conditioner: None,
event_sender,
packet_receiver,
},
packet_sender,
event_receiver,
))
}
pub fn start_polling(&mut self) -> Result<()> {
loop {
if let Err(e) = self.recv_from() {
error!("Encountered an error receiving data: {:?}", e);
};
while let Ok(p) = self.packet_receiver.try_recv() {
if let Err(e) = self.send_to(p) {
match e {
ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
_ => error!("There was an error sending packet: {:?}", e),
}
}
}
if let Err(e) = self.handle_idle_clients() {
error!("Encountered an error when sending TimeoutEvent: {:?}", e);
}
}
}
fn handle_idle_clients(&mut self) -> Result<()> {
let idle_addresses = self
.connections
.idle_connections(self.config.idle_connection_timeout);
for address in idle_addresses {
self.connections.remove_connection(&address);
self.event_sender.send(SocketEvent::Timeout(address))?;
}
Ok(())
}
fn send_to(&mut self, packet: Packet) -> Result<usize> {
let connection = self
.connections
.get_or_insert_connection(packet.addr(), &self.config);
let dropped = connection.gather_dropped_packets();
let mut processed_packets: Vec<Outgoing> = dropped
.iter()
.flat_map(|waiting_packet| {
connection.process_outgoing(
&waiting_packet.payload,
DeliveryGuarantee::Reliable,
waiting_packet.ordering_guarantee,
)
})
.collect();
let processed_packet = connection.process_outgoing(
packet.payload(),
packet.delivery_guarantee(),
packet.order_guarantee(),
)?;
processed_packets.push(processed_packet);
let mut bytes_sent = 0;
for processed_packet in processed_packets {
if self.should_send_packet() {
match processed_packet {
Outgoing::Packet(outgoing) => {
bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?;
}
Outgoing::Fragments(packets) => {
for outgoing in packets {
bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?;
}
}
}
}
}
Ok(bytes_sent)
}
fn recv_from(&mut self) -> Result<()> {
match self.socket.recv_from(&mut self.recv_buffer) {
Ok((recv_len, address)) => {
if recv_len == 0 {
return Err(ErrorKind::ReceivedDataToShort)?;
}
let received_payload = &self.recv_buffer[..recv_len];
if !self.connections.exists(&address) {
self.event_sender.send(SocketEvent::Connect(address))?;
}
let connection = self
.connections
.get_or_insert_connection(address, &self.config);
connection.process_incoming(received_payload, &self.event_sender)?;
}
Err(e) => {
if e.kind() != io::ErrorKind::WouldBlock {
error!("Encountered an error receiving data: {:?}", e);
return Err(e.into());
}
}
}
Ok(())
}
fn send_packet(&self, addr: &SocketAddr, payload: &[u8]) -> Result<usize> {
let bytes_sent = self.socket.send_to(payload, addr)?;
Ok(bytes_sent)
}
fn should_send_packet(&self) -> bool {
if let Some(link_conditioner) = &self.link_conditioner {
link_conditioner.should_send()
} else {
true
}
}
}
#[cfg(test)]
mod tests {
use crate::{
net::constants::{ACKED_PACKET_HEADER, FRAGMENT_HEADER_SIZE, STANDARD_HEADER_SIZE},
Config, Packet, Socket, SocketEvent,
};
use std::net::SocketAddr;
use std::thread;
use std::time::Duration;
#[test]
fn can_send_and_receive() {
let (mut server, _, packet_receiver) =
Socket::bind("127.0.0.1:12342".parse::<SocketAddr>().unwrap()).unwrap();
let (mut client, packet_sender, _) =
Socket::bind("127.0.0.1:12341".parse::<SocketAddr>().unwrap()).unwrap();
thread::spawn(move || client.start_polling());
thread::spawn(move || server.start_polling());
for _ in 0..3 {
packet_sender
.send(Packet::unreliable(
"127.0.0.1:12342".parse::<SocketAddr>().unwrap(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let mut iter = packet_receiver.iter();
assert!(iter.next().is_some());
assert!(iter.next().is_some());
assert!(iter.next().is_some());
}
#[test]
fn sending_large_unreliable_packet_should_fail() {
let (mut server, _, _) =
Socket::bind("127.0.0.1:12370".parse::<SocketAddr>().unwrap()).unwrap();
assert_eq!(
server
.send_to(Packet::unreliable(
"127.0.0.1:12360".parse().unwrap(),
vec![1; 5000]
))
.is_err(),
true
);
}
#[test]
fn send_returns_right_size() {
let (mut server, _, _) =
Socket::bind("127.0.0.1:12371".parse::<SocketAddr>().unwrap()).unwrap();
assert_eq!(
server
.send_to(Packet::unreliable(
"127.0.0.1:12361".parse().unwrap(),
vec![1; 1024]
))
.unwrap(),
1024 + STANDARD_HEADER_SIZE as usize
);
}
#[test]
fn fragmentation_send_returns_right_size() {
let (mut server, _, _) =
Socket::bind("127.0.0.1:12372".parse::<SocketAddr>().unwrap()).unwrap();
let fragment_packet_size = STANDARD_HEADER_SIZE + FRAGMENT_HEADER_SIZE;
assert_eq!(
server
.send_to(Packet::reliable_unordered(
"127.0.0.1:12362".parse().unwrap(),
vec![1; 4000]
))
.unwrap(),
4000 + (fragment_packet_size * 4 + ACKED_PACKET_HEADER) as usize
);
}
#[test]
fn connect_event_occurs() {
let (mut server, _, packet_receiver) =
Socket::bind("127.0.0.1:12345".parse::<SocketAddr>().unwrap()).unwrap();
let (mut client, packet_sender, _) =
Socket::bind("127.0.0.1:12344".parse::<SocketAddr>().unwrap()).unwrap();
thread::spawn(move || client.start_polling());
thread::spawn(move || server.start_polling());
packet_sender
.send(Packet::unreliable(
"127.0.0.1:12345".parse().unwrap(),
vec![0, 1, 2],
))
.unwrap();
assert_eq!(
packet_receiver.recv().unwrap(),
SocketEvent::Connect("127.0.0.1:12344".parse().unwrap())
);
}
#[test]
fn disconnect_event_occurs() {
let mut config = Config::default();
config.idle_connection_timeout = Duration::from_millis(1);
let (mut server, _, packet_receiver) =
Socket::bind("127.0.0.1:12347".parse::<SocketAddr>().unwrap()).unwrap();
let (mut client, packet_sender, _) =
Socket::bind("127.0.0.1:12346".parse::<SocketAddr>().unwrap()).unwrap();
thread::spawn(move || client.start_polling());
thread::spawn(move || server.start_polling());
packet_sender
.send(Packet::unreliable(
"127.0.0.1:12347".parse().unwrap(),
vec![0, 1, 2],
))
.unwrap();
assert_eq!(
packet_receiver.recv().unwrap(),
SocketEvent::Connect("127.0.0.1:12346".parse().unwrap())
);
assert_eq!(
packet_receiver.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(
"127.0.0.1:12346".parse().unwrap(),
vec![0, 1, 2]
))
);
assert_eq!(
packet_receiver.recv().unwrap(),
SocketEvent::Timeout("127.0.0.1:12346".parse().unwrap())
);
}
const LOCAL_ADDR: &str = "127.0.0.1:13000";
const REMOTE_ADDR: &str = "127.0.0.1:14000";
fn create_test_packet(id: u8, addr: &str) -> Packet {
let payload = vec![id];
Packet::reliable_unordered(addr.parse().unwrap(), payload)
}
#[test]
fn multiple_sends_should_start_sending_dropped() {
let (mut server, server_sender, server_receiver) =
Socket::bind(REMOTE_ADDR.parse::<SocketAddr>().unwrap()).unwrap();
thread::spawn(move || server.start_polling());
let (mut client, client_sender, client_receiver) =
Socket::bind(LOCAL_ADDR.parse::<SocketAddr>().unwrap()).unwrap();
thread::spawn(move || client.start_polling());
for i in 0..35 {
client_sender
.send(create_test_packet(i, REMOTE_ADDR))
.unwrap();
}
let mut events = Vec::new();
loop {
if let Ok(event) = server_receiver.recv_timeout(Duration::from_millis(500)) {
events.push(event);
} else {
break;
}
}
assert_eq!(events.len(), 36);
server_sender
.send(create_test_packet(0, LOCAL_ADDR))
.unwrap();
client_receiver.recv().unwrap();
events.clear();
client_sender
.send(create_test_packet(35, REMOTE_ADDR))
.unwrap();
loop {
if let Ok(event) = server_receiver.recv_timeout(Duration::from_millis(500)) {
events.push(event);
} else {
break;
}
}
let sent_events: Vec<u8> = events
.iter()
.flat_map(|e| match e {
SocketEvent::Packet(p) => Some(p.payload()[0]),
_ => None,
})
.collect();
assert_eq!(sent_events, vec![0, 1, 35]);
}
}