1
/*
2
 * This file is part of mailpot
3
 *
4
 * Copyright 2020 - Manos Pitsidianakis
5
 *
6
 * This program is free software: you can redistribute it and/or modify
7
 * it under the terms of the GNU Affero General Public License as
8
 * published by the Free Software Foundation, either version 3 of the
9
 * License, or (at your option) any later version.
10
 *
11
 * This program is distributed in the hope that it will be useful,
12
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 * GNU Affero General Public License for more details.
15
 *
16
 * You should have received a copy of the GNU Affero General Public License
17
 * along with this program. If not, see <https://www.gnu.org/licenses/>.
18
 */
19

            
20
use std::{borrow::Cow, process::Stdio};
21

            
22
use tempfile::NamedTempFile;
23
use tokio::{fs::File, io::AsyncWriteExt, process::Command};
24

            
25
use super::*;
26

            
27
const TOKEN_KEY: &str = "ssh_challenge";
28
const EXPIRY_IN_SECS: i64 = 6 * 60;
29

            
30
10
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, Eq, PartialEq, PartialOrd)]
31
pub enum Role {
32
    User,
33
    Admin,
34
}
35

            
36
12
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
37
pub struct User {
38
    /// SSH signature.
39
6
    pub ssh_signature: String,
40
    /// User role.
41
6
    pub role: Role,
42
    /// Database primary key.
43
6
    pub pk: i64,
44
    /// Accounts's display name, optional.
45
6
    pub name: Option<String>,
46
    /// Account's e-mail address.
47
6
    pub address: String,
48
    /// GPG public key.
49
6
    pub public_key: Option<String>,
50
    /// SSH public key.
51
6
    pub password: String,
52
    /// Whether this account is enabled.
53
6
    pub enabled: bool,
54
}
55

            
56
impl AuthUser<i64, Role> for User {
57
1
    fn get_id(&self) -> i64 {
58
1
        self.pk
59
2
    }
60

            
61
7
    fn get_password_hash(&self) -> SecretVec<u8> {
62
7
        SecretVec::new(self.ssh_signature.clone().into())
63
7
    }
64

            
65
6
    fn get_role(&self) -> Option<Role> {
66
6
        Some(self.role)
67
12
    }
68
}
69

            
70
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Default)]
71
pub struct AuthFormPayload {
72
    pub address: String,
73
    pub password: String,
74
}
75

            
76
1
pub async fn ssh_signin(
77
1
    _: LoginPath,
78
    mut session: WritableSession,
79
    Query(next): Query<Next>,
80
    auth: AuthContext,
81
    State(state): State<Arc<AppState>>,
82
1
) -> impl IntoResponse {
83
    if auth.current_user.is_some() {
84
        if let Err(err) = session.add_message(Message {
85
            message: "You are already logged in.".into(),
86
            level: Level::Info,
87
        }) {
88
            return err.into_response();
89
        }
90
        return next
91
            .or_else(|| format!("{}{}", state.root_url_prefix, SettingsPath.to_uri()))
92
            .into_response();
93
    }
94
    if next.next.is_some() {
95
        if let Err(err) = session.add_message(Message {
96
            message: "You need to be logged in to access this page.".into(),
97
            level: Level::Info,
98
        }) {
99
            return err.into_response();
100
        };
101
    }
102

            
103
    let now: i64 = chrono::offset::Utc::now().timestamp();
104

            
105
    let prev_token = if let Some(tok) = session.get::<(String, i64)>(TOKEN_KEY) {
106
        let timestamp: i64 = tok.1;
107
        if !(timestamp < now && now - timestamp < EXPIRY_IN_SECS) {
108
            session.remove(TOKEN_KEY);
109
            None
110
        } else {
111
            Some(tok)
112
        }
113
    } else {
114
        None
115
    };
116

            
117
    let (token, timestamp): (String, i64) = prev_token.map_or_else(
118
        || {
119
            use rand::{distributions::Alphanumeric, thread_rng, Rng};
120

            
121
            let mut rng = thread_rng();
122
            let chars: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect();
123
            println!("Random chars: {}", chars);
124
            session.insert(TOKEN_KEY, (&chars, now)).unwrap();
125
            (chars, now)
126
        },
127
        |tok| tok,
128
    );
129
    let timeout_left = ((timestamp + EXPIRY_IN_SECS) - now) as f64 / 60.0;
130

            
131
    let crumbs = vec![
132
        Crumb {
133
            label: "Home".into(),
134
            url: "/".into(),
135
        },
136
        Crumb {
137
            label: "Sign in".into(),
138
            url: LoginPath.to_crumb(),
139
        },
140
    ];
141

            
142
    let context = minijinja::context! {
143
        namespace => &state.public_url,
144
        page_title => "Log in",
145
        ssh_challenge => token,
146
        timeout_left => timeout_left,
147
        current_user => auth.current_user,
148
        messages => session.drain_messages(),
149
        crumbs => crumbs,
150
    };
151
    Html(
152
        TEMPLATES
153
            .get_template("auth.html")
154
            .unwrap()
155
            .render(context)
156
            .unwrap_or_else(|err| err.to_string()),
157
    )
158
    .into_response()
159
2
}
160

            
161
#[allow(non_snake_case)]
162
1
pub async fn ssh_signin_POST(
163
1
    _: LoginPath,
164
    mut session: WritableSession,
165
    Query(next): Query<Next>,
166
    mut auth: AuthContext,
167
    Form(payload): Form<AuthFormPayload>,
168
    state: Arc<AppState>,
169
1
) -> Result<Redirect, ResponseError> {
170
    if auth.current_user.as_ref().is_some() {
171
        session.add_message(Message {
172
            message: "You are already logged in.".into(),
173
            level: Level::Info,
174
        })?;
175
        return Ok(next.or_else(|| format!("{}{}", state.root_url_prefix, SettingsPath.to_uri())));
176
    }
177

            
178
    let now: i64 = chrono::offset::Utc::now().timestamp();
179

            
180
    let (_prev_token, _) = if let Some(tok @ (_, timestamp)) =
181
        session.get::<(String, i64)>(TOKEN_KEY)
182
    {
183
        if !(timestamp <= now && now - timestamp < EXPIRY_IN_SECS) {
184
            session.add_message(Message {
185
                message: "The token has expired. Please retry.".into(),
186
                level: Level::Error,
187
            })?;
188
            return Ok(Redirect::to(&format!(
189
                "{}{}?next={}",
190
                state.root_url_prefix,
191
                LoginPath.to_uri(),
192
                next.next.as_ref().map_or(Cow::Borrowed(""), |next| format!(
193
                    "?next={}",
194
                    percent_encoding::utf8_percent_encode(
195
                        next.as_str(),
196
                        percent_encoding::CONTROLS
197
                    )
198
                )
199
                .into())
200
            )));
201
        } else {
202
            tok
203
        }
204
    } else {
205
        session.add_message(Message {
206
            message: "The token has expired. Please retry.".into(),
207
            level: Level::Error,
208
        })?;
209
        return Ok(Redirect::to(&format!(
210
            "{}{}{}",
211
            state.root_url_prefix,
212
            LoginPath.to_uri(),
213
            next.next.as_ref().map_or(Cow::Borrowed(""), |next| format!(
214
                "?next={}",
215
                percent_encoding::utf8_percent_encode(next.as_str(), percent_encoding::CONTROLS)
216
            )
217
            .into())
218
        )));
219
    };
220

            
221
    let db = Connection::open_db(state.conf.clone())?;
222
    let mut acc = match db
223
        .account_by_address(&payload.address)
224
        .with_status(StatusCode::BAD_REQUEST)?
225
    {
226
        Some(v) => v,
227
        None => {
228
            session.add_message(Message {
229
                message: "Invalid account details, please retry.".into(),
230
                level: Level::Error,
231
            })?;
232
            return Ok(Redirect::to(&format!(
233
                "{}{}{}",
234
                state.root_url_prefix,
235
                LoginPath.to_uri(),
236
                next.next.as_ref().map_or(Cow::Borrowed(""), |next| format!(
237
                    "?next={}",
238
                    percent_encoding::utf8_percent_encode(
239
                        next.as_str(),
240
                        percent_encoding::CONTROLS
241
                    )
242
                )
243
                .into())
244
            )));
245
        }
246
    };
247
    #[cfg(not(debug_assertions))]
248
    let sig = SshSignature {
249
        email: payload.address.clone(),
250
        ssh_public_key: acc.password.clone(),
251
        ssh_signature: payload.password.clone(),
252
        namespace: std::env::var("SSH_NAMESPACE")
253
            .unwrap_or_else(|_| "lists.mailpot.rs".to_string())
254
            .into(),
255
        token: _prev_token,
256
    };
257
    #[cfg(not(debug_assertions))]
258
    {
259
        #[cfg(not(feature = "ssh-key"))]
260
        let ssh_verify_fn = ssh_verify;
261
        #[cfg(feature = "ssh-key")]
262
        let ssh_verify_fn = ssh_verify_in_memory;
263
        if let Err(err) = ssh_verify_fn(sig).await {
264
            session.add_message(Message {
265
                message: format!("Could not verify signature: {err}").into(),
266
                level: Level::Error,
267
            })?;
268
            return Ok(Redirect::to(&format!(
269
                "{}{}{}",
270
                state.root_url_prefix,
271
                LoginPath.to_uri(),
272
                next.next.as_ref().map_or(Cow::Borrowed(""), |next| format!(
273
                    "?next={}",
274
                    percent_encoding::utf8_percent_encode(
275
                        next.as_str(),
276
                        percent_encoding::CONTROLS
277
                    )
278
                )
279
                .into())
280
            )));
281
        }
282
    }
283

            
284
    let user = User {
285
        pk: acc.pk(),
286
        ssh_signature: payload.password,
287
        role: if db
288
            .conf()
289
            .administrators
290
            .iter()
291
            .any(|a| a.eq_ignore_ascii_case(&payload.address))
292
        {
293
            Role::Admin
294
        } else {
295
            Role::User
296
        },
297
        public_key: std::mem::take(&mut acc.public_key),
298
        password: std::mem::take(&mut acc.password),
299
        name: std::mem::take(&mut acc.name),
300
        address: payload.address,
301
        enabled: acc.enabled,
302
    };
303
    state.insert_user(acc.pk(), user.clone()).await;
304
    drop(session);
305
    auth.login(&user)
306
        .await
307
        .map_err(|err| ResponseError::new(err.to_string(), StatusCode::BAD_REQUEST))?;
308
    Ok(next.or_else(|| format!("{}{}", state.root_url_prefix, SettingsPath.to_uri())))
309
2
}
310

            
311
6
#[derive(Debug, Clone, Default)]
312
pub struct SshSignature {
313
3
    pub email: String,
314
3
    pub ssh_public_key: String,
315
3
    pub ssh_signature: String,
316
3
    pub namespace: Cow<'static, str>,
317
3
    pub token: String,
318
}
319

            
320
/// Run ssh signature validation with `ssh-keygen` binary.
321
///
322
/// ```no_run
323
/// use mailpot_web::{ssh_verify, SshSignature};
324
///
325
/// async fn verify_signature(
326
///     ssh_public_key: String,
327
///     ssh_signature: String,
328
/// ) -> std::result::Result<(), Box<dyn std::error::Error>> {
329
///     let sig = SshSignature {
330
///         email: "user@example.com".to_string(),
331
///         ssh_public_key,
332
///         ssh_signature,
333
///         namespace: "doc-test@example.com".into(),
334
///         token: "d074a61990".to_string(),
335
///     };
336
///
337
///     ssh_verify(sig).await?;
338
///     Ok(())
339
/// }
340
/// ```
341
21
pub async fn ssh_verify(sig: SshSignature) -> Result<(), Box<dyn std::error::Error>> {
342
    let SshSignature {
343
2
        email,
344
2
        ssh_public_key,
345
2
        ssh_signature,
346
2
        namespace,
347
2
        token,
348
    } = sig;
349
4
    let dir = tempfile::tempdir()?;
350

            
351
2
    let mut allowed_signers_fp = NamedTempFile::new_in(dir.path())?;
352
2
    let mut signature_fp = NamedTempFile::new_in(dir.path())?;
353
    {
354
2
        let (tempfile, path) = allowed_signers_fp.into_parts();
355
2
        let mut file = File::from(tempfile);
356

            
357
6
        file.write_all(format!("{email} {ssh_public_key}").as_bytes())
358
6
            .await?;
359
4
        file.flush().await?;
360
2
        allowed_signers_fp = NamedTempFile::from_parts(file.into_std().await, path);
361
2
    }
362
    {
363
2
        let (tempfile, path) = signature_fp.into_parts();
364
2
        let mut file = File::from(tempfile);
365

            
366
6
        file.write_all(ssh_signature.trim().replace("\r\n", "\n").as_bytes())
367
6
            .await?;
368
4
        file.flush().await?;
369
2
        signature_fp = NamedTempFile::from_parts(file.into_std().await, path);
370
2
    }
371

            
372
2
    let mut cmd = Command::new("ssh-keygen");
373

            
374
2
    cmd.stdout(Stdio::piped());
375
2
    cmd.stderr(Stdio::piped());
376
2
    cmd.stdin(Stdio::piped());
377

            
378
    // Once you have your allowed signers file, verification works like this:
379
    //
380
    // ```shell
381
    // ssh-keygen -Y verify -f allowed_signers -I alice@example.com -n file -s file_to_verify.sig < file_to_verify
382
    // ```
383
    //
384
    // Here are the arguments you may need to change:
385
    //
386
    // - `allowed_signers` is the path to the allowed signers file.
387
    // - `alice@example.com` is the email address of the person who allegedly signed
388
    //   the file. This email address is looked up in the allowed signers file to
389
    //   get possible public keys.
390
    // - `file` is the "namespace", which must match the namespace used for signing
391
    //   as described above.
392
    // - `file_to_verify.sig` is the path to the signature file.
393
    // - `file_to_verify` is the path to the file to be verified. Note that this
394
    //   file is read from standard in. In the above command, the < shell operator
395
    //   is used to redirect standard in from this file.
396
    //
397
    // If the signature is valid, the command exits with status `0` and prints a
398
    // message like this:
399
    //
400
    // > Good "file" signature for alice@example.com with ED25519 key
401
    // > SHA256:ZGa8RztddW4kE2XKPPsP9ZYC7JnMObs6yZzyxg8xZSk
402
    //
403
    // Otherwise, the command exits with a non-zero status and prints an error
404
    // message.
405

            
406
10
    let mut child = cmd
407
        .arg("-Y")
408
        .arg("verify")
409
        .arg("-f")
410
2
        .arg(allowed_signers_fp.path())
411
        .arg("-I")
412
2
        .arg(&email)
413
        .arg("-n")
414
2
        .arg(namespace.as_ref())
415
        .arg("-s")
416
2
        .arg(signature_fp.path())
417
        .spawn()
418
2
        .expect("failed to spawn command");
419

            
420
2
    let mut stdin = child
421
        .stdin
422
        .take()
423
2
        .expect("child did not have a handle to stdin");
424

            
425
8
    stdin
426
2
        .write_all(token.as_bytes())
427
6
        .await
428
        .expect("could not write to stdin");
429

            
430
2
    drop(stdin);
431

            
432
11
    let op = child.wait_with_output().await?;
433

            
434
2
    if !op.status.success() {
435
3
        return Err(format!(
436
            "ssh-keygen exited with {}:\nstdout: {}\n\nstderr: {}",
437
1
            op.status.code().unwrap_or(-1),
438
1
            String::from_utf8_lossy(&op.stdout),
439
1
            String::from_utf8_lossy(&op.stderr)
440
        )
441
        .into());
442
    }
443

            
444
1
    Ok(())
445
7
}
446

            
447
/// Run ssh signature validation.
448
///
449
/// ```no_run
450
/// use mailpot_web::{ssh_verify_in_memory, SshSignature};
451
///
452
/// async fn ssh_verify(
453
///     ssh_public_key: String,
454
///     ssh_signature: String,
455
/// ) -> std::result::Result<(), Box<dyn std::error::Error>> {
456
///     let sig = SshSignature {
457
///         email: "user@example.com".to_string(),
458
///         ssh_public_key,
459
///         ssh_signature,
460
///         namespace: "doc-test@example.com".into(),
461
///         token: "d074a61990".to_string(),
462
///     };
463
///
464
///     ssh_verify_in_memory(sig).await?;
465
///     Ok(())
466
/// }
467
/// ```
468
#[cfg(feature = "ssh-key")]
469
16
pub async fn ssh_verify_in_memory(sig: SshSignature) -> Result<(), Box<dyn std::error::Error>> {
470
    use ssh_key::{PublicKey, SshSig};
471

            
472
    let SshSignature {
473
        email: _,
474
4
        ref ssh_public_key,
475
4
        ref ssh_signature,
476
4
        ref namespace,
477
4
        ref token,
478
    } = sig;
479

            
480
9
    let public_key = ssh_public_key.parse::<PublicKey>().map_err(|err| {
481
1
        format!("Could not parse user's SSH public key. Is it valid? Reason given: {err}")
482
2
    })?;
483
6
    let signature = if ssh_signature.contains("\r\n") {
484
        ssh_signature.trim().replace("\r\n", "\n").parse::<SshSig>()
485
    } else {
486
3
        ssh_signature.parse::<SshSig>()
487
    }
488
2
    .map_err(|err| format!("Invalid SSH signature. Reason given: {err}"))?;
489

            
490
2
    if let Err(err) = public_key.verify(namespace, token.as_bytes(), &signature) {
491
        use ssh_key::Error;
492

            
493
        #[allow(clippy::wildcard_in_or_patterns)]
494
1
        return match err {
495
            Error::Io(err_kind) => {
496
                log::error!(
497
                    "ssh signature could not be verified because of internal error:\nSignature \
498
                     was {sig:#?}\nError was {err_kind}."
499
                );
500
                Err("SSH signature could not be verified because of internal error.".into())
501
            }
502
1
            Error::Crypto => Err("SSH signature is invalid.".into()),
503
            Error::AlgorithmUnknown
504
            | Error::AlgorithmUnsupported { .. }
505
            | Error::CertificateFieldInvalid(_)
506
            | Error::CertificateValidation
507
            | Error::Decrypted
508
            | Error::Ecdsa(_)
509
            | Error::Encoding(_)
510
            | Error::Encrypted
511
            | Error::FormatEncoding
512
            | Error::Namespace
513
            | Error::PublicKey
514
            | Error::Time
515
            | Error::TrailingData { .. }
516
            | Error::Version { .. }
517
            | _ => Err(format!("SSH signature could not be verified: Reason given: {err}").into()),
518
        };
519
2
    }
520

            
521
1
    Ok(())
522
14
}
523

            
524
pub async fn logout_handler(
525
    _: LogoutPath,
526
    mut auth: AuthContext,
527
    State(state): State<Arc<AppState>>,
528
) -> Redirect {
529
    auth.logout().await;
530
    Redirect::to(&format!("{}/", state.root_url_prefix))
531
}
532

            
533
pub mod auth_request {
534
    use std::{marker::PhantomData, ops::RangeBounds};
535

            
536
    use axum::body::HttpBody;
537
    use dyn_clone::DynClone;
538
    use tower_http::auth::AuthorizeRequest;
539

            
540
    use super::*;
541

            
542
    trait RoleBounds<Role>: DynClone + Send + Sync {
543
        fn contains(&self, role: Option<Role>) -> bool;
544
    }
545

            
546
    impl<T, Role> RoleBounds<Role> for T
547
    where
548
        Role: PartialOrd + PartialEq,
549
        T: RangeBounds<Role> + Clone + Send + Sync,
550
    {
551
        fn contains(&self, role: Option<Role>) -> bool {
552
            role.as_ref()
553
                .map_or_else(|| role.is_none(), |role| RangeBounds::contains(self, role))
554
        }
555
    }
556

            
557
    /// Type that performs login authorization.
558
    ///
559
    /// See [`RequireAuthorizationLayer::login`] for more details.
560
    pub struct Login<UserId, User, ResBody, Role = ()> {
561
        login_url: Option<Arc<Cow<'static, str>>>,
562
        redirect_field_name: Option<Arc<Cow<'static, str>>>,
563
        role_bounds: Box<dyn RoleBounds<Role>>,
564
        _user_id_type: PhantomData<UserId>,
565
        _user_type: PhantomData<User>,
566
        _body_type: PhantomData<fn() -> ResBody>,
567
    }
568

            
569
    impl<UserId, User, ResBody, Role> Clone for Login<UserId, User, ResBody, Role> {
570
        fn clone(&self) -> Self {
571
            Self {
572
                login_url: self.login_url.clone(),
573
                redirect_field_name: self.redirect_field_name.clone(),
574
                role_bounds: dyn_clone::clone_box(&*self.role_bounds),
575
                _user_id_type: PhantomData,
576
                _user_type: PhantomData,
577
                _body_type: PhantomData,
578
            }
579
        }
580
    }
581

            
582
    impl<UserId, User, ReqBody, ResBody, Role> AuthorizeRequest<ReqBody>
583
        for Login<UserId, User, ResBody, Role>
584
    where
585
        Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
586
        User: AuthUser<UserId, Role>,
587
        ResBody: HttpBody + Default,
588
    {
589
        type ResponseBody = ResBody;
590

            
591
        fn authorize(
592
            &mut self,
593
            request: &mut Request<ReqBody>,
594
        ) -> Result<(), Response<Self::ResponseBody>> {
595
            let user = request
596
                .extensions()
597
                .get::<Option<User>>()
598
                .expect("Auth extension missing. Is the auth layer installed?");
599

            
600
            match user {
601
                Some(user) if self.role_bounds.contains(user.get_role()) => {
602
                    let user = user.clone();
603
                    request.extensions_mut().insert(user);
604

            
605
                    Ok(())
606
                }
607

            
608
                _ => {
609
                    let unauthorized_response = if let Some(ref login_url) = self.login_url {
610
                        let url: Cow<'static, str> = self.redirect_field_name.as_ref().map_or_else(
611
                            || login_url.as_ref().clone(),
612
                            |next| {
613
                                format!(
614
                                    "{login_url}?{next}={}",
615
                                    percent_encoding::utf8_percent_encode(
616
                                        request.uri().path(),
617
                                        percent_encoding::CONTROLS
618
                                    )
619
                                )
620
                                .into()
621
                            },
622
                        );
623

            
624
                        Response::builder()
625
                            .status(http::StatusCode::TEMPORARY_REDIRECT)
626
                            .header(http::header::LOCATION, url.as_ref())
627
                            .body(Default::default())
628
                            .unwrap()
629
                    } else {
630
                        Response::builder()
631
                            .status(http::StatusCode::UNAUTHORIZED)
632
                            .body(Default::default())
633
                            .unwrap()
634
                    };
635

            
636
                    Err(unauthorized_response)
637
                }
638
            }
639
        }
640
    }
641

            
642
    /// A wrapper around [`tower_http::auth::RequireAuthorizationLayer`] which
643
    /// provides login authorization.
644
    pub struct RequireAuthorizationLayer<UserId, User, Role = ()>(UserId, User, Role);
645

            
646
    impl<UserId, User, Role> RequireAuthorizationLayer<UserId, User, Role>
647
    where
648
        Role: PartialOrd + PartialEq + Clone + Send + Sync + 'static,
649
        User: AuthUser<UserId, Role>,
650
    {
651
        /// Authorizes requests by requiring a logged in user, otherwise it
652
        /// rejects with [`http::StatusCode::UNAUTHORIZED`].
653
        pub fn login<ResBody>(
654
        ) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
655
        where
656
            ResBody: HttpBody + Default,
657
        {
658
            tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
659
                login_url: None,
660
                redirect_field_name: None,
661
                role_bounds: Box::new(..),
662
                _user_id_type: PhantomData,
663
                _user_type: PhantomData,
664
                _body_type: PhantomData,
665
            })
666
        }
