use core::net::IpAddr; use futures::{future::TryFutureExt, stream::TryStreamExt}; use ldap3::{Ldap, LdapConnAsync}; use rocket::fairing::{self, AdHoc}; use rocket::form::Form; use rocket::http::{Cookie, CookieJar, Status}; use rocket::outcome::{try_outcome, IntoOutcome}; use rocket::request::{FromRequest, Outcome, Request}; use rocket::response::status::Unauthorized; use rocket::serde::json::{self, Json}; use rocket::serde::{Deserialize, Serialize}; use rocket::{Build, Rocket, State}; use rocket_db_pools::{sqlx, Connection, Database}; use std::borrow::Cow; use std::cmp::Ordering; use std::collections::BTreeMap; use std::sync::Mutex; use std::sync::OnceLock; use std::time::Instant; use time::Duration; use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}; use utoipa::{Modify, OpenApi, ToSchema}; use crate::api_model; use crate::Db; #[derive(OpenApi)] #[openapi( paths(login, logout, status,), modifiers(&AuthApiAddon), )] pub struct AuthApi; pub struct AuthApiAddon; impl Modify for AuthApiAddon { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { let components = openapi.components.as_mut().unwrap(); components.add_security_scheme( "session", SecurityScheme::ApiKey(ApiKey::Cookie(ApiKeyValue::new(SESSION_COOKIE))), ) } } #[derive(FromForm, ToSchema)] struct Login<'r> { username: &'r str, password: &'r str, } #[derive(Debug, Deserialize)] #[cfg_attr(test, allow(dead_code))] struct AuthConfig<'a> { session_max_age_days: u32, ldap_url: Cow<'a, str>, ldap_users: Cow<'a, str>, ldap_filter: Cow<'a, str>, } struct SessionsData { active_ids: BTreeMap, next_id: u32, } struct Sessions { data: Mutex, } #[cfg_attr(test, allow(dead_code))] struct LdapState { ldap: OnceLock, } #[derive(Debug, Deserialize, Serialize)] pub struct Session { pub user_id: u64, session_id: u32, remote: String, } #[derive(Debug)] pub enum SessionError { Invalid, } const SESSION_COOKIE: &str = "s"; const STATUS_OK: api_model::StatusResponse = api_model::StatusResponse { ok: true, error: None, }; const STATUS_UNAUTHORIZED: api_model::StatusResponse = api_model::StatusResponse { ok: false, error: Some("Unauthorized"), }; fn validate(sessions: &State, session: &Session, request: &Request<'_>) -> bool { match request.client_ip() { Some(addr) => { if session.remote == addr.to_string() { { let sessions_data = sessions.data.lock().unwrap(); match sessions_data.active_ids.get(&session.session_id) { // We could remove the expired session here, but it will be cleaned // next time anyone logs in anyway. Some(&expire) => expire > Instant::now(), None => false, } } } else { false } } None => false, } } #[rocket::async_trait] impl<'r> FromRequest<'r> for Session { type Error = SessionError; async fn from_request(request: &'r Request<'_>) -> Outcome { let sessions = try_outcome!(request .guard::<&State>() .await .map_error(|_| (Status::Unauthorized, SessionError::Invalid))); request .cookies() .get_private(SESSION_COOKIE) .and_then(|cookie| -> Option { json::from_str(cookie.value()).ok() }) .and_then(|session| { if validate(sessions, &session, request) { Some(session) } else { None } }) .or_error((Status::Unauthorized, SessionError::Invalid)) } } fn new_session( sessions: &State, user_id: u64, remote: String, max_age: Duration, ) -> Session { let session_id; { let mut sessions_data = sessions.data.lock().unwrap(); session_id = sessions_data.next_id; sessions_data.next_id += 1; let now = Instant::now(); // Remove expired sessions first sessions_data .active_ids .retain(|_, &mut expire| expire > now); sessions_data.active_ids.insert(session_id, now + max_age); } Session { user_id, session_id, remote, } } #[cfg(not(test))] async fn authenticate(ldap_state: &State, dn: &str, password: &str) -> bool { let mut ldap = ldap_state.ldap.get().unwrap().clone(); let maybe_result = ldap.compare(dn, "userPassword", password.as_bytes()).await; if let Ok(result) = maybe_result { if let Ok(is_equal) = result.equal() { return is_equal; } } false } #[cfg(test)] async fn authenticate(_ldap_state: &State, dn: &str, password: &str) -> bool { match dn { "user" => password == "password", "other" => password == "secret", _ => false, } } #[utoipa::path( responses( (status = 200, description = "Login successful", body = api_model::StatusResponse, example = json!(STATUS_OK)), (status = 401, description = "Login failed", body = api_model::StatusResponse, example = json!(STATUS_UNAUTHORIZED)), ), security( (), ), )] #[post("/login", data = "")] async fn login( auth_config: &State>, ldap_state: &State, sessions: &State, ipaddr: IpAddr, cookies: &CookieJar<'_>, mut db: Connection, login: Form>, ) -> Result, Unauthorized<&'static str>> { let (user_id, maybe_dn) = sqlx::query!("SELECT id,dn FROM users WHERE username=?", login.username) .fetch_one(&mut **db) .map_ok(|r| (r.id, r.dn)) .map_err(|_| Unauthorized("Unknown username or password")) .await?; if let Some(dn) = maybe_dn { if authenticate(ldap_state, dn.as_str(), login.password).await { let max_age = Duration::days(i64::from(auth_config.session_max_age_days)); let session = new_session(sessions, user_id, ipaddr.to_string(), max_age); let cookie = Cookie::build((SESSION_COOKIE, json::to_string(&session).unwrap())) .path("/api") .max_age(max_age) .http_only(true) .build(); cookies.add_private(cookie); return Ok(Json(STATUS_OK)); } } Err(Unauthorized("Unknown username or password")) } #[utoipa::path( responses( (status = 200, description = "Logout successful", body = api_model::StatusResponse, example = json!(STATUS_OK)), ), security( ("session" = []), ), )] #[get("/logout")] fn logout( session: Session, sessions: &State, cookies: &CookieJar<'_>, ) -> Json { { let mut sessions_data = sessions.data.lock().unwrap(); sessions_data.active_ids.remove(&session.session_id); } let cookie = Cookie::build((SESSION_COOKIE, "")) .path("/api") .http_only(true) .build(); cookies.remove_private(cookie); Json(STATUS_OK) } #[utoipa::path( responses( (status = 200, description = "Current status", body = api_model::StatusResponse, example = json!(STATUS_OK)), (status = 401, description = "Not authorized", body = api_model::StatusResponse, example = json!(STATUS_UNAUTHORIZED)), ), security( (), ("session" = []), ), )] #[get("/status")] fn status(_session: Session) -> Json { Json(STATUS_OK) } #[catch(401)] fn unauthorized() -> Json { Json(STATUS_UNAUTHORIZED) } #[cfg_attr(test, allow(dead_code))] async fn setup_ldap( ldap_state: &LdapState, config: &AuthConfig<'_>, ) -> Result { let (conn, ldap) = LdapConnAsync::new(&config.ldap_url).await?; ldap3::drive!(conn); let ret = ldap.clone(); ldap_state .ldap .set(ldap) .expect("setup_ldap must only be called once"); Ok(ret) } #[derive(Debug)] #[allow(dead_code)] enum LdapOrSqlError { LdapError(ldap3::LdapError), SqlError(sqlx::Error), } #[cfg_attr(test, allow(dead_code))] async fn sync_ldap( ldap_state: &LdapState, config: &AuthConfig<'_>, db: &Db, ) -> Result<(), LdapOrSqlError> { let mut ldap = setup_ldap(ldap_state, config) .map_err(LdapOrSqlError::LdapError) .await?; let (entries, _) = ldap .search( &config.ldap_users, ldap3::Scope::OneLevel, &config.ldap_filter, vec!["uid"], ) .map_err(LdapOrSqlError::LdapError) .await? .success() .map_err(LdapOrSqlError::LdapError)?; let mut tx = db.begin().await.unwrap(); // TODO: Insert/Update name as well as dn. let db_users = sqlx::query!("SELECT id,username,dn FROM users ORDER BY username") .fetch(&mut *tx) .map_ok(|r| (r.id, r.username, r.dn)) .try_collect::>() .await .unwrap(); let mut new_users: Vec<(String, String)> = Vec::new(); let mut updated_users: Vec<(u64, String)> = Vec::new(); let mut old_users: Vec = Vec::new(); let mut db_user = db_users.iter().peekable(); for entry in entries { let se = ldap3::SearchEntry::construct(entry); let uid = se.attrs.get("uid").unwrap().first().unwrap(); loop { if let Some(du) = db_user.peek() { match du.1.cmp(uid) { Ordering::Equal => { if du.2.as_ref().is_none_or(|x| *x != se.dn) { updated_users.push((du.0, se.dn)); } db_user.next(); break; } Ordering::Less => { old_users.push(du.0); db_user.next(); continue; } Ordering::Greater => (), } } new_users.push((uid.to_string(), se.dn)); break; } } if !new_users.is_empty() { let mut query_builder: sqlx::QueryBuilder = sqlx::QueryBuilder::new("INSERT INTO users (username,dn) VALUES"); let mut first = true; for pair in new_users { if first { first = false; } else { query_builder.push(","); } query_builder.push("("); query_builder.push_bind(pair.0); query_builder.push(","); query_builder.push_bind(pair.1); query_builder.push(")"); } query_builder .build() .execute(&mut *tx) .map_err(LdapOrSqlError::SqlError) .await?; } for pair in updated_users { sqlx::query!("UPDATE users SET dn=? WHERE id=?", pair.1, pair.0) .execute(&mut *tx) .map_err(LdapOrSqlError::SqlError) .await?; } if !old_users.is_empty() { let params = format!("?{}", ", ?".repeat(old_users.len() - 1)); let query_str = format!("UPDATE users SET dn=NULL WHERE id IN ({})", params); let mut query = sqlx::query(&query_str); for id in old_users { query = query.bind(id); } query .execute(&mut *tx) .map_err(LdapOrSqlError::SqlError) .await?; } tx.commit().map_err(LdapOrSqlError::SqlError).await?; Ok(()) } #[cfg(not(test))] async fn run_import(rocket: Rocket) -> fairing::Result { match rocket.state::() { Some(config) => match rocket.state::() { Some(ldap) => match Db::fetch(&rocket) { Some(db) => match sync_ldap(ldap, config, db).await { Ok(_) => Ok(rocket), Err(_) => Err(rocket), }, None => Err(rocket), }, None => Err(rocket), }, None => Err(rocket), } } #[cfg(test)] async fn run_import(rocket: Rocket) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::query!( "INSERT IGNORE INTO users (username,dn) VALUES (?,?), (?,?)", "user", "user", "other", "other", ) .execute(&**db) .await { Ok(_) => Ok(rocket), Err(_) => Err(rocket), }, None => Err(rocket), } } pub fn stage(basepath: &str) -> AdHoc { let l_basepath = basepath.to_string(); AdHoc::on_ignite("Auth Stage", |rocket| async { rocket .manage(Sessions { data: Mutex::new(SessionsData { active_ids: BTreeMap::new(), next_id: 1, }), }) .attach(AdHoc::config::()) .manage(LdapState { ldap: OnceLock::new(), }) .attach(AdHoc::try_on_ignite("Auth Import", run_import)) .mount(l_basepath.clone(), routes![login, logout, status]) .register(l_basepath, catchers![unauthorized]) }) }