diff options
| author | Joel Klinghed <the_jk@spawned.biz> | 2025-01-03 01:28:54 +0100 |
|---|---|---|
| committer | Joel Klinghed <the_jk@spawned.biz> | 2025-01-03 01:28:54 +0100 |
| commit | 7494db93b9262c3d8330fd11631e711a1642b8fc (patch) | |
| tree | 565534ddaa990a861c0ef8a9439f7656fce7f132 | |
| parent | 4b1f7fec1cf9d427234ff5bded79a6d18d5c88ce (diff) | |
Add initital tests
Also add /users endpoint.
| -rw-r--r-- | server/Cargo.lock | 7 | ||||
| -rw-r--r-- | server/Cargo.toml | 3 | ||||
| -rw-r--r-- | server/src/api_model.rs | 39 | ||||
| -rw-r--r-- | server/src/auth.rs | 20 | ||||
| -rw-r--r-- | server/src/main.rs | 83 | ||||
| -rw-r--r-- | server/src/tests.rs | 367 |
6 files changed, 496 insertions, 23 deletions
diff --git a/server/Cargo.lock b/server/Cargo.lock index 1983c29..622a32b 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -502,6 +502,7 @@ dependencies = [ "rocket_db_pools", "serde", "sqlx", + "stdext", "time", "utoipa", "utoipa-swagger-ui", @@ -2253,6 +2254,12 @@ dependencies = [ ] [[package]] +name = "stdext" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4af28eeb7c18ac2dbdb255d40bee63f203120e1db6b0024b177746ebec7049c1" + +[[package]] name = "stringprep" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/server/Cargo.toml b/server/Cargo.toml index 9348d86..d939013 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -12,3 +12,6 @@ sqlx = { version = "0.7.0", default-features = false, features = ["macros", "mig time = "0.3.34" utoipa = { version = "5", features = ["rocket_extras"] } utoipa-swagger-ui = { version = "8", features = ["rocket", "vendored"], default-features = false } + +[dev-dependencies] +stdext = "0.3.3" diff --git a/server/src/api_model.rs b/server/src/api_model.rs index f3bb5cf..b08ad5c 100644 --- a/server/src/api_model.rs +++ b/server/src/api_model.rs @@ -1,7 +1,7 @@ use rocket::serde::{Deserialize, Serialize}; use utoipa::ToSchema; -#[derive(Deserialize, Serialize, Copy, Clone, ToSchema)] +#[derive(Copy, Clone, Deserialize, Serialize, ToSchema)] pub enum ReviewState { Draft, Open, @@ -9,14 +9,14 @@ pub enum ReviewState { Closed, } -#[derive(Deserialize, Serialize, Copy, Clone, ToSchema)] +#[derive(Copy, Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] pub enum UserReviewRole { Reviewer, Watcher, None, } -#[derive(Serialize, ToSchema)] +#[derive(Debug, Deserialize, Serialize, PartialEq, ToSchema)] pub struct User { #[schema(example = 1337u64)] pub id: u64, @@ -28,6 +28,19 @@ pub struct User { pub active: bool, } +#[derive(Deserialize, Serialize, ToSchema)] +pub struct Users { + #[schema(example = 0u32)] + pub offset: u32, + #[schema(example = 10u32)] + pub limit: u32, + #[schema(example = 42u32)] + pub total_count: u32, + #[schema(example = true)] + pub more: bool, + pub users: Vec<User>, +} + #[derive(Serialize, ToSchema)] pub struct ReviewUserEntry { pub user: User, @@ -77,7 +90,7 @@ pub struct Reviews { pub reviews: Vec<ReviewEntry>, } -#[derive(Serialize, ToSchema)] +#[derive(Debug, Deserialize, PartialEq, Serialize, ToSchema)] pub struct ProjectUserEntry { pub user: User, #[schema(example = UserReviewRole::Reviewer)] @@ -86,7 +99,7 @@ pub struct ProjectUserEntry { pub maintainer: bool, } -#[derive(Deserialize, ToSchema)] +#[derive(Deserialize, Serialize, ToSchema)] pub struct ProjectUserEntryData { #[schema(example = UserReviewRole::Reviewer)] pub default_role: Option<UserReviewRole>, @@ -94,7 +107,7 @@ pub struct ProjectUserEntryData { pub maintainer: Option<bool>, } -#[derive(Serialize, ToSchema)] +#[derive(Debug, Deserialize, PartialEq, Serialize, ToSchema)] pub struct Project { #[schema(example = 1u64)] pub id: u64, @@ -105,7 +118,7 @@ pub struct Project { pub users: Vec<ProjectUserEntry>, } -#[derive(Deserialize, ToSchema)] +#[derive(Deserialize, Serialize, ToSchema)] pub struct ProjectData<'r> { #[schema(example = "FAKE: Features All Kids Erase")] pub title: Option<&'r str>, @@ -113,7 +126,7 @@ pub struct ProjectData<'r> { pub description: Option<&'r str>, } -#[derive(Serialize, ToSchema)] +#[derive(Deserialize, Serialize, ToSchema)] pub struct ProjectEntry { #[schema(example = 1u64)] pub id: u64, @@ -121,7 +134,7 @@ pub struct ProjectEntry { pub title: String, } -#[derive(Serialize, ToSchema)] +#[derive(Deserialize, Serialize, ToSchema)] pub struct Projects { #[schema(example = 0u32)] pub offset: u32, @@ -134,9 +147,13 @@ pub struct Projects { pub projects: Vec<ProjectEntry>, } -#[derive(Serialize, ToSchema)] +#[derive(Deserialize, Serialize, ToSchema)] pub struct StatusResponse { pub ok: bool, - #[serde(skip_serializing_if = "Option::is_none")] + #[serde( + skip_serializing_if = "Option::is_none", + // &'static str is problematic for serde, only used in tests anyway. + skip_deserializing, + )] pub error: Option<&'static str>, } diff --git a/server/src/auth.rs b/server/src/auth.rs index f1b8f70..db3a6a0 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -248,8 +248,10 @@ fn unauthorized() -> Json<api_model::StatusResponse> { Json(STATUS_UNAUTHORIZED) } +#[cfg(not(test))] async fn run_import(rocket: Rocket<Build>) -> fairing::Result { match Db::fetch(&rocket) { + // TODO: Replace with ldap Some(db) => match sqlx::query!("INSERT IGNORE INTO users (username) VALUES (?)", "user") .execute(&**db) .await @@ -261,6 +263,24 @@ async fn run_import(rocket: Rocket<Build>) -> fairing::Result { } } +#[cfg(test)] +async fn run_import(rocket: Rocket<Build>) -> fairing::Result { + match Db::fetch(&rocket) { + Some(db) => match sqlx::query!( + "INSERT IGNORE INTO users (username) VALUES (?), (?)", + "user", + "other", + ) + .execute(&**db) + .await + { + Ok(_) => Ok(rocket), + Err(_) => Err(rocket), + }, + None => Err(rocket), + } +} + pub fn stage(basepath: &str) -> AdHoc { let l_basepath = basepath.to_string(); AdHoc::on_ignite("Auth Stage", |rocket| async { diff --git a/server/src/main.rs b/server/src/main.rs index 54ad279..53cdb89 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -3,6 +3,7 @@ extern crate rocket; use futures::{future::TryFutureExt, stream::TryStreamExt}; use rocket::fairing::{self, AdHoc}; +use rocket::figment::Figment; use rocket::http::Status; use rocket::response::status::{Custom, NotFound}; use rocket::serde::json::Json; @@ -12,6 +13,9 @@ use sqlx::Acquire; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +#[cfg(test)] +mod tests; + mod api_model; mod auth; mod db_utils; @@ -34,6 +38,7 @@ struct Db(sqlx::MySqlPool); project_user_del, reviews, review, + users, ), modifiers(&AuthApiAddon), )] @@ -574,6 +579,56 @@ async fn review( Ok(Json(review)) } +#[utoipa::path( + responses( + (status = 200, description = "Get all users", body = api_model::Users), + ), + security( + ("session" = []), + ), +)] +#[get("/users?<limit>&<offset>")] +async fn users( + mut db: Connection<Db>, + _session: auth::Session, + limit: Option<u32>, + offset: Option<u32>, +) -> Json<api_model::Users> { + let uw_offset = offset.unwrap_or(0); + let uw_limit = limit.unwrap_or(10); + let entries = sqlx::query!( + "SELECT id,username,name,active FROM users ORDER BY username LIMIT ? OFFSET ?", + uw_limit, + uw_offset + ) + .fetch(&mut **db) + .map_ok(|r| api_model::User { + id: r.id, + username: r.username, + name: r.name, + active: r.active != 0, + }) + .try_collect::<Vec<_>>() + .await + .unwrap(); + + let count = sqlx::query!("SELECT COUNT(id) AS count FROM users") + .fetch_one(&mut **db) + .map_ok(|r| r.count) + .await + .unwrap(); + + let u32_count = u32::try_from(count).unwrap(); + + Json(api_model::Users { + offset: uw_offset, + limit: uw_limit, + total_count: u32_count, + more: uw_offset + uw_limit < u32_count, + users: entries, + }) +} + async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result { match Db::fetch(&rocket) { Some(db) => match sqlx::migrate!().run(&**db).await { @@ -587,17 +642,9 @@ async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result { } } -#[rocket::main] -async fn main() -> Result<(), rocket::Error> { +fn rocket_from_config(figment: Figment) -> Rocket<Build> { let basepath = "/api/v1"; - - let mut api = MainApi::openapi(); - api.merge(auth::AuthApi::openapi()); - api.servers = Some(vec![utoipa::openapi::ServerBuilder::new() - .url(basepath) - .build()]); - - let _rocket = rocket::build() + rocket::custom(figment) .attach(Db::init()) .attach(AdHoc::try_on_ignite("Database Migrations", run_migrations)) .mount( @@ -612,14 +659,26 @@ async fn main() -> Result<(), rocket::Error> { project_user_update, project_user_del, reviews, - review + review, + users, ], ) + .attach(auth::stage(basepath)) +} + +#[rocket::main] +async fn main() -> Result<(), rocket::Error> { + let mut api = MainApi::openapi(); + api.merge(auth::AuthApi::openapi()); + api.servers = Some(vec![utoipa::openapi::ServerBuilder::new() + .url("/api/v1") + .build()]); + + let _rocket = rocket_from_config(rocket::Config::figment()) .mount( "/", SwaggerUi::new("/openapi/ui/<_..>").url("/openapi/openapi.json", api), ) - .attach(auth::stage(basepath)) .launch() .await?; diff --git a/server/src/tests.rs b/server/src/tests.rs new file mode 100644 index 0000000..b6476a0 --- /dev/null +++ b/server/src/tests.rs @@ -0,0 +1,367 @@ +use rocket::figment::util::map; +use rocket::figment::value::{Map, Value}; +use rocket::http::{ContentType, Header, Status}; +use rocket::local::asynchronous::{Client, LocalRequest}; +use sqlx::mysql::{MySql, MySqlConnectOptions, MySqlPoolOptions}; +use sqlx::{Acquire, Executor, Pool}; +use std::sync::OnceLock; +use stdext::function_name; + +use crate::api_model; + +struct RealIP(&'static str); + +impl From<&RealIP> for Header<'static> { + fn from(ip: &RealIP) -> Header<'static> { + Header::new("X-Real-IP", ip.0) + } +} + +static FAKE_IP: RealIP = RealIP("127.0.1.10"); +static ANOTHER_FAKE_IP: RealIP = RealIP("192.168.0.1"); + +static MASTER_POOL: OnceLock<Pool<MySql>> = OnceLock::new(); + +fn find_password(url: &'_ str) -> Option<&'_ str> { + let protocol = url.find("://"); + if protocol.is_none() { + return None; + } + let specific = &url[protocol.unwrap() + 3..]; + let at = specific.find('@'); + if at.is_none() { + return None; + } + let auth = &specific[0..at.unwrap()]; + let colon = auth.find(':'); + if colon.is_none() { + return None; + } + return Some(&auth[colon.unwrap() + 1..]); +} + +fn make_db_name_safe(name: &str) -> String { + let mut ret = String::new(); + for c in name.chars() { + if c >= 'a' && c <= 'z' { + ret.push(c); + } else if c >= '0' && c <= '9' { + ret.push(c); + } else { + ret.push('_'); + } + } + return ret; +} + +async fn async_client_with_private_database(test_name: String) -> Client { + let base_figment = rocket::Config::figment(); + + let base_url_value = base_figment + .find_value("databases.eyeballs.url") + .expect("database_url"); + let base_url = base_url_value.as_str().expect("database_url as string"); + let base_options: MySqlConnectOptions = base_url.parse().expect("valid database_url"); + + let maybe_password = find_password(base_url); + + let database = + make_db_name_safe(&format!("_{}", test_name.trim_end_matches("::{{closure}}"))[..]); + + // Cannot get sqlx::test (0.7.4) to work with MySQL, always errors out + // with connection (already?) closed when closing the setup connection. + // So doing our own lazier setup where each test gets a db based on + // their name. + { + let mut pool_conn = MASTER_POOL + .get_or_init(|| { + let options: MySqlConnectOptions = + base_url_value.as_str().unwrap().parse().unwrap(); + + MySqlPoolOptions::new() + .max_connections(20) + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(options) + }) + .acquire() + .await + .unwrap(); + + let conn = pool_conn.acquire().await.unwrap(); + conn.execute(&format!("DROP DATABASE IF EXISTS {database}")[..]) + .await + .unwrap(); + conn.execute(&format!("CREATE DATABASE {database}")[..]) + .await + .unwrap(); + } + + let db_url = format!( + "mysql://{}{}@{}:{}/{}", + base_options.get_username(), + if let Some(password) = maybe_password { + format!(":{}", password) + } else { + "".to_string() + }, + base_options.get_host(), + base_options.get_port(), + database, + ); + + let db_config: Map<_, Value> = map! { + "url" => db_url.into(), + }; + + let figment = base_figment.merge(("databases", map!["eyeballs" => db_config])); + + Client::tracked(crate::rocket_from_config(figment)) + .await + .expect("valid rocket instance") +} + +async fn get_status_from<'a>(request: LocalRequest<'a>) -> api_model::StatusResponse { + request + .header(&FAKE_IP) + .dispatch() + .await + .into_json::<api_model::StatusResponse>() + .await + .unwrap() +} + +async fn get_status<'a>(client: &Client) -> api_model::StatusResponse { + get_status_from(client.get("/api/v1/status")).await +} + +async fn login(client: &Client) { + let login = get_status_from( + client + .post("/api/v1/login") + .body("username=user&password=password") + .header(ContentType::Form), + ) + .await; + assert_eq!(login.ok, true); +} + +async fn get_projects<'a>(client: &Client) -> api_model::Projects { + client + .get("/api/v1/projects") + .header(&FAKE_IP) + .dispatch() + .await + .into_json::<api_model::Projects>() + .await + .unwrap() +} + +async fn get_project_from<'a>(request: LocalRequest<'a>) -> api_model::Project { + request + .header(&FAKE_IP) + .dispatch() + .await + .into_json::<api_model::Project>() + .await + .unwrap() +} + +async fn get_users<'a>(client: &Client) -> api_model::Users { + client + .get("/api/v1/users") + .header(&FAKE_IP) + .dispatch() + .await + .into_json::<api_model::Users>() + .await + .unwrap() +} + +async fn new_project(client: &Client) -> api_model::Project { + get_project_from( + client + .post("/api/v1/project/new") + .json(&api_model::ProjectData { + title: Some("foo"), + description: Some("bar"), + }), + ) + .await +} + +#[rocket::async_test] +async fn test_not_logged_in_status() { + let client = async_client_with_private_database(function_name!().to_string()).await; + let not_logged_in = get_status(&client).await; + assert_eq!(not_logged_in.ok, false); +} + +#[rocket::async_test] +async fn test_login_status() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let logged_in = get_status(&client).await; + assert_eq!(logged_in.ok, true); +} + +#[rocket::async_test] +async fn test_bad_login_status() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + let bad_password = client + .post("/api/v1/login") + .body("username=user&password=incorrect") + .header(ContentType::Form) + .header(&FAKE_IP) + .dispatch() + .await; + assert_eq!(bad_password.status(), Status::Unauthorized); + + let bad_username = client + .post("/api/v1/login") + .body("username=incorrect&password=password") + .header(ContentType::Form) + .header(&FAKE_IP) + .dispatch() + .await; + assert_eq!(bad_username.status(), Status::Unauthorized); +} + +#[rocket::async_test] +async fn test_change_ip() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let new_ip = client + .get("/api/v1/status") + .header(&ANOTHER_FAKE_IP) + .dispatch() + .await; + assert_eq!(new_ip.status(), Status::Unauthorized); +} + +#[rocket::async_test] +async fn test_logout() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let logged_in = get_status(&client).await; + assert_eq!(logged_in.ok, true); + + let logout = get_status_from(client.get("/api/v1/logout")).await; + assert_eq!(logout.ok, true); + + let not_logged_in = get_status(&client).await; + assert_eq!(not_logged_in.ok, false); +} + +#[rocket::async_test] +async fn test_projects_empty() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let projects = get_projects(&client).await; + assert_eq!(projects.total_count, 0); + assert_eq!(projects.more, false); + assert_eq!(projects.projects.len(), 0); +} + +#[rocket::async_test] +async fn test_project_new() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let project = new_project(&client).await; + + assert_eq!(project.title, "foo"); + assert_eq!(project.description, "bar"); + assert_eq!(project.users.len(), 1); + let user = project.users.get(0).unwrap(); + assert_eq!(user.user.username, "user"); + assert_eq!(user.default_role, api_model::UserReviewRole::Reviewer); + assert_eq!(user.maintainer, true); + + let projects = get_projects(&client).await; + assert_eq!(projects.total_count, 1); + assert_eq!(projects.more, false); + assert_eq!(projects.projects.len(), 1); + let project_entry = projects.projects.get(0).unwrap(); + assert_eq!(project_entry.id, project.id); + assert_eq!(project_entry.title, project.title); + + let project2 = get_project_from(client.get(format!("/api/v1/project/{}", project.id))).await; + assert_eq!(project, project2); +} + +#[rocket::async_test] +async fn test_project_update() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let project = get_project_from(client.post("/api/v1/project/new").json( + &api_model::ProjectData { + title: Some("foo"), + description: None, + }, + )) + .await; + + let project_url = format!("/api/v1/project/{}", project.id); + + let update = client + .post(project_url.clone()) + .json(&api_model::ProjectData { + title: None, + description: Some("bar"), + }) + .header(&FAKE_IP) + .dispatch() + .await; + assert_eq!(update.status(), Status::Ok); + + let updated_project = get_project_from(client.get(project_url)).await; + assert_eq!(updated_project.title, project.title); + assert_eq!(updated_project.description, "bar"); +} + +#[rocket::async_test] +async fn test_project_new_user() { + let client = async_client_with_private_database(function_name!().to_string()).await; + + login(&client).await; + + let project = new_project(&client).await; + let project_url = format!("/api/v1/project/{}", project.id); + + let users = get_users(&client).await; + let other = users.users.iter().find(|u| u.username == "other").unwrap(); + + let new = client + .post(format!("{project_url}/user/new?userid={}", other.id)) + .json(&api_model::ProjectUserEntryData { + default_role: Some(api_model::UserReviewRole::Watcher), + maintainer: Some(true), + }) + .header(&FAKE_IP) + .dispatch() + .await; + assert_eq!(new.status(), Status::Ok); + + let updated_project = get_project_from(client.get(project_url)).await; + assert_eq!(updated_project.users.len(), 2); + let other_entry = updated_project + .users + .iter() + .find(|ue| ue.user.id == other.id) + .unwrap(); + assert_eq!(other_entry.user, *other); + assert_eq!(other_entry.default_role, api_model::UserReviewRole::Watcher); + assert_eq!(other_entry.maintainer, true); +} |