667

            
668
        /// Authorizes requests by requiring a logged in user to have a specific
669
        /// range of roles, otherwise it rejects with
670
        /// [`http::StatusCode::UNAUTHORIZED`].
671
        pub fn login_with_role<ResBody>(
672
            role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
673
        ) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
674
        where
675
            ResBody: HttpBody + Default,
676
        {
677
            tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
678
                login_url: None,
679
                redirect_field_name: None,
680
                role_bounds: Box::new(role_bounds),
681
                _user_id_type: PhantomData,
682
                _user_type: PhantomData,
683
                _body_type: PhantomData,
684
            })
685
        }
686

            
687
        /// Authorizes requests by requiring a logged in user, otherwise it
688
        /// redirects to the provided login URL.
689
        ///
690
        /// If `redirect_field_name` is set to a value, the login page will
691
        /// receive the path it was redirected from in the URI query
692
        /// part. For example, attempting to visit a protected path
693
        /// `/protected` would redirect you to `/login?next=/protected` allowing
694
        /// you to know how to return the visitor to their requested
695
        /// page.
696
        pub fn login_or_redirect<ResBody>(
697
            login_url: Arc<Cow<'static, str>>,
698
            redirect_field_name: Option<Arc<Cow<'static, str>>>,
699
        ) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
700
        where
