feat: implement authentication module with JWT support and user management

This commit is contained in:
GW_MC
2025-12-15 15:51:45 +08:00
parent 1233f3b736
commit 3354154b87
13 changed files with 1232 additions and 5 deletions

View File

@@ -0,0 +1,269 @@
pub mod strategies;
use std::{collections::HashSet, sync::Arc};
use argon2::password_hash::{SaltString, rand_core::OsRng};
use jsonwebtoken::{
DecodingKey, EncodingKey, Header, Validation, decode, encode,
errors::ErrorKind::{ExpiredSignature, InvalidSubject, InvalidToken},
};
use sea_orm::prelude::Uuid;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::errors::service_error::ServiceError;
// Number of requests between invalidation cache cleanups
const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks
#[derive(Serialize, Deserialize, Clone)]
pub struct Claims {
// Subject - user ID
pub sub: String,
// Issued at as UNIX timestamp
pub iat: u64,
// Expiration time as UNIX timestamp
pub exp: u64,
}
#[async_trait::async_trait]
pub trait AuthenticationService: Send + Sync {
async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64)
-> Result<String, ServiceError>;
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError>;
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError>;
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>;
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError>;
async fn logout(&self, token: &str) -> Result<(), ServiceError>;
async fn cleanup_invalidation_cache(&self);
}
#[derive(Eq, Hash, PartialEq)]
struct InvalidationEntry {
token: String,
invalidated_at: u64,
valid_until: u64,
}
pub struct AuthenticationServiceImpl {
secret: String,
invalidation_cache: Arc<RwLock<HashSet<InvalidationEntry>>>,
cache_cleanup_counter: Arc<RwLock<usize>>,
}
impl AuthenticationServiceImpl {
pub fn new(secret: Option<String>) -> Self {
let secret = secret.unwrap_or_else(|| {
// generate a random secret if none is provided
SaltString::generate(&mut OsRng).as_str().to_owned()
});
Self {
secret,
invalidation_cache: Arc::new(RwLock::new(HashSet::new())),
cache_cleanup_counter: Arc::new(RwLock::new(0)),
}
}
}
#[async_trait::async_trait]
impl AuthenticationService for AuthenticationServiceImpl {
async fn generate_jwt(
&self,
user_id: Uuid,
duration_secs: u64,
) -> Result<String, ServiceError> {
let header = Header::default();
let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::seconds(duration_secs as i64))
.ok_or(ServiceError::InternalError(
"Invalid expiration time".into(),
))?
.timestamp() as u64;
let claims = Claims {
sub: user_id.to_string(),
iat: chrono::Utc::now().timestamp() as u64,
exp: expiration,
};
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(self.secret.as_ref()),
)
.map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?;
Ok(token)
}
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError> {
let mut validation = Validation::default();
if let Some(expected_sub) = target_sub {
validation.sub = Some(expected_sub);
}
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
match decode::<Claims>(token, &decoding_key, &validation) {
Ok(_) => Ok(true),
Err(err) => match *err.kind() {
InvalidToken | InvalidSubject | ExpiredSignature => Ok(false),
_ => Err(ServiceError::InternalError(format!(
"JWT validation error: {}",
err
))),
},
}
}
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError> {
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
let token_data = decode::<Claims>(token, &decoding_key, &Validation::default())
.map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?;
Ok(token_data.claims)
}
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> {
let claims = self.parse_jwt(token).await?;
let valid_until = claims.exp;
let invalidated_at = chrono::Utc::now().timestamp() as u64;
let entry = InvalidationEntry {
token: token.to_string(),
invalidated_at,
valid_until,
};
{
self.invalidation_cache.write().await.insert(entry);
}
//
if self.cache_cleanup_counter.read().await.wrapping_add(1)
% INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS
== 0
{
self.cleanup_invalidation_cache().await;
}
//
Ok(())
}
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError> {
let claims = self.parse_jwt(token).await?;
let user_id = Uuid::parse_str(&claims.sub).map_err(|e| {
ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e))
})?;
let new_token = self.generate_jwt(user_id, duration_secs).await?;
Ok(new_token)
}
async fn logout(&self, token: &str) -> Result<(), ServiceError> {
self.invalidate_jwt(token).await
}
async fn cleanup_invalidation_cache(&self) {
let now = chrono::Utc::now().timestamp() as u64;
let mut cache = self.invalidation_cache.write().await;
cache.retain(|entry| entry.valid_until > now);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_jwt_generation_and_validation() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service
.generate_jwt(user_id, 60)
.await
.expect("generate jwt");
let valid = service
.is_valid_jwt(&token, None)
.await
.expect("validate jwt");
assert!(valid, "Generated token should be valid");
let claims = service.parse_jwt(&token).await.expect("parse jwt");
assert_eq!(claims.sub, user_id.to_string());
}
#[tokio::test]
async fn test_jwt_validation_with_wrong_subject() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let other_sub = Uuid::new_v4().to_string();
let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap();
assert!(!valid, "Token should be invalid for a different subject");
}
#[tokio::test]
async fn test_parse_jwt_invalid_token() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let res = service.parse_jwt("not_a_token").await;
assert!(matches!(res, Err(ServiceError::InternalError(_))));
}
#[tokio::test]
async fn test_refresh_jwt() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let new_token = service.refresh_jwt(&token, 120).await.unwrap();
let claims = service.parse_jwt(&new_token).await.unwrap();
assert_eq!(claims.sub, user_id.to_string());
assert_eq!(claims.exp - claims.iat, 120);
}
#[tokio::test]
async fn test_is_valid_jwt_expired() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
sleep(Duration::from_secs(2)).await;
let valid = service.is_valid_jwt(&token, None).await.unwrap();
assert!(!valid, "Token should be expired and thus invalid");
}
#[tokio::test]
async fn test_invalidate_and_cleanup() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
service.invalidate_jwt(&token).await.unwrap();
// ensure entry is present
{
let cache = service.invalidation_cache.read().await;
assert!(cache.iter().any(|e| e.token == token));
}
// wait until token validity ends and cleanup
sleep(Duration::from_secs(2)).await;
service.cleanup_invalidation_cache().await;
let cache = service.invalidation_cache.read().await;
assert!(
cache.is_empty(),
"Cleanup should remove expired invalidation entries"
);
}
}

