1use std::collections::VecDeque;
4
5use crate::connectable::{UdpConnectable, UdpMode};
6use crate::connection::MavConnection;
7use crate::peek_reader::PeekReader;
8use crate::{MavHeader, MavlinkVersion, Message, ReadVersion};
9use core::ops::DerefMut;
10use std::io::{self, Read};
11use std::net::{SocketAddr, UdpSocket};
12use std::sync::Mutex;
13
14use super::{get_socket_addr, Connectable};
15
16#[cfg(not(feature = "signing"))]
17use crate::{read_versioned_msg, write_versioned_msg};
18
19#[cfg(feature = "signing")]
20use crate::{read_versioned_msg_signed, write_versioned_msg_signed, SigningConfig, SigningData};
21
22struct UdpRead {
23 socket: UdpSocket,
24 buffer: VecDeque<u8>,
25 last_recv_address: Option<SocketAddr>,
26}
27
28const MTU_SIZE: usize = 1500;
29impl Read for UdpRead {
30 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
31 if !self.buffer.is_empty() {
32 self.buffer.read(buf)
33 } else {
34 let mut read_buffer = [0u8; MTU_SIZE];
35 let (n_buffer, address) = self.socket.recv_from(&mut read_buffer)?;
36 let n = (&read_buffer[0..n_buffer]).read(buf)?;
37 self.buffer.extend(&read_buffer[n..n_buffer]);
38
39 self.last_recv_address = Some(address);
40 Ok(n)
41 }
42 }
43}
44
45struct UdpWrite {
46 socket: UdpSocket,
47 dest: Option<SocketAddr>,
48 sequence: u8,
49}
50
51pub struct UdpConnection {
52 reader: Mutex<PeekReader<UdpRead>>,
53 writer: Mutex<UdpWrite>,
54 protocol_version: MavlinkVersion,
55 recv_any_version: bool,
56 server: bool,
57 #[cfg(feature = "signing")]
58 signing_data: Option<SigningData>,
59}
60
61impl UdpConnection {
62 fn new(socket: UdpSocket, server: bool, dest: Option<SocketAddr>) -> io::Result<Self> {
63 Ok(Self {
64 server,
65 reader: Mutex::new(PeekReader::new(UdpRead {
66 socket: socket.try_clone()?,
67 buffer: VecDeque::new(),
68 last_recv_address: None,
69 })),
70 writer: Mutex::new(UdpWrite {
71 socket,
72 dest,
73 sequence: 0,
74 }),
75 protocol_version: MavlinkVersion::V2,
76 recv_any_version: false,
77 #[cfg(feature = "signing")]
78 signing_data: None,
79 })
80 }
81}
82
83impl<M: Message> MavConnection<M> for UdpConnection {
84 fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
85 let mut reader = self.reader.lock().unwrap();
86
87 loop {
88 let version = ReadVersion::from_conn_cfg::<_, M>(self);
89 #[cfg(not(feature = "signing"))]
90 let result = read_versioned_msg(reader.deref_mut(), version);
91 #[cfg(feature = "signing")]
92 let result =
93 read_versioned_msg_signed(reader.deref_mut(), version, self.signing_data.as_ref());
94 if self.server {
95 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
96 self.writer.lock().unwrap().dest = addr;
97 }
98 }
99 if let ok @ Ok(..) = result {
100 return ok;
101 }
102 }
103 }
104
105 fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
106 let mut reader = self.reader.lock().unwrap();
107 let version = ReadVersion::from_conn_cfg::<_, M>(self);
108
109 #[cfg(not(feature = "signing"))]
110 let result = read_versioned_msg(reader.deref_mut(), version);
111 #[cfg(feature = "signing")]
112 let result =
113 read_versioned_msg_signed(reader.deref_mut(), version, self.signing_data.as_ref());
114
115 if self.server {
116 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
117 self.writer.lock().unwrap().dest = addr;
118 }
119 }
120
121 result
122 }
123
124 fn send(&self, header: &MavHeader, data: &M) -> Result<usize, crate::error::MessageWriteError> {
125 let mut guard = self.writer.lock().unwrap();
126 let state = &mut *guard;
127
128 let header = MavHeader {
129 sequence: state.sequence,
130 system_id: header.system_id,
131 component_id: header.component_id,
132 };
133
134 state.sequence = state.sequence.wrapping_add(1);
135
136 let len = if let Some(addr) = state.dest {
137 let mut buf = Vec::new();
138 #[cfg(not(feature = "signing"))]
139 write_versioned_msg(&mut buf, self.protocol_version, header, data)?;
140 #[cfg(feature = "signing")]
141 write_versioned_msg_signed(
142 &mut buf,
143 self.protocol_version,
144 header,
145 data,
146 self.signing_data.as_ref(),
147 )?;
148 state.socket.send_to(&buf, addr)?
149 } else {
150 0
151 };
152
153 Ok(len)
154 }
155
156 fn set_protocol_version(&mut self, version: MavlinkVersion) {
157 self.protocol_version = version;
158 }
159
160 fn protocol_version(&self) -> MavlinkVersion {
161 self.protocol_version
162 }
163
164 fn set_allow_recv_any_version(&mut self, allow: bool) {
165 self.recv_any_version = allow
166 }
167
168 fn allow_recv_any_version(&self) -> bool {
169 self.recv_any_version
170 }
171
172 #[cfg(feature = "signing")]
173 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
174 self.signing_data = signing_data.map(SigningData::from_config)
175 }
176}
177
178impl Connectable for UdpConnectable {
179 fn connect<M: Message>(&self) -> io::Result<Box<dyn MavConnection<M> + Sync + Send>> {
180 let (addr, server, dest): (&str, _, _) = match self.mode {
181 UdpMode::Udpin => (&self.address, true, None),
182 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
183 };
184 let socket = UdpSocket::bind(addr)?;
185 if matches!(self.mode, UdpMode::Udpcast) {
186 socket.set_broadcast(true)?;
187 }
188 Ok(Box::new(UdpConnection::new(socket, server, dest)?))
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_datagram_buffering() {
198 let receiver_socket = UdpSocket::bind("127.0.0.1:5000").unwrap();
199 let mut udp_reader = UdpRead {
200 socket: receiver_socket.try_clone().unwrap(),
201 buffer: VecDeque::new(),
202 last_recv_address: None,
203 };
204 let sender_socket = UdpSocket::bind("0.0.0.0:0").unwrap();
205 sender_socket.connect("127.0.0.1:5000").unwrap();
206
207 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
208
209 let mut n_sent = sender_socket.send(&datagram).unwrap();
210 assert_eq!(n_sent, datagram.len());
211 n_sent = sender_socket.send(&datagram).unwrap();
212 assert_eq!(n_sent, datagram.len());
213
214 let mut buf = [0u8; 30];
215
216 let mut n_read = udp_reader.read(&mut buf).unwrap();
217 assert_eq!(n_read, 30);
218 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
219
220 n_read = udp_reader.read(&mut buf).unwrap();
221 assert_eq!(n_read, 20);
222 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
223
224 n_read = udp_reader.read(&mut buf).unwrap();
225 assert_eq!(n_read, 30);
226 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
227
228 n_read = udp_reader.read(&mut buf).unwrap();
229 assert_eq!(n_read, 20);
230 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
231 }
232}