701
            ResBody: HttpBody + Default,
702
        {
703
            tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
704
                login_url: Some(login_url),
705
                redirect_field_name,
706
                role_bounds: Box::new(..),
707
                _user_id_type: PhantomData,
708
                _user_type: PhantomData,
709
                _body_type: PhantomData,
710
            })
711
        }
712

            
713
        /// Authorizes requests by requiring a logged in user to have a specific
714
        /// range of roles, otherwise it redirects to the
715
        /// provided login URL.
716
        ///
717
        /// If `redirect_field_name` is set to a value, the login page will
718
        /// receive the path it was redirected from in the URI query
719
        /// part. For example, attempting to visit a protected path
720
        /// `/protected` would redirect you to `/login?next=/protected` allowing
721
        /// you to know how to return the visitor to their requested
722
        /// page.
723
        pub fn login_with_role_or_redirect<ResBody>(
724
            role_bounds: impl RangeBounds<Role> + Clone + Send + Sync + 'static,
725
            login_url: Arc<Cow<'static, str>>,
726
            redirect_field_name: Option<Arc<Cow<'static, str>>>,
727
        ) -> tower_http::auth::RequireAuthorizationLayer<Login<UserId, User, ResBody, Role>>
728
        where
729
            ResBody: HttpBody + Default,
