Skip to content
Merged

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

210 changes: 197 additions & 13 deletions crates/defguard_common/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,25 @@ impl UserDevice {
// fetch device config and connection info for all allowed networks
let result = query!(
"SELECT n.id network_id, n.name network_name, n.endpoint gateway_endpoint, \
wnd.wireguard_ips \"device_wireguard_ips: Vec<IpAddr>\", vs.endpoint \"device_endpoint?\", \
vs.latest_handshake \"latest_handshake?\", \
vs.state \"state?: VpnClientSessionState\" \
wnd.wireguard_ips \"device_wireguard_ips: Vec<IpAddr>\", latest_session.endpoint \"device_endpoint?\", \
last_successful_session.connected_at \"last_connected_at?\", \
latest_session.state \"state?: VpnClientSessionState\" \
FROM wireguard_network_device wnd \
JOIN wireguard_network n ON n.id = wnd.wireguard_network_id \
LEFT JOIN LATERAL ( \
SELECT id, state, location_id, endpoint, latest_handshake \
SELECT id, state, location_id, endpoint, connected_at \
FROM vpn_client_session \
LEFT JOIN LATERAL ( \
SELECT session_id, endpoint, latest_handshake \
FROM vpn_session_stats \
WHERE session_id = vpn_client_session.id \
ORDER BY collected_at DESC \
LIMIT 1 \
) vss ON vss.session_id = vpn_client_session.id \
WHERE location_id = n.id and device_id = $1 \
ORDER BY created_at DESC, id DESC \
LIMIT 1 \
) vs ON vs.location_id = n.id \
) latest_session ON latest_session.location_id = n.id \
LEFT JOIN LATERAL ( \
SELECT connected_at \
FROM vpn_client_session \
WHERE location_id = n.id AND device_id = $1 AND connected_at IS NOT NULL \
ORDER BY connected_at DESC, id DESC \
LIMIT 1 \
) last_successful_session ON true \
WHERE wnd.device_id = $1",
device.id,
)
Expand Down Expand Up @@ -293,7 +293,7 @@ impl UserDevice {
.map(IpAddr::to_string)
.collect(),
last_connected_ip: device_ip,
last_connected_at: r.latest_handshake,
last_connected_at: r.last_connected_at,
is_active,
}
})
Expand Down Expand Up @@ -1793,6 +1793,190 @@ mod test {
assert_eq!(network_info.preshared_key, None);
}

#[sqlx::test]
async fn test_user_device_from_device_keeps_latest_successful_connection_timestamp(
_: PgPoolOptions,
options: PgConnectOptions,
) {
let pool = setup_pool(options).await;

let user = User::new(
"testuser",
Some("password"),
"Tester",
"Test",
"test@test.com",
None,
)
.save(&pool)
.await
.unwrap();

let device = Device::new(
"device".into(),
"pubkey".into(),
user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

let network = WireguardNetwork::default()
.try_set_address("10.1.1.1/24")
.unwrap()
.save(&pool)
.await
.unwrap();

WireguardNetworkDevice::new(
network.id,
device.id,
[IpAddr::from_str("10.1.1.2").unwrap()],
)
.insert(&pool)
.await
.unwrap();

let last_successful_connection = NaiveDate::from_ymd_opt(2026, 1, 2)
.expect("expected valid date")
.and_hms_opt(3, 4, 5)
.expect("expected valid time");
let newer_session_created_at = NaiveDate::from_ymd_opt(2026, 1, 3)
.expect("expected valid date")
.and_hms_opt(4, 5, 6)
.expect("expected valid time");

let mut connected_session = VpnClientSession::new(
network.id,
user.id,
device.id,
Some(last_successful_connection),
None,
);
connected_session.created_at = last_successful_connection;
connected_session.save(&pool).await.unwrap();

let mut disconnected_session =
VpnClientSession::new(network.id, user.id, device.id, None, None);
disconnected_session.created_at = newer_session_created_at;
disconnected_session.disconnected_at = Some(newer_session_created_at);
disconnected_session.state = VpnClientSessionState::Disconnected;
disconnected_session.save(&pool).await.unwrap();

let user_device = UserDevice::from_device(&pool, device)
.await
.unwrap()
.unwrap();
let network_info = user_device
.networks
.into_iter()
.find(|network_info| network_info.network_id == network.id)
.expect("expected created network in user device response");

assert!(!network_info.is_active);
assert_eq!(
network_info.last_connected_at,
Some(last_successful_connection)
);
}

#[sqlx::test]
async fn test_user_device_from_device_keeps_latest_successful_connection_timestamp_for_newer_new_session(
_: PgPoolOptions,
options: PgConnectOptions,
) {
let pool = setup_pool(options).await;

let user = User::new(
"testuser",
Some("password"),
"Tester",
"Test",
"test@test.com",
None,
)
.save(&pool)
.await
.unwrap();

let device = Device::new(
"device".into(),
"pubkey".into(),
user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

let network = WireguardNetwork::default()
.try_set_address("10.1.1.1/24")
.unwrap()
.save(&pool)
.await
.unwrap();

WireguardNetworkDevice::new(
network.id,
device.id,
[IpAddr::from_str("10.1.1.2").unwrap()],
)
.insert(&pool)
.await
.unwrap();

let last_successful_connection = NaiveDate::from_ymd_opt(2026, 1, 2)
.expect("expected valid date")
.and_hms_opt(3, 4, 5)
.expect("expected valid time");
let newer_session_created_at = NaiveDate::from_ymd_opt(2026, 1, 3)
.expect("expected valid date")
.and_hms_opt(4, 5, 6)
.expect("expected valid time");

let disconnected_at = NaiveDate::from_ymd_opt(2026, 1, 2)
.expect("expected valid date")
.and_hms_opt(3, 5, 6)
.expect("expected valid time");

let mut connected_session = VpnClientSession::new(
network.id,
user.id,
device.id,
Some(last_successful_connection),
None,
);
connected_session.created_at = last_successful_connection;
connected_session.disconnected_at = Some(disconnected_at);
connected_session.state = VpnClientSessionState::Disconnected;
connected_session.save(&pool).await.unwrap();

let mut new_session = VpnClientSession::new(network.id, user.id, device.id, None, None);
new_session.created_at = newer_session_created_at;
new_session.save(&pool).await.unwrap();

let user_device = UserDevice::from_device(&pool, device)
.await
.unwrap()
.unwrap();
let network_info = user_device
.networks
.into_iter()
.find(|network_info| network_info.network_id == network.id)
.expect("expected created network in user device response");

assert!(!network_info.is_active);
assert_eq!(
network_info.last_connected_at,
Some(last_successful_connection)
);
}

#[sqlx::test]
fn test_all_for_network_and_user(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;
Expand Down
Loading
Loading