1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use crate::{
    db::{self, AccessTokenId, Database, UserId},
    AppState, Error, Result,
};
use anyhow::{anyhow, Context};
use axum::{
    http::{self, Request, StatusCode},
    middleware::Next,
    response::IntoResponse,
};
use lazy_static::lazy_static;
use prometheus::{exponential_buckets, register_histogram, Histogram};
use rand::thread_rng;
use scrypt::{
    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
    Scrypt,
};
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Instant};

lazy_static! {
    static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
        "access_token_hashing_time",
        "time spent hashing access tokens",
        exponential_buckets(10.0, 2.0, 10).unwrap(),
    )
    .unwrap();
}

pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
    let mut auth_header = req
        .headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok())
        .ok_or_else(|| {
            Error::Http(
                StatusCode::UNAUTHORIZED,
                "missing authorization header".to_string(),
            )
        })?
        .split_whitespace();

    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
        Error::Http(
            StatusCode::BAD_REQUEST,
            "missing user id in authorization header".to_string(),
        )
    })?);

    let access_token = auth_header.next().ok_or_else(|| {
        Error::Http(
            StatusCode::BAD_REQUEST,
            "missing access token in authorization header".to_string(),
        )
    })?;

    let state = req.extensions().get::<Arc<AppState>>().unwrap();
    let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
        state.config.api_token == admin_token
    } else {
        verify_access_token(&access_token, user_id, &state.db)
            .await
            .unwrap_or(false)
    };

    if credentials_valid {
        let user = state
            .db
            .get_user_by_id(user_id)
            .await?
            .ok_or_else(|| anyhow!("user {} not found", user_id))?;
        req.extensions_mut().insert(user);
        Ok::<_, Error>(next.run(req).await)
    } else {
        Err(Error::Http(
            StatusCode::UNAUTHORIZED,
            "invalid credentials".to_string(),
        ))
    }
}

const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;

#[derive(Serialize, Deserialize)]
struct AccessTokenJson {
    version: usize,
    id: AccessTokenId,
    token: String,
}

pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
    const VERSION: usize = 1;
    let access_token = rpc::auth::random_token();
    let access_token_hash =
        hash_access_token(&access_token).context("failed to hash access token")?;
    let id = db
        .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
        .await?;
    Ok(serde_json::to_string(&AccessTokenJson {
        version: VERSION,
        id,
        token: access_token,
    })?)
}

fn hash_access_token(token: &str) -> Result<String> {
    // Avoid slow hashing in debug mode.
    let params = if cfg!(debug_assertions) {
        scrypt::Params::new(1, 1, 1).unwrap()
    } else {
        scrypt::Params::new(14, 8, 1).unwrap()
    };

    Ok(Scrypt
        .hash_password(
            token.as_bytes(),
            None,
            params,
            &SaltString::generate(thread_rng()),
        )
        .map_err(anyhow::Error::new)?
        .to_string())
}

pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
    let native_app_public_key =
        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
    let encrypted_access_token = native_app_public_key
        .encrypt_string(access_token)
        .context("failed to encrypt access token with public key")?;
    Ok(encrypted_access_token)
}

pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
    let token: AccessTokenJson = serde_json::from_str(&token)?;

    let db_token = db.get_access_token(token.id).await?;
    if db_token.user_id != user_id {
        return Err(anyhow!("no such access token"))?;
    }

    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
    let t0 = Instant::now();
    let is_valid = Scrypt
        .verify_password(token.token.as_bytes(), &db_hash)
        .is_ok();
    let duration = t0.elapsed();
    log::info!("hashed access token in {:?}", duration);
    METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
    Ok(is_valid)
}