730
        {
731
            tower_http::auth::RequireAuthorizationLayer::custom(Login::<_, _, _, _> {
732
                login_url: Some(login_url),
733
                redirect_field_name,
734
                role_bounds: Box::new(role_bounds),
735
                _user_id_type: PhantomData,
736
                _user_type: PhantomData,
737
                _body_type: PhantomData,
738
            })
739
        }
740
    }
741
}
742

            
743
#[cfg(test)]
744
mod tests {
745
    use super::*;
746
    const PKEY: &str = concat!(
747
        "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCzXp8nLJL8GPNw7S+Dqt0m3Dw/",
748
            "xFOAdwKXcekTFI9cLDEUII2rNPf0uUZTpv57OgU+",
749
            "QOEEIvWMjz+5KSWBX8qdP8OtV0QNvynlZkEKZN0cUqGKaNXo5a+PUDyiJ2rHroPe1aMo6mUBL9kLR6J2U1CYD/dLfL8ywXsAGmOL0bsK0GRPVBJAjpUNRjpGU/",
750
            "2FFIlU6s6GawdbDXEHDox/UoOVAKIlhKabaTrFBA0ACFLRX2/GCBmHqqt5d4ZZjefYzReLs/beOjafYImoyhHC428wZDcUjvLrpSJbIOE/",
751
            "gSPCWlRbcsxg4JGcKOtALUurE+ok+avy9M7eFjGhLGSlTKLdshIVQr/3W667M7bYfOT6xP/",
752
            "lyjxeWIUYyj7rjlqKJ9tzygek7QNxCtuqH5xsZAZqzQCN8wfrPAlwDykvWityKOw+Bt2DWjimITqyKgsBsOaA+",
753
            "eVCllFvooJxoYvAjODASjAUoOdgVzyBDpFnOhLFYiIIyL3F6NROS9i7z086paX7mrzcQzvLr4ckF9qT7DrI88ikISCR9bFR4vPq3aH",
754
            "zJdjDDpWxACa5b11NG8KdCJPe/L0kDw82Q00U13CpW9FI9sZjvk+",
755
            "lyw8bTFvVsIl6A0ueboFvrNvznAqHrtfWu75fXRh5sKj2TGk8rhm3vyNgrBSr5zAfFVM8LgqBxbAAYw=="
756
            );
757

            
758
    const ARMOR_SIG: &str = concat!(
759
        "-----BEGIN SSH SIGNATURE-----\n",
760
        "U1NIU0lHAAAAAQAAAhcAAAAHc3NoLXJzYQAAAAMBAAEAAAIBALNenycskvwY83DtL4Oq3S\n",
761
        "bcPD/EU4B3Apdx6RMUj1wsMRQgjas09/S5RlOm/ns6BT5A4QQi9YyPP7kpJYFfyp0/w61X\n",
762
        "RA2/KeVmQQpk3RxSoYpo1ejlr49QPKInaseug97VoyjqZQEv2QtHonZTUJgP90t8vzLBew\n",
763
        "AaY4vRuwrQZE9UEkCOlQ1GOkZT/YUUiVTqzoZrB1sNcQcOjH9Sg5UAoiWEpptpOsUEDQAI\n",
764
        "UtFfb8YIGYeqq3l3hlmN59jNF4uz9t46Np9giajKEcLjbzBkNxSO8uulIlsg4T+BI8JaVF\n",
765
        "tyzGDgkZwo60AtS6sT6iT5q/L0zt4WMaEsZKVMot2yEhVCv/dbrrsztth85PrE/+XKPF5Y\n",
766
        "hRjKPuuOWoon23PKB6TtA3EK26ofnGxkBmrNAI3zB+s8CXAPKS9aK3Io7D4G3YNaOKYhOr\n",
767
        "IqCwGw5oD55UKWUW+ignGhi8CM4MBKMBSg52BXPIEOkWc6EsViIgjIvcXo1E5L2LvPTzql\n",
768
        "pfuavNxDO8uvhyQX2pPsOsjzyKQhIJH1sVHi8+rdofMl2MMOlbEAJrlvXU0bwp0Ik978vS\n",
769
        "QPDzZDTRTXcKlb0Uj2xmO+T6XLDxtMW9WwiXoDS55ugW+s2/OcCoeu19a7vl9dGHmwqPZM\n",
770
        "aTyuGbe/I2CsFKvnMB8VUzwuCoHFsABjAAAAFGRvYy10ZXN0QGV4YW1wbGUuY29tAAAAAA\n",
771
        "AAAAZzaGE1MTIAAAIUAAAADHJzYS1zaGEyLTUxMgAAAgBxaMqIfeapKTrhQzggDssD+76s\n",
772
        "jZxv3XxzgsuAjlIdtw+/nyxU6skTnrGoam2shpmQvx0HuqSQ7HyS2USBK7T4LZNoE53zR/\n",
773
        "ZmHLGoyQAoexiHSEW9Lk53kyRNPhpXQedTvm8REHPGM3zw6WO6mAXVVxvebvawf81LTbBb\n",
774
        "p9ubNRcHgktVeywMO/sD6zWSyShq1gjVv1PdRBOjUgqkwjImL8dFKi1QUeoffCxyk3JhTO\n",
775
        "siTy79HZSz/kOvkvL1vQuqaP2R8lE9P1uaD19dGOMTPRod3u+QmpYX47ri5KM3Fmkfxdwq\n",
776
        "p8JVmfAA9nme7bmNS1hWgmF2Nbh9qjh1zOZvCimIpuNtz5eEl9K+1DxG6w5tX86wSGvBMO\n",
777
        "znx0k1gGfkiAULqgrkdul7mqMPRvPN9J6QlNJ7SLFChRhzlJIJc6tOvCs7qkVD43Zcb+I5\n",
778
        "Z+K4NiFf5jf8kVX/pjjeW/ucbrctJIkGsZ58OkHKi1EDRcq7NtCF6SKlcv8g3fMLd9wW6K\n",
779
        "aaed0TBDC+s+f6naNIGvWqfWCwDuK5xGyDTTmJGcrsMwWuT9K6uLk8cGdv7t5mOFuWi5jl\n",
780
        "E+IKZKVABMuWqSj96ErMIiBjtsAZfNSezpsK49wQztoSPhdwLhD6fHrSAyPCqN2xRkcsIb\n",
781
        "6PxWKC/OELf3gyEBRPouxsF7xSZQ==\n",
782
        "-----END SSH SIGNATURE-----\n"
783
    );
784

            
785
3
    fn create_sig() -> SshSignature {
786
3
        SshSignature {
787
3
            email: "user@example.com".to_string(),
788
3
            ssh_public_key: PKEY.to_string(),
789
3
            ssh_signature: ARMOR_SIG.to_string(),
790
3
            namespace: "doc-test@example.com".into(),
791
3
            token: "d074a61990".to_string(),
792
        }
793
3
    }
794

            
795
18
    #[tokio::test]
796
2
    async fn test_ssh_verify() {
797
1
        let mut sig = create_sig();
798
8
        ssh_verify(sig.clone()).await.unwrap();
799

            
800
1
        sig.ssh_signature = sig.ssh_signature.replace('J', "0");
801

            
802
9
        let err = ssh_verify(sig).await.unwrap_err();
803

            
804
2
        assert!(
805
1
            err.to_string().starts_with("ssh-keygen exited with"),
806
            "{}",
807
            err
808
1
        );
809
3
    }
810

            
811
    #[cfg(feature = "ssh-key")]
812
3
    #[tokio::test]
813
2
    async fn test_ssh_verify_in_memory() {
814
1
        let mut sig = create_sig();
815
1
        ssh_verify_in_memory(sig.clone()).await.unwrap();
816

            
817
1
        sig.ssh_signature = sig.ssh_signature.replace('J', "0");
818

            
819
1
        let err = ssh_verify_in_memory(sig.clone()).await.unwrap_err();
820

            
821
1
        assert_eq!(
822
1
            &err.to_string(),
823
            "Invalid SSH signature. Reason given: invalid label: 'ssh-}3a'",
824
            "{}",
825
            err
826
        );
827

            
828
1
        sig.ssh_public_key = sig.ssh_public_key.replace(' ', "0");
829

            
830
1
        let err = ssh_verify_in_memory(sig).await.unwrap_err();
831
1
        assert_eq!(
832
1
            &err.to_string(),
833
            "Could not parse user's SSH public key. Is it valid? Reason given: length invalid",
834
            "{}",
835
            err
836
        );
837

            
838
1
        let mut sig = create_sig();
839
1
        sig.token = sig.token.replace('d', "0");
840

            
841
1
        let err = ssh_verify_in_memory(sig).await.unwrap_err();
842
2
        assert_eq!(&err.to_string(), "SSH signature is invalid.", "{}", err);
843
3
    }
844
}