View File

@@ -0,0 +1 @@
pub mod password;

View File

@@ -0,0 +1,453 @@
use std::sync::Arc;
use crate::{errors::service_error::ServiceError, with_conn};
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
};
use database::generated::entities::{user, user_identity};
use sea_orm::{
ColumnTrait, DatabaseConnection, DatabaseTransaction, EntityTrait, IntoActiveModel,
QueryFilter, prelude::Uuid,
};
pub struct PasswordStrategy {
connection: Arc<DatabaseConnection>,
}
const MAX_PASSWORD_LENGTH: usize = 32;
const PASSWORD_PROVIDER: &str = "password";
impl PasswordStrategy {
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
pub async fn authenticate(
&self,
username: &str,
password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<Uuid, ServiceError> {
// Find user by username
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find()
.filter(user::Column::Name.eq(username))
.one(*conn)
.await?
.ok_or_else(|| {
ServiceError::Unauthorized("Invalid username or password".to_string())
})?
});
// Get user's identity
let identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user.id))
.one(*conn)
.await?
.ok_or_else(|| {
ServiceError::Unauthorized("Invalid username or password".to_string())
})?
});
// Check if revoked
if identity.is_revoked {
return Err(ServiceError::Unauthorized("Account is revoked".to_string()));
}
// Verify password
let password_hash = identity
.password_hash
.ok_or_else(|| ServiceError::InternalError("Invalid password hash".to_string()))?;
let parsed_hash = PasswordHash::new(&password_hash)
.map_err(|_| ServiceError::InternalError("Invalid password hash".to_string()))?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| ServiceError::Unauthorized("Invalid username or password".to_string()))?;
Ok(user.id)
}
pub async fn revoke_identity(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
let mut identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user_id))
.one(*conn)
.await?
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
});
identity.is_revoked = true;
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::update(identity.into_active_model())
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
pub async fn create_identity(
&self,
user_id: Uuid,
password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
Self::is_valid_password(password).map_err(ServiceError::BadRequest)?;
let password_hash = Argon2::default()
.hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
.to_string();
let new_identity = user_identity::ActiveModel {
user_id: sea_orm::ActiveValue::Set(user_id),
provider: sea_orm::ActiveValue::Set(PASSWORD_PROVIDER.to_string()),
password_hash: sea_orm::ActiveValue::Set(Some(password_hash)),
metadata: sea_orm::ActiveValue::Set(None),
is_revoked: sea_orm::ActiveValue::Set(false),
..Default::default()
};
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::insert(new_identity)
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
pub async fn update_password(
&self,
user_id: Uuid,
new_password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
Self::is_valid_password(new_password).map_err(ServiceError::BadRequest)?;
let password_hash = Argon2::default()
.hash_password(new_password.as_bytes(), &SaltString::generate(&mut OsRng))
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
.to_string();
let mut identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user_id))
.one(*conn)
.await?
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
});
identity.password_hash = Some(password_hash);
identity.password_changed_at = Some(chrono::Utc::now());
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::update(identity.into_active_model())
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
fn is_valid_password(password: &str) -> Result<(), String> {
if password.is_empty() {
return Err("Password cannot be empty".to_string());
}
if password.len() > MAX_PASSWORD_LENGTH {
return Err(format!(
"Password cannot be longer than {} characters",
MAX_PASSWORD_LENGTH
));
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use database::generated::entities::{user, user_identity};
use sea_orm::MockDatabase;
#[test]
fn ensure_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<PasswordStrategy>();
}
#[test]
fn password_validation() {
let valid_password = "ValidPassword123!";
let long_password = "a".repeat(129);
assert!(PasswordStrategy::is_valid_password(valid_password).is_ok());
assert!(PasswordStrategy::is_valid_password(long_password.as_str()).is_err());
}
#[tokio::test]
async fn authenticate_user_not_found() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("nonexistent_user", "password", None)
.await;
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
}
#[tokio::test]
async fn authenticate_invalid_password() {
let user_id = Uuid::new_v4();
let password_hash = Argon2::default()
.hash_password(
"CorrectPassword".as_bytes(),
&SaltString::generate(&mut OsRng),
)
.unwrap()
.to_string();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user::Model {
id: user_id,
name: "test_user".to_string(),
is_active: true,
is_admin: false,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
deleted_at: None,
last_login_at: None,
}]])
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some(password_hash),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("test_user", "InvalidPassword", None)
.await;
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
}
#[tokio::test]
async fn authenticate_success() {
let user_id = Uuid::new_v4();
let password_hash = Argon2::default()
.hash_password(
"CorrectPassword".as_bytes(),
&SaltString::generate(&mut OsRng),
)
.unwrap()
.to_string();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user::Model {
id: user_id,
name: "test_user".to_string(),
is_active: true,
is_admin: false,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
deleted_at: None,
last_login_at: None,
}]])
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some(password_hash),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("test_user", "CorrectPassword", None)
.await;
assert!(matches!(result, Ok(id) if id == user_id));
}
#[tokio::test]
async fn revoke_identity_not_found() {
let user_id = Uuid::new_v4();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.revoke_identity(user_id, None).await;
assert!(matches!(result, Err(ServiceError::NotFound(_))));
}
#[tokio::test]
async fn revoke_identity_success() {
let user_id = Uuid::new_v4();
let identity = user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: None,
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
};
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![
vec![identity.clone()],
vec![user_identity::Model {
is_revoked: true,
..identity
}],
])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.revoke_identity(user_id, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn create_identity_invalid_password() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite).into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.create_identity(Uuid::new_v4(), "", None).await;
assert!(matches!(result, Err(ServiceError::BadRequest(_))));
}
#[tokio::test]
async fn create_identity_success() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id: Uuid::new_v4(),
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("somehash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.create_identity(Uuid::new_v4(), "ValidPass1!", None)
.await;
assert!(
result.is_ok(),
"Failed to create identity, error: {:?}",
result.err()
);
}
#[tokio::test]
async fn update_password_not_found() {
let user_id = Uuid::new_v4();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.update_password(user_id, "NewPass1!", None).await;
assert!(matches!(result, Err(ServiceError::NotFound(_))));
}
#[tokio::test]
async fn update_password_success() {
let user_id = Uuid::new_v4();
let identity = user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("oldhash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
};
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![
vec![identity],
vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("newhash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}],
])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.update_password(user_id, "NewPass1!", None).await;
assert!(
result.is_ok(),
"Failed to update password, error: {:?}",
result.err()
);
}
}

