This commit is contained in:
gary go 2026-06-22 07:39:13 +00:00 committed by GitHub
commit a7af8fac3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -63,6 +63,7 @@ use hbb_common::{
timeout,
tokio::{
self,
task::JoinSet,
net::UdpSocket,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver},
@ -1793,41 +1794,13 @@ impl LoginConfigHandler {
shared_password: Option<String>,
conn_token: Option<String>,
) {
let mut id = id;
if id.contains("@") {
let mut v = id.split("@");
let raw_id: &str = v.next().unwrap_or_default();
let mut server_key = v.next().unwrap_or_default().split('?');
let server = server_key.next().unwrap_or_default();
let args = server_key.next().unwrap_or_default();
let key = if server == PUBLIC_SERVER {
config::RS_PUB_KEY.to_owned()
} else {
let mut args_map: HashMap<String, &str> = HashMap::new();
for arg in args.split('&') {
if let Some(kv) = arg.find('=') {
let k = arg[0..kv].to_lowercase();
let v = &arg[kv + 1..];
args_map.insert(k, v);
}
}
let key = args_map.remove("key").unwrap_or_default();
key.to_owned()
};
// here we can check <id>/r@server
let real_id = crate::ui_interface::handle_relay_id(raw_id).to_string();
if real_id != raw_id {
force_relay = true;
}
self.other_server = Some((real_id.clone(), server.to_owned(), key));
id = format!("{real_id}@{server}");
} else {
let real_id = crate::ui_interface::handle_relay_id(&id);
if real_id != id {
force_relay = true;
id = real_id.to_owned();
}
let format_id = format_id(id.as_str());
let id = format_id.id;
if format_id.force_relay {
force_relay = true;
}
if format_id.server.is_some() {
self.other_server = format_id.server;
}
self.id = id;
@ -4049,14 +4022,67 @@ async fn hc_connection_(
Ok(())
}
#[derive(Debug, Clone)]
pub struct FormatId {
pub id: String,
pub server: Option<(String, String, String)>,
pub force_relay: bool,
}
fn format_id(id: &str) -> FormatId {
if id.contains("@") {
let mut force_relay = false;
let mut v = id.split("@");
let raw_id: &str = v.next().unwrap_or_default();
let mut server_key = v.next().unwrap_or_default().split('?');
let server = server_key.next().unwrap_or_default();
let args = server_key.next().unwrap_or_default();
let key = if server == PUBLIC_SERVER {
config::RS_PUB_KEY.to_owned()
} else {
let mut args_map: HashMap<String, &str> = HashMap::new();
for arg in args.split('&') {
if let Some(kv) = arg.find('=') {
let k = arg[0..kv].to_lowercase();
let v = &arg[kv + 1..];
args_map.insert(k, v);
}
}
let key = args_map.remove("key").unwrap_or_default();
key.to_owned()
};
// here we can check <id>/r@server
let real_id = crate::ui_interface::handle_relay_id(raw_id).to_string();
if real_id != raw_id {
force_relay = true;
}
FormatId {
id: format!("{real_id}@{server}"),
server: Some((real_id.to_string(), server.to_string(), key.to_string())),
force_relay,
}
} else {
let real_id = crate::ui_interface::handle_relay_id(&id);
FormatId {
id: real_id.to_string(),
server: None,
force_relay: real_id != id,
}
}
}
pub mod peer_online {
use std::collections::HashMap;
use hbb_common::{
anyhow::bail,
config::{Config, CONNECT_TIMEOUT, READ_TIMEOUT},
log,
rendezvous_proto::*,
sleep,
socket_client::connect_tcp,
socket_client::{check_port, connect_tcp},
tokio::task::JoinSet,
ResultType, Stream,
};
@ -4069,35 +4095,78 @@ pub mod peer_online {
f(onlines, offlines)
} else {
let query_timeout = std::time::Duration::from_millis(3_000);
match query_online_states_(&ids, query_timeout).await {
Ok((onlines, offlines)) => {
f(onlines, offlines);
}
Err(e) => {
log::debug!("query onlines, {}", &e);
let (rendezvous_server, _servers, _contained) =
crate::get_rendezvous_server(READ_TIMEOUT).await;
let group = group_query_online_states(ids, rendezvous_server.as_str());
let mut onlines = Vec::new();
let mut offlines = Vec::new();
let mut join = JoinSet::new();
for (server, map) in group.into_iter() {
let ids: Vec<String> = map.keys().map(|t| t.to_string()).collect();
let query_timeout = query_timeout.clone();
join.spawn(async move {
match query_online_states_(ids, query_timeout, server.as_str()).await {
Ok((on, off)) => {
let on: Vec<String> = on
.into_iter()
.filter_map(|id| map.get(&id).cloned())
.flat_map(|p| p.into_iter())
.collect();
let off: Vec<String> = off
.into_iter()
.filter_map(|id| map.get(&id).cloned())
.flat_map(|p| p.into_iter())
.collect();
Ok((on, off))
}
Err(e) => Err(e),
}
});
}
while let Some(res) = join.join_next().await {
match res {
Ok(Ok((on, off))) => {
onlines.extend(on);
offlines.extend(off);
}
Ok(Err(e)) => {
log::debug!("query onlines error: {}", e);
}
Err(e) => {
log::error!("task panicked: {}", e);
}
}
}
f(onlines, offlines);
}
}
async fn create_online_stream(rendezvous_server: &str) -> ResultType<Stream> {
let tmp = rendezvous_server.rfind(':').map(|pos| {
let url = &rendezvous_server[..pos];
let port: u16 = rendezvous_server[pos + 1..].parse().unwrap_or(0);
(url, port)
});
match tmp {
Some((url, port)) if port > 1 => {
let online_server = format!("{}:{}", url, port - 1);
connect_tcp(online_server, CONNECT_TIMEOUT).await
}
_ => {
bail!("Invalid server address: {}", rendezvous_server);
}
}
}
async fn create_online_stream() -> ResultType<Stream> {
let (rendezvous_server, _servers, _contained) =
crate::get_rendezvous_server(READ_TIMEOUT).await;
let tmp: Vec<&str> = rendezvous_server.split(":").collect();
if tmp.len() != 2 {
bail!("Invalid server address: {}", rendezvous_server);
}
let port: u16 = tmp[1].parse()?;
if port == 0 {
bail!("Invalid server address: {}", rendezvous_server);
}
let online_server = format!("{}:{}", tmp[0], port - 1);
connect_tcp(online_server, CONNECT_TIMEOUT).await
}
async fn query_online_states_(
ids: &Vec<String>,
ids: Vec<String>,
timeout: std::time::Duration,
rendezvous_server: &str,
) -> ResultType<(Vec<String>, Vec<String>)> {
let mut msg_out = RendezvousMessage::new();
msg_out.set_online_request(OnlineRequest {
@ -4106,7 +4175,7 @@ pub mod peer_online {
..Default::default()
});
let mut socket = match create_online_stream().await {
let mut socket = match create_online_stream(rendezvous_server).await {
Ok(s) => s,
Err(e) => {
log::debug!("Failed to create peers online stream, {e}");
@ -4156,6 +4225,50 @@ pub mod peer_online {
bail!("Failed to query online states, no online response");
}
fn group_query_online_states(
ids: Vec<String>,
main_rendezvous_server: &str,
) -> HashMap<String, HashMap<String, Vec<String>>> {
ids.iter()
.filter(|id| {
!hbb_common::is_ip_str(id.as_str()) && !hbb_common::is_domain_port_str(id.as_str())
})
.map(|id| (id.clone(), super::format_id(id.as_str())))
.map(|(raw_id, format)| {
if let Some((pure_id, server, _key)) = format.server {
(raw_id, pure_id, server)
} else {
(raw_id, format.id, main_rendezvous_server.to_string())
}
})
.map(|(raw_id, id, server)| {
if server == crate::client::PUBLIC_SERVER {
(
raw_id,
id,
check_port(
hbb_common::config::RENDEZVOUS_SERVERS[0],
hbb_common::config::RENDEZVOUS_PORT,
),
)
} else {
(
raw_id,
id,
check_port(server, hbb_common::config::RENDEZVOUS_PORT),
)
}
})
.fold(HashMap::new(), |mut map, (raw_id, id, server)| {
map.entry(server)
.or_insert_with(HashMap::new)
.entry(id)
.or_insert_with(Vec::new)
.push(raw_id);
map
})
}
#[cfg(test)]
mod tests {
use hbb_common::tokio;
@ -4175,6 +4288,55 @@ pub mod peer_online {
)
.await;
}
#[test]
fn test_group_query_online_states() {
use std::collections::HashMap;
assert_eq!(
super::group_query_online_states(
vec![
"152183996".to_string(),
"152183996/r".to_string(),
"456@custom.com".to_string(),
"45611@custom.com".to_string(),
"45611@custom.com:2000".to_string(),
"789@public".to_string(),
"abc".to_string(),
],
"localhost",
),
HashMap::from([
(
"localhost:21116".to_string(),
HashMap::from([
(
"152183996".to_string(),
vec!["152183996".to_string(), "152183996/r".to_string()]
),
("abc".to_string(), vec!["abc".to_string()])
])
),
(
"custom.com:21116".to_string(),
HashMap::from([
("456".to_string(), vec!["456@custom.com".to_string()]),
("45611".to_string(), vec!["45611@custom.com".to_string()])
])
),
(
"custom.com:2000".to_string(),
HashMap::from([(
"45611".to_string(),
vec!["45611@custom.com:2000".to_string()]
)])
),
(
"rs-ny.rustdesk.com:21116".to_string(),
HashMap::from([("789".to_string(), vec!["789@public".to_string()]),])
)
])
);
}
}
}