mavlink_core/connection/
udp.rs

1//! UDP MAVLink connection
2
3use std::collections::VecDeque;
4
5use crate::connectable::{UdpConnectable, UdpMode};
6use crate::connection::MavConnection;
7use crate::peek_reader::PeekReader;
8use crate::{MavHeader, MavlinkVersion, Message, ReadVersion};
9use core::ops::DerefMut;
10use std::io::{self, Read};
11use std::net::{SocketAddr, UdpSocket};
12use std::sync::Mutex;
13
14use super::{get_socket_addr, Connectable};
15
16#[cfg(not(feature = "signing"))]
17use crate::{read_versioned_msg, write_versioned_msg};
18
19#[cfg(feature = "signing")]
20use crate::{read_versioned_msg_signed, write_versioned_msg_signed, SigningConfig, SigningData};
21
22struct UdpRead {
23    socket: UdpSocket,
24    buffer: VecDeque<u8>,
25    last_recv_address: Option<SocketAddr>,
26}
27
28const MTU_SIZE: usize = 1500;
29impl Read for UdpRead {
30    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
31        if !self.buffer.is_empty() {
32            self.buffer.read(buf)
33        } else {
34            let mut read_buffer = [0u8; MTU_SIZE];
35            let (n_buffer, address) = self.socket.recv_from(&mut read_buffer)?;
36            let n = (&read_buffer[0..n_buffer]).read(buf)?;
37            self.buffer.extend(&read_buffer[n..n_buffer]);
38
39            self.last_recv_address = Some(address);
40            Ok(n)
41        }
42    }
43}
44
45struct UdpWrite {
46    socket: UdpSocket,
47    dest: Option<SocketAddr>,
48    sequence: u8,
49}
50
51pub struct UdpConnection {
52    reader: Mutex<PeekReader<UdpRead>>,
53    writer: Mutex<UdpWrite>,
54    protocol_version: MavlinkVersion,
55    recv_any_version: bool,
56    server: bool,
57    #[cfg(feature = "signing")]
58    signing_data: Option<SigningData>,
59}
60
61impl UdpConnection {
62    fn new(socket: UdpSocket, server: bool, dest: Option<SocketAddr>) -> io::Result<Self> {
63        Ok(Self {
64            server,
65            reader: Mutex::new(PeekReader::new(UdpRead {
66                socket: socket.try_clone()?,
67                buffer: VecDeque::new(),
68                last_recv_address: None,
69            })),
70            writer: Mutex::new(UdpWrite {
71                socket,
72                dest,
73                sequence: 0,
74            }),
75            protocol_version: MavlinkVersion::V2,
76            recv_any_version: false,
77            #[cfg(feature = "signing")]
78            signing_data: None,
79        })
80    }
81}
82
83impl<M: Message> MavConnection<M> for UdpConnection {
84    fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
85        let mut reader = self.reader.lock().unwrap();
86
87        loop {
88            let version = ReadVersion::from_conn_cfg::<_, M>(self);
89            #[cfg(not(feature = "signing"))]
90            let result = read_versioned_msg(reader.deref_mut(), version);
91            #[cfg(feature = "signing")]
92            let result =
93                read_versioned_msg_signed(reader.deref_mut(), version, self.signing_data.as_ref());
94            if self.server {
95                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
96                    self.writer.lock().unwrap().dest = addr;
97                }
98            }
99            if let ok @ Ok(..) = result {
100                return ok;
101            }
102        }
103    }
104
105    fn send(&self, header: &MavHeader, data: &M) -> Result<usize, crate::error::MessageWriteError> {
106        let mut guard = self.writer.lock().unwrap();
107        let state = &mut *guard;
108
109        let header = MavHeader {
110            sequence: state.sequence,
111            system_id: header.system_id,
112            component_id: header.component_id,
113        };
114
115        state.sequence = state.sequence.wrapping_add(1);
116
117        let len = if let Some(addr) = state.dest {
118            let mut buf = Vec::new();
119            #[cfg(not(feature = "signing"))]
120            write_versioned_msg(&mut buf, self.protocol_version, header, data)?;
121            #[cfg(feature = "signing")]
122            write_versioned_msg_signed(
123                &mut buf,
124                self.protocol_version,
125                header,
126                data,
127                self.signing_data.as_ref(),
128            )?;
129            state.socket.send_to(&buf, addr)?
130        } else {
131            0
132        };
133
134        Ok(len)
135    }
136
137    fn set_protocol_version(&mut self, version: MavlinkVersion) {
138        self.protocol_version = version;
139    }
140
141    fn protocol_version(&self) -> MavlinkVersion {
142        self.protocol_version
143    }
144
145    fn set_allow_recv_any_version(&mut self, allow: bool) {
146        self.recv_any_version = allow
147    }
148
149    fn allow_recv_any_version(&self) -> bool {
150        self.recv_any_version
151    }
152
153    #[cfg(feature = "signing")]
154    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
155        self.signing_data = signing_data.map(SigningData::from_config)
156    }
157}
158
159impl Connectable for UdpConnectable {
160    fn connect<M: Message>(&self) -> io::Result<Box<dyn MavConnection<M> + Sync + Send>> {
161        let (addr, server, dest): (&str, _, _) = match self.mode {
162            UdpMode::Udpin => (&self.address, true, None),
163            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
164        };
165        let socket = UdpSocket::bind(addr)?;
166        if matches!(self.mode, UdpMode::Udpcast) {
167            socket.set_broadcast(true)?;
168        }
169        Ok(Box::new(UdpConnection::new(socket, server, dest)?))
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_datagram_buffering() {
179        let receiver_socket = UdpSocket::bind("127.0.0.1:5000").unwrap();
180        let mut udp_reader = UdpRead {
181            socket: receiver_socket.try_clone().unwrap(),
182            buffer: VecDeque::new(),
183            last_recv_address: None,
184        };
185        let sender_socket = UdpSocket::bind("0.0.0.0:0").unwrap();
186        sender_socket.connect("127.0.0.1:5000").unwrap();
187
188        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
189
190        let mut n_sent = sender_socket.send(&datagram).unwrap();
191        assert_eq!(n_sent, datagram.len());
192        n_sent = sender_socket.send(&datagram).unwrap();
193        assert_eq!(n_sent, datagram.len());
194
195        let mut buf = [0u8; 30];
196
197        let mut n_read = udp_reader.read(&mut buf).unwrap();
198        assert_eq!(n_read, 30);
199        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
200
201        n_read = udp_reader.read(&mut buf).unwrap();
202        assert_eq!(n_read, 20);
203        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
204
205        n_read = udp_reader.read(&mut buf).unwrap();
206        assert_eq!(n_read, 30);
207        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
208
209        n_read = udp_reader.read(&mut buf).unwrap();
210        assert_eq!(n_read, 20);
211        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
212    }
213}