View File

@@ -0,0 +1,208 @@
use std::sync::Arc;
use database::generated::entities::user::{
self, ActiveModel as UserActiveModel, Model as UserModel,
};
use sea_orm::{
ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, DatabaseTransaction, DbErr,
EntityTrait, IntoActiveModel, QueryFilter, prelude::Uuid,
};
use crate::{errors::service_error::ServiceError, with_conn};
#[async_trait::async_trait]
pub trait UserService: Send + Sync {
async fn get_user_by_id(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn is_admin(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError>;
async fn user_exists(
&self,
username: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError>;
async fn create_user(
&self,
user: NewUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn update_user(
&self,
user_id: Uuid,
user: UpdateUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn delete_user(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError>;
}
pub struct User {
pub id: Uuid,
pub username: String,
pub is_admin: bool,
}
impl From<UserModel> for User {
fn from(model: UserModel) -> Self {
Self {
id: model.id,
username: model.name,
is_admin: model.is_admin,
}
}
}
pub struct NewUser {
pub username: String,
pub is_admin: bool,
}
pub struct UpdateUser {
pub username: Option<String>,
pub is_admin: Option<bool>,
pub is_active: Option<bool>,
}
impl UpdateUser {
fn apply_to_active_model(&self, model: &mut UserActiveModel) {
if let Some(username) = &self.username {
model.name = ActiveValue::Set(username.clone());
}
if let Some(is_admin) = self.is_admin {
model.is_admin = ActiveValue::Set(is_admin);
}
if let Some(is_active) = self.is_active {
model.is_active = ActiveValue::Set(is_active);
}
}
}
pub struct UserServiceImpl {
connection: Arc<DatabaseConnection>,
}
impl UserServiceImpl {
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
async fn get_user_by_id_from_db(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<UserModel, ServiceError> {
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find_by_id(user_id).one(*conn).await
});
match user {
Err(err) => Err(ServiceError::from(err)),
Ok(None) => Err(ServiceError::NotFound(format!(
"User with id '{}' not found",
user_id
))),
Ok(Some(record)) => Ok(record),
}
}
}
#[async_trait::async_trait]
impl UserService for UserServiceImpl {
async fn get_user_by_id(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let user = self.get_user_by_id_from_db(user_id, tx).await?;
Ok(User::from(user))
}
async fn is_admin(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError> {
let user = self.get_user_by_id(user_id, tx).await?;
Ok(user.is_admin)
}
async fn user_exists(
&self,
username: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError> {
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find()
.filter(user::Column::Name.eq(username))
.one(*conn)
.await
});
match user {
Err(err) => match err {
DbErr::RecordNotFound(_) => Ok(false),
_ => Err(ServiceError::from(err)),
},
Ok(None) => Ok(false),
Ok(Some(_)) => Ok(true),
}
}
async fn create_user(
&self,
user: NewUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let user_active_model = UserActiveModel {
id: ActiveValue::NotSet,
name: ActiveValue::Set(user.username),
is_admin: ActiveValue::Set(user.is_admin),
is_active: ActiveValue::Set(true),
..Default::default()
};
let user_model = with_conn!(&*self.connection, tx, conn, {
user_active_model.insert(*conn).await
})?;
Ok(User::from(user_model))
}
async fn update_user(
&self,
user_id: Uuid,
update_user: UpdateUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let existing_user = self.get_user_by_id_from_db(user_id, tx).await?;
let mut user_active_model = existing_user.into_active_model();
update_user.apply_to_active_model(&mut user_active_model);
let user_model = user_active_model.update(&*self.connection).await?;
Ok(User::from(user_model))
}
async fn delete_user(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
let user = self.get_user_by_id_from_db(user_id, tx).await?;
let user_active_model = user.into_active_model();
user_active_model.delete(&*self.connection).await?;
Ok(())
}
}