use core::net::IpAddr; use futures::future::TryFutureExt; use rocket::fairing::{self, AdHoc}; use rocket::form::Form; use rocket::http::{Cookie, CookieJar, Status}; use rocket::outcome::{try_outcome, IntoOutcome}; use rocket::request::{FromRequest, Outcome, Request}; use rocket::response::status::Unauthorized; use rocket::serde::json::{self, Json}; use rocket::serde::{Deserialize, Serialize}; use rocket::{Build, Rocket, State}; use rocket_db_pools::{sqlx, Connection, Database}; use std::collections::BTreeMap; use std::sync::Mutex; use std::time::Instant; use time::Duration; use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}; use utoipa::{Modify, OpenApi, ToSchema}; use crate::api_model; use crate::Db; #[derive(OpenApi)] #[openapi( paths(login, logout, status,), modifiers(&AuthApiAddon), )] pub struct AuthApi; pub struct AuthApiAddon; impl Modify for AuthApiAddon { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { let components = openapi.components.as_mut().unwrap(); components.add_security_scheme( "session", SecurityScheme::ApiKey(ApiKey::Cookie(ApiKeyValue::new(SESSION_COOKIE))), ) } } #[derive(FromForm, ToSchema)] struct Login<'r> { username: &'r str, password: &'r str, } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct AuthConfig { session_max_age_days: u32, } struct SessionsData { active_ids: BTreeMap, next_id: u32, } struct Sessions { data: Mutex, } #[derive(Debug, Deserialize, Serialize)] pub struct Session { pub user_id: u64, session_id: u32, remote: String, } #[derive(Debug)] pub enum SessionError { Invalid, } const SESSION_COOKIE: &str = "s"; const STATUS_OK: api_model::StatusResponse = api_model::StatusResponse { ok: true, error: None, }; const STATUS_UNAUTHORIZED: api_model::StatusResponse = api_model::StatusResponse { ok: false, error: Some("Unauthorized"), }; fn validate(sessions: &State, session: &Session, request: &Request<'_>) -> bool { match request.client_ip() { Some(addr) => { if session.remote == addr.to_string() { { let sessions_data = sessions.data.lock().unwrap(); match sessions_data.active_ids.get(&session.session_id) { // We could remove the expired session here, but it will be cleaned // next time anyone logs in anyway. Some(&expire) => expire > Instant::now(), None => false, } } } else { false } } None => false, } } #[rocket::async_trait] impl<'r> FromRequest<'r> for Session { type Error = SessionError; async fn from_request(request: &'r Request<'_>) -> Outcome { let sessions = try_outcome!(request .guard::<&State>() .await .map_error(|_| (Status::Unauthorized, SessionError::Invalid))); request .cookies() .get_private(SESSION_COOKIE) .and_then(|cookie| -> Option { json::from_str(cookie.value()).ok() }) .and_then(|session| { if validate(sessions, &session, request) { Some(session) } else { None } }) .or_error((Status::Unauthorized, SessionError::Invalid)) } } fn new_session( sessions: &State, user_id: u64, remote: String, max_age: Duration, ) -> Session { let session_id; { let mut sessions_data = sessions.data.lock().unwrap(); session_id = sessions_data.next_id; sessions_data.next_id += 1; let now = Instant::now(); // Remove expired sessions first sessions_data .active_ids .retain(|_, &mut expire| expire > now); sessions_data.active_ids.insert(session_id, now + max_age); } Session { user_id, session_id, remote, } } #[utoipa::path( responses( (status = 200, description = "Login successful", body = api_model::StatusResponse, example = json!(STATUS_OK)), (status = 401, description = "Login failed", body = api_model::StatusResponse, example = json!(STATUS_UNAUTHORIZED)), ), security( (), ), )] #[post("/login", data = "")] async fn login( auth_config: &State, sessions: &State, ipaddr: IpAddr, cookies: &CookieJar<'_>, mut db: Connection, login: Form>, ) -> Result, Unauthorized<&'static str>> { if login.username == "user" && login.password == "password" { let user_id = sqlx::query!("SELECT id FROM users WHERE username=?", login.username) .fetch_one(&mut **db) .map_ok(|r| r.id) .map_err(|_| Unauthorized("Unknown username or password")) .await .unwrap(); let max_age = Duration::days(i64::from(auth_config.session_max_age_days)); let session = new_session(sessions, user_id, ipaddr.to_string(), max_age); let cookie = Cookie::build((SESSION_COOKIE, json::to_string(&session).unwrap())) .path("/api") .max_age(max_age) .http_only(true) .build(); cookies.add_private(cookie); Ok(Json(STATUS_OK)) } else { Err(Unauthorized("Unknown username or password")) } } #[utoipa::path( responses( (status = 200, description = "Logout successful", body = api_model::StatusResponse, example = json!(STATUS_OK)), ), security( ("session" = []), ), )] #[get("/logout")] fn logout( session: Session, sessions: &State, cookies: &CookieJar<'_>, ) -> Json { { let mut sessions_data = sessions.data.lock().unwrap(); sessions_data.active_ids.remove(&session.session_id); } let cookie = Cookie::build((SESSION_COOKIE, "")) .path("/api") .http_only(true) .build(); cookies.remove_private(cookie); Json(STATUS_OK) } #[utoipa::path( responses( (status = 200, description = "Current status", body = api_model::StatusResponse, example = json!(STATUS_OK)), (status = 401, description = "Not authorized", body = api_model::StatusResponse, example = json!(STATUS_UNAUTHORIZED)), ), security( (), ("session" = []), ), )] #[get("/status")] fn status(_session: Session) -> Json { Json(STATUS_OK) } #[catch(401)] fn unauthorized() -> Json { Json(STATUS_UNAUTHORIZED) } async fn run_import(rocket: Rocket) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::query!("INSERT IGNORE INTO users (username) VALUES (?)", "user") .execute(&**db) .await { Ok(_) => Ok(rocket), Err(_) => Err(rocket), }, None => Err(rocket), } } pub fn stage(basepath: &str) -> AdHoc { let l_basepath = basepath.to_string(); AdHoc::on_ignite("Auth Stage", |rocket| async { rocket .manage(Sessions { data: Mutex::new(SessionsData { active_ids: BTreeMap::new(), next_id: 1, }), }) .attach(AdHoc::config::()) .attach(AdHoc::try_on_ignite("Auth Import", run_import)) .mount(l_basepath.clone(), routes![login, logout, status]) .register(l_basepath, catchers![unauthorized]) }) }