Skip to content

Commit

Permalink
refactor session insert, update if already exists (rustdesk#9739)
Browse files Browse the repository at this point in the history
* All share the same last_receive_time
* Not second port forward

Signed-off-by: 21pages <[email protected]>
  • Loading branch information
21pages authored Oct 24, 2024
1 parent 4da5840 commit c8b9031
Showing 1 changed file with 61 additions and 69 deletions.
130 changes: 61 additions & 69 deletions src/server/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ struct Session {
last_recv_time: Arc<Mutex<Instant>>,
random_password: String,
tfa: bool,
conn_type: AuthConnType,
conn_id: i32,
}

#[cfg(not(any(target_os = "android", target_os = "ios")))]
Expand Down Expand Up @@ -217,7 +215,7 @@ pub struct Connection {
server_audit_conn: String,
server_audit_file: String,
lr: LoginRequest,
last_recv_time: Arc<Mutex<Instant>>,
session_last_recv_time: Option<Arc<Mutex<Instant>>>,
chat_unanswered: bool,
file_transferred: bool,
#[cfg(windows)]
Expand Down Expand Up @@ -364,7 +362,7 @@ impl Connection {
server_audit_conn: "".to_owned(),
server_audit_file: "".to_owned(),
lr: Default::default(),
last_recv_time: Arc::new(Mutex::new(Instant::now())),
session_last_recv_time: None,
chat_unanswered: false,
file_transferred: false,
#[cfg(windows)]
Expand Down Expand Up @@ -595,7 +593,7 @@ impl Connection {
},
Ok(bytes) => {
last_recv_time = Instant::now();
*conn.last_recv_time.lock().unwrap() = Instant::now();
conn.session_last_recv_time.as_mut().map(|t| *t.lock().unwrap() = Instant::now());
if let Ok(msg_in) = Message::parse_from_bytes(&bytes) {
if !conn.on_message(msg_in).await {
break;
Expand Down Expand Up @@ -762,6 +760,10 @@ impl Connection {
}
if let Err(err) = conn.try_port_forward_loop(&mut rx_from_cm).await {
conn.on_close(&err.to_string(), false).await;
raii::AuthedConnID::remove_session_if_last_duplication(
conn.inner.id(),
conn.session_key(),
);
}

conn.post_conn_audit(json!({
Expand Down Expand Up @@ -1140,6 +1142,11 @@ impl Connection {
auth_conn_type,
self.session_key(),
));
self.session_last_recv_time = SESSIONS
.lock()
.unwrap()
.get(&self.session_key())
.map(|s| s.last_recv_time.clone());
self.post_conn_audit(
json!({"peer": ((&self.lr.my_id, &self.lr.my_name)), "type": conn_type}),
);
Expand Down Expand Up @@ -1549,15 +1556,10 @@ impl Connection {
if password::temporary_enabled() {
let password = password::temporary_password();
if self.validate_one_password(password.clone()) {
raii::AuthedConnID::insert_session(
raii::AuthedConnID::update_or_insert_session(
self.session_key(),
Session {
last_recv_time: self.last_recv_time.clone(),
random_password: password,
tfa: false,
conn_type: self.conn_type(),
conn_id: self.inner.id(),
},
Some(password),
Some(false),
);
return true;
}
Expand All @@ -1581,15 +1583,11 @@ impl Connection {
.get(&self.session_key())
.map(|s| s.to_owned());
// last_recv_time is a mutex variable shared with connection, can be updated lively.
if let Some(mut session) = session {
if let Some(session) = session {
if !self.lr.password.is_empty()
&& (tfa && session.tfa
|| !tfa && self.validate_one_password(session.random_password.clone()))
{
session.last_recv_time = self.last_recv_time.clone();
session.conn_id = self.inner.id();
session.conn_type = self.conn_type();
raii::AuthedConnID::insert_session(self.session_key(), session);
log::info!("is recent session");
return true;
}
Expand Down Expand Up @@ -1841,34 +1839,13 @@ impl Connection {
if res {
self.update_failure(failure, true, 1);
self.require_2fa.take();
raii::AuthedConnID::set_session_2fa(self.session_key());
self.send_logon_response().await;
self.try_start_cm(
self.lr.my_id.to_owned(),
self.lr.my_name.to_owned(),
self.authorized,
);
let session = SESSIONS
.lock()
.unwrap()
.get(&self.session_key())
.map(|s| s.to_owned());
if let Some(mut session) = session {
session.tfa = true;
session.conn_id = self.inner.id();
session.conn_type = self.conn_type();
raii::AuthedConnID::insert_session(self.session_key(), session);
} else {
raii::AuthedConnID::insert_session(
self.session_key(),
Session {
last_recv_time: self.last_recv_time.clone(),
random_password: "".to_owned(),
tfa: true,
conn_type: self.conn_type(),
conn_id: self.inner.id(),
},
);
}
if !tfa.hwid.is_empty() && Self::enable_trusted_devices() {
Config::add_trusted_device(TrustedDevice {
hwid: tfa.hwid,
Expand Down Expand Up @@ -3872,49 +3849,64 @@ mod raii {
}

pub fn remove_session_if_last_duplication(conn_id: i32, key: SessionKey) {
let contains = SESSIONS.lock().unwrap().contains_key(&key);
let mut lock = SESSIONS.lock().unwrap();
let contains = lock.contains_key(&key);
if contains {
let another = AUTHED_CONNS
.lock()
.unwrap()
.iter()
.any(|c| c.0 != conn_id && c.2 == key && c.1 != AuthConnType::PortForward);
.any(|c| c.0 != conn_id && c.2 == key);
if !another {
// Keep the session if there is another connection with same peer_id and session_id.
SESSIONS.lock().unwrap().remove(&key);
lock.remove(&key);
log::info!("remove session");
} else {
log::info!("skip remove session");
}
}
}

pub fn insert_session(key: SessionKey, session: Session) {
let mut insert = true;
if session.conn_type == AuthConnType::PortForward {
// port forward doesn't update last received time
let other_alive_conns = AUTHED_CONNS
.lock()
.unwrap()
.iter()
.filter(|c| {
c.2 == key && c.1 != AuthConnType::PortForward // port forward doesn't remove itself
})
.map(|c| c.0)
.collect::<Vec<_>>();
let another = SESSIONS.lock().unwrap().get(&key).map(|s| {
other_alive_conns.contains(&s.conn_id)
&& s.tfa == session.tfa
&& s.conn_type != AuthConnType::PortForward
}) == Some(true);
if another {
insert = false;
log::info!("skip insert session for port forward");
}
}
if insert {
log::info!("insert session for {:?}", session.conn_type);
SESSIONS.lock().unwrap().insert(key, session);
pub fn update_or_insert_session(
key: SessionKey,
password: Option<String>,
tfa: Option<bool>,
) {
let mut lock = SESSIONS.lock().unwrap();
let session = lock.get_mut(&key);
if let Some(session) = session {
if let Some(password) = password {
session.random_password = password;
}
if let Some(tfa) = tfa {
session.tfa = tfa;
}
} else {
lock.insert(
key,
Session {
random_password: password.unwrap_or_default(),
tfa: tfa.unwrap_or_default(),
last_recv_time: Arc::new(Mutex::new(Instant::now())),
},
);
}
}

pub fn set_session_2fa(key: SessionKey) {
let mut lock = SESSIONS.lock().unwrap();
let session = lock.get_mut(&key);
if let Some(session) = session {
session.tfa = true;
} else {
lock.insert(
key,
Session {
last_recv_time: Arc::new(Mutex::new(Instant::now())),
random_password: "".to_owned(),
tfa: true,
},
);
}
}
}
Expand Down

0 comments on commit c8b9031

Please sign in to comment.