use core::net::IpAddr; use rocket::fairing::AdHoc; use rocket::form::Form; use rocket::http::{Cookie, CookieJar, Status}; use rocket::outcome::{try_outcome, IntoOutcome}; use rocket::request::{FromRequest, Outcome, Request}; use rocket::response::status::Unauthorized; use rocket::serde::json::{self, Json}; use rocket::serde::{Deserialize, Serialize}; use rocket::State; use std::collections::BTreeMap; use std::sync::Mutex; use std::time::Instant; use time::Duration; use crate::api_model; #[derive(FromForm)] struct Login<'r> { username: &'r str, password: &'r str, } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct AuthConfig { session_max_age_days: u32, } struct SessionsData { active_ids: BTreeMap, next_id: u32, } struct Sessions { data: Mutex, } #[derive(Debug, Deserialize, Serialize)] pub struct Session { pub user_id: u64, session_id: u32, remote: String, } #[derive(Debug)] pub enum SessionError { Invalid, } const SESSION_COOKIE: &str = "s"; fn validate(sessions: &State, session: &Session, request: &Request<'_>) -> bool { match request.client_ip() { Some(addr) => { if session.remote == addr.to_string() { { let sessions_data = sessions.data.lock().unwrap(); match sessions_data.active_ids.get(&session.session_id) { // We could remove the expired session here, but it will be cleaned // next time anyone logs in anyway. Some(&expire) => expire > Instant::now(), None => false, } } } else { false } } None => false, } } #[rocket::async_trait] impl<'r> FromRequest<'r> for Session { type Error = SessionError; async fn from_request(request: &'r Request<'_>) -> Outcome { let sessions = try_outcome!(request .guard::<&State>() .await .map_error(|_| (Status::Unauthorized, SessionError::Invalid))); request .cookies() .get_private(SESSION_COOKIE) .and_then(|cookie| -> Option { json::from_str(cookie.value()).ok() }) .and_then(|session| { if validate(&sessions, &session, request) { Some(session) } else { None } }) .or_error((Status::Unauthorized, SessionError::Invalid)) } } fn new_session( sessions: &State, user_id: u64, remote: String, max_age: Duration, ) -> Session { let session_id; { let mut sessions_data = sessions.data.lock().unwrap(); session_id = sessions_data.next_id; sessions_data.next_id += 1; let now = Instant::now(); // Remove expired sessions first sessions_data .active_ids .retain(|_, &mut expire| expire > now); sessions_data.active_ids.insert(session_id, now + max_age); } Session { user_id: user_id, session_id: session_id, remote: remote, } } #[post("/login", data = "")] fn login( auth_config: &State, sessions: &State, ipaddr: IpAddr, cookies: &CookieJar<'_>, login: Form>, ) -> Result, Unauthorized<&'static str>> { if login.username == "user" && login.password == "password" { let max_age = Duration::days(i64::from(auth_config.session_max_age_days)); let session = new_session(&sessions, 1u64, ipaddr.to_string(), max_age); let cookie = Cookie::build((SESSION_COOKIE, json::to_string(&session).unwrap())) .path("/api") .max_age(max_age) .http_only(true) .build(); cookies.add_private(cookie); Ok(Json(api_model::StatusResponse { ok: true, error: None, })) } else { Err(Unauthorized("Unknown username or password")) } } #[get("/logout")] fn logout( session: Session, sessions: &State, cookies: &CookieJar<'_>, ) -> Json { { let mut sessions_data = sessions.data.lock().unwrap(); sessions_data.active_ids.remove(&session.session_id); } let cookie = Cookie::build((SESSION_COOKIE, "")) .path("/api") .http_only(true) .build(); cookies.remove_private(cookie); Json(api_model::StatusResponse { ok: true, error: None, }) } #[get("/status")] fn status(_session: Session) -> Json { Json(api_model::StatusResponse { ok: true, error: None, }) } #[catch(401)] fn unauthorized() -> Json { Json(api_model::StatusResponse { ok: false, error: Some("Unauthorized".to_string()), }) } pub fn stage(basepath: String) -> AdHoc { AdHoc::on_ignite("Auth Stage", |rocket| async { rocket .manage(Sessions { data: Mutex::new(SessionsData { active_ids: BTreeMap::new(), next_id: 1, }), }) .attach(AdHoc::config::()) .mount(basepath.clone(), routes![login, logout, status]) .register(basepath, catchers![unauthorized]) }) }