Coverage Report

Created: 2025-05-07 06:59

/rust/registry/src/index.crates.io-6f17d22bba15001f/reqsign-0.16.3/src/aws/credential.rs
Line
Count
Source (jump to first uncovered line)
1
use std::fmt::Debug;
2
use std::fmt::Write;
3
use std::fs;
4
use std::sync::Arc;
5
use std::sync::Mutex;
6
7
use anyhow::anyhow;
8
use anyhow::Result;
9
use async_trait::async_trait;
10
use http::header::CONTENT_LENGTH;
11
use log::debug;
12
use quick_xml::de;
13
use reqwest::Client;
14
use serde::Deserialize;
15
16
use super::config::Config;
17
use super::constants::X_AMZ_CONTENT_SHA_256;
18
use super::v4::Signer;
19
use crate::time::now;
20
use crate::time::parse_rfc3339;
21
use crate::time::DateTime;
22
23
pub const EMPTY_STRING_SHA256: &str =
24
    "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
25
26
/// Credential that holds the access_key and secret_key.
27
#[derive(Default, Clone)]
28
#[cfg_attr(test, derive(Debug))]
29
pub struct Credential {
30
    /// Access key id for aws services.
31
    pub access_key_id: String,
32
    /// Secret access key for aws services.
33
    pub secret_access_key: String,
34
    /// Session token for aws services.
35
    pub session_token: Option<String>,
36
    /// Expiration time for this credential.
37
    pub expires_in: Option<DateTime>,
38
}
39
40
impl Credential {
41
    /// is current cred is valid?
42
0
    pub fn is_valid(&self) -> bool {
43
0
        if (self.access_key_id.is_empty() || self.secret_access_key.is_empty())
44
0
            && self.session_token.is_none()
45
        {
46
0
            return false;
47
0
        }
48
        // Take 120s as buffer to avoid edge cases.
49
0
        if let Some(valid) = self
50
0
            .expires_in
51
0
            .map(|v| v > now() + chrono::TimeDelta::try_minutes(2).expect("in bounds"))
52
        {
53
0
            return valid;
54
0
        }
55
0
56
0
        true
57
0
    }
58
}
59
60
/// Loader trait will try to load credential from different sources.
61
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
62
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
63
pub trait CredentialLoad: 'static + Send + Sync {
64
    /// Load credential from sources.
65
    ///
66
    /// - If succeed, return `Ok(Some(cred))`
67
    /// - If not found, return `Ok(None)`
68
    /// - If unexpected errors happened, return `Err(err)`
69
    async fn load_credential(&self, client: Client) -> Result<Option<Credential>>;
70
}
71
72
/// CredentialLoader will load credential from different methods.
73
pub struct DefaultLoader {
74
    client: Client,
75
    config: Config,
76
    credential: Arc<Mutex<Option<Credential>>>,
77
    imds_v2_loader: Option<IMDSv2Loader>,
78
}
79
80
impl DefaultLoader {
81
    /// Create a new CredentialLoader
82
0
    pub fn new(client: Client, config: Config) -> Self {
83
0
        let imds_v2_loader = if config.ec2_metadata_disabled {
84
0
            None
85
        } else {
86
0
            Some(IMDSv2Loader::new(client.clone()))
87
        };
88
0
        Self {
89
0
            client,
90
0
            config,
91
0
            credential: Arc::default(),
92
0
            imds_v2_loader,
93
0
        }
94
0
    }
95
96
    /// Disable load from ec2 metadata.
97
0
    pub fn with_disable_ec2_metadata(mut self) -> Self {
98
0
        self.imds_v2_loader = None;
99
0
        self
100
0
    }
101
102
    /// Load credential.
103
    ///
104
    /// Resolution order:
105
    /// 1. Environment variables
106
    /// 2. Shared config (`~/.aws/config`, `~/.aws/credentials`)
107
    /// 3. Web Identity Tokens
108
    /// 4. ECS (IAM Roles for Tasks) & General HTTP credentials:
109
    /// 5. EC2 IMDSv2
110
0
    pub async fn load(&self) -> Result<Option<Credential>> {
111
0
        // Return cached credential if it has been loaded at least once.
112
0
        match self.credential.lock().expect("lock poisoned").clone() {
113
0
            Some(cred) if cred.is_valid() => return Ok(Some(cred)),
114
0
            _ => (),
115
        }
116
117
0
        let cred = self.load_inner().await?;
118
119
0
        let mut lock = self.credential.lock().expect("lock poisoned");
120
0
        lock.clone_from(&cred);
121
0
122
0
        Ok(cred)
123
0
    }
124
125
0
    async fn load_inner(&self) -> Result<Option<Credential>> {
126
0
        if let Some(cred) = self.load_via_config().map_err(|err| {
127
0
            debug!("load credential via config failed: {err:?}");
128
0
            err
129
0
        })? {
130
0
            return Ok(Some(cred));
131
0
        }
132
133
0
        if let Some(cred) = self
134
0
            .load_via_assume_role_with_web_identity()
135
0
            .await
136
0
            .map_err(|err| {
137
0
                debug!("load credential via assume_role_with_web_identity failed: {err:?}");
138
0
                err
139
0
            })?
140
        {
141
0
            return Ok(Some(cred));
142
0
        }
143
144
0
        if let Some(cred) = self.load_via_imds_v2().await.map_err(|err| {
145
0
            debug!("load credential via imds_v2 failed: {err:?}");
146
0
            err
147
0
        })? {
148
0
            return Ok(Some(cred));
149
0
        }
150
0
151
0
        Ok(None)
152
0
    }
153
154
0
    fn load_via_config(&self) -> Result<Option<Credential>> {
155
0
        if let (Some(ak), Some(sk)) = (&self.config.access_key_id, &self.config.secret_access_key) {
156
0
            Ok(Some(Credential {
157
0
                access_key_id: ak.clone(),
158
0
                secret_access_key: sk.clone(),
159
0
                session_token: self.config.session_token.clone(),
160
0
                // Set expires_in to 10 minutes to enforce re-read
161
0
                // from file.
162
0
                expires_in: Some(now() + chrono::TimeDelta::try_minutes(10).expect("in bounds")),
163
0
            }))
164
        } else {
165
0
            Ok(None)
166
        }
167
0
    }
168
169
0
    async fn load_via_imds_v2(&self) -> Result<Option<Credential>> {
170
0
        let loader = match &self.imds_v2_loader {
171
0
            Some(loader) => loader,
172
0
            None => return Ok(None),
173
        };
174
175
0
        loader.load().await
176
0
    }
177
178
0
    async fn load_via_assume_role_with_web_identity(&self) -> Result<Option<Credential>> {
179
0
        let (token_file, role_arn) =
180
0
            match (&self.config.web_identity_token_file, &self.config.role_arn) {
181
0
                (Some(token_file), Some(role_arn)) => (token_file, role_arn),
182
0
                _ => return Ok(None),
183
            };
184
185
0
        let token = fs::read_to_string(token_file)?;
186
0
        let role_session_name = &self.config.role_session_name;
187
188
0
        let endpoint = self.sts_endpoint()?;
189
190
        // Construct request to AWS STS Service.
191
0
        let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}");
192
0
        let req = self.client.get(&url).header(
193
0
            http::header::CONTENT_TYPE.as_str(),
194
0
            "application/x-www-form-urlencoded",
195
0
        );
196
197
0
        let resp = req.send().await?;
198
0
        if resp.status() != http::StatusCode::OK {
199
0
            let content = resp.text().await?;
200
0
            return Err(anyhow!("request to AWS STS Services failed: {content}"));
201
0
        }
202
203
0
        let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.text().await?)?;
204
0
        let resp_cred = resp.result.credentials;
205
206
0
        let cred = Credential {
207
0
            access_key_id: resp_cred.access_key_id,
208
0
            secret_access_key: resp_cred.secret_access_key,
209
0
            session_token: Some(resp_cred.session_token),
210
0
            expires_in: Some(parse_rfc3339(&resp_cred.expiration)?),
211
        };
212
213
0
        Ok(Some(cred))
214
0
    }
215
216
    /// Get the sts endpoint.
217
    ///
218
    /// The returning format may look like `sts.{region}.amazonaws.com`
219
    ///
220
    /// # Notes
221
    ///
222
    /// AWS could have different sts endpoint based on it's region.
223
    /// We can check them by region name.
224
    ///
225
    /// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs
226
0
    fn sts_endpoint(&self) -> Result<String> {
227
0
        // use regional sts if sts_regional_endpoints has been set.
228
0
        if self.config.sts_regional_endpoints == "regional" {
229
0
            let region = self.config.region.clone().ok_or_else(|| {
230
0
                anyhow!("sts_regional_endpoints set to reginal, but region is not set")
231
0
            })?;
232
0
            if region.starts_with("cn-") {
233
0
                Ok(format!("sts.{region}.amazonaws.com.cn"))
234
            } else {
235
0
                Ok(format!("sts.{region}.amazonaws.com"))
236
            }
237
        } else {
238
0
            let region = self.config.region.clone().unwrap_or_default();
239
0
            if region.starts_with("cn") {
240
                // TODO: seems aws china doesn't support global sts?
241
0
                Ok("sts.amazonaws.com.cn".to_string())
242
            } else {
243
0
                Ok("sts.amazonaws.com".to_string())
244
            }
245
        }
246
0
    }
247
}
248
249
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
250
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
251
impl CredentialLoad for DefaultLoader {
252
0
    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
253
0
        self.load().await
254
0
    }
255
}
256
257
pub struct IMDSv2Loader {
258
    client: Client,
259
260
    token: Arc<Mutex<(String, DateTime)>>,
261
}
262
263
impl IMDSv2Loader {
264
    /// Create a new IMDSv2Loader.
265
0
    pub fn new(client: Client) -> Self {
266
0
        Self {
267
0
            client,
268
0
            token: Arc::new(Mutex::new(("".to_string(), DateTime::MIN_UTC))),
269
0
        }
270
0
    }
271
272
0
    pub async fn load(&self) -> Result<Option<Credential>> {
273
0
        let token = self.load_ec2_metadata_token().await?;
274
275
        // List all credentials that node has.
276
0
        let url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/";
277
0
        let req = self
278
0
            .client
279
0
            .get(url)
280
0
            .header("x-aws-ec2-metadata-token", &token);
281
0
        let resp = req.send().await?;
282
0
        if resp.status() != http::StatusCode::OK {
283
0
            let content = resp.text().await?;
284
0
            return Err(anyhow!(
285
0
                "request to AWS EC2 Metadata Services failed: {content}"
286
0
            ));
287
0
        }
288
0
        let profile_name = resp.text().await?;
289
290
        // Get the credentials via role_name.
291
0
        let url = format!(
292
0
            "http://169.254.169.254/latest/meta-data/iam/security-credentials/{profile_name}"
293
0
        );
294
0
        let req = self
295
0
            .client
296
0
            .get(&url)
297
0
            .header("x-aws-ec2-metadata-token", &token);
298
0
        let resp = req.send().await?;
299
0
        if resp.status() != http::StatusCode::OK {
300
0
            let content = resp.text().await?;
301
0
            return Err(anyhow!(
302
0
                "request to AWS EC2 Metadata Services failed: {content}"
303
0
            ));
304
0
        }
305
306
0
        let content = resp.text().await?;
307
0
        let resp: Ec2MetadataIamSecurityCredentials = serde_json::from_str(&content)?;
308
0
        if resp.code != "Success" {
309
0
            return Err(anyhow!(
310
0
                "request to AWS EC2 Metadata Services failed: {content}"
311
0
            ));
312
0
        }
313
314
0
        let cred = Credential {
315
0
            access_key_id: resp.access_key_id,
316
0
            secret_access_key: resp.secret_access_key,
317
0
            session_token: Some(resp.token),
318
0
            expires_in: Some(parse_rfc3339(&resp.expiration)?),
319
        };
320
321
0
        Ok(Some(cred))
322
0
    }
323
324
    /// load_ec2_metadata_token will load ec2 metadata token from IMDS.
325
    ///
326
    /// Return value is (token, expires_in).
327
0
    async fn load_ec2_metadata_token(&self) -> Result<String> {
328
0
        {
329
0
            let (token, expires_in) = self.token.lock().expect("lock poisoned").clone();
330
0
            if expires_in > now() {
331
0
                return Ok(token);
332
0
            }
333
0
        }
334
0
335
0
        let url = "http://169.254.169.254/latest/api/token";
336
0
        #[allow(unused_mut)]
337
0
        let mut req = self
338
0
            .client
339
0
            .put(url)
340
0
            .header(CONTENT_LENGTH, "0")
341
0
            // 21600s (6h) is recommended by AWS.
342
0
            .header("x-aws-ec2-metadata-token-ttl-seconds", "21600");
343
0
344
0
        // Set timeout to 1s to avoid hanging on non-s3 env.
345
0
        #[cfg(not(target_arch = "wasm32"))]
346
0
        {
347
0
            req = req.timeout(std::time::Duration::from_secs(1));
348
0
        }
349
350
0
        let resp = req.send().await?;
351
0
        if resp.status() != http::StatusCode::OK {
352
0
            let content = resp.text().await?;
353
0
            return Err(anyhow!(
354
0
                "request to AWS EC2 Metadata Services failed: {content}"
355
0
            ));
356
0
        }
357
0
        let ec2_token = resp.text().await?;
358
        // Set expires_in to 10 minutes to enforce re-read.
359
0
        let expires_in = now() + chrono::TimeDelta::try_seconds(21600).expect("in bounds")
360
0
            - chrono::TimeDelta::try_seconds(600).expect("in bounds");
361
0
362
0
        {
363
0
            *self.token.lock().expect("lock poisoned") = (ec2_token.clone(), expires_in);
364
0
        }
365
0
366
0
        Ok(ec2_token)
367
0
    }
368
}
369
370
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
371
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
372
impl CredentialLoad for IMDSv2Loader {
373
0
    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
374
0
        self.load().await
375
0
    }
376
}
377
378
/// AssumeRoleLoader will load credential via assume role.
379
pub struct AssumeRoleLoader {
380
    client: Client,
381
    config: Config,
382
383
    source_credential: Box<dyn CredentialLoad>,
384
    sts_signer: Signer,
385
}
386
387
impl AssumeRoleLoader {
388
    /// Create a new assume role loader.
389
0
    pub fn new(
390
0
        client: Client,
391
0
        config: Config,
392
0
        source_credential: Box<dyn CredentialLoad>,
393
0
    ) -> Result<Self> {
394
0
        let region = config.region.clone().ok_or_else(|| {
395
0
            anyhow!("assume role loader requires region, but not found, please check your configuration")
396
0
        })?;
397
398
0
        Ok(Self {
399
0
            client,
400
0
            config,
401
0
            source_credential,
402
0
403
0
            sts_signer: Signer::new("sts", &region),
404
0
        })
405
0
    }
406
407
    /// Load credential via assume role.
408
0
    pub async fn load(&self) -> Result<Option<Credential>> {
409
0
        let role_arn =self.config.role_arn.clone().ok_or_else(|| {
410
0
            anyhow!("assume role loader requires role_arn, but not found, please check your configuration")
411
0
        })?;
412
413
0
        let role_session_name = &self.config.role_session_name;
414
415
0
        let endpoint = self.sts_endpoint()?;
416
417
        // Construct request to AWS STS Service.
418
0
        let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}");
419
0
        if let Some(external_id) = &self.config.external_id {
420
0
            write!(url, "&ExternalId={external_id}")?;
421
0
        }
422
0
        if let Some(duration_seconds) = &self.config.duration_seconds {
423
0
            write!(url, "&DurationSeconds={duration_seconds}")?;
424
0
        }
425
0
        if let Some(tags) = &self.config.tags {
426
0
            for (idx, (key, value)) in tags.iter().enumerate() {
427
0
                let tag_index = idx + 1;
428
0
                write!(
429
0
                    url,
430
0
                    "&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}"
431
0
                )?;
432
            }
433
0
        }
434
435
0
        let mut req = self
436
0
            .client
437
0
            .get(&url)
438
0
            .header(
439
0
                http::header::CONTENT_TYPE.as_str(),
440
0
                "application/x-www-form-urlencoded",
441
0
            )
442
0
            // Set content sha to empty string.
443
0
            .header(X_AMZ_CONTENT_SHA_256, EMPTY_STRING_SHA256)
444
0
            .build()?;
445
446
0
        let source_cred = self
447
0
            .source_credential
448
0
            .load_credential(self.client.clone())
449
0
            .await?
450
0
            .ok_or_else(|| {
451
0
                anyhow!("source credential is required for AssumeRole, but not found, please check your configuration")
452
0
            })?;
453
454
0
        self.sts_signer.sign(&mut req, &source_cred)?;
455
456
0
        let resp = self.client.execute(req).await?;
457
0
        if resp.status() != http::StatusCode::OK {
458
0
            let content = resp.text().await?;
459
0
            return Err(anyhow!("request to AWS STS Services failed: {content}"));
460
0
        }
461
462
0
        let resp: AssumeRoleResponse = de::from_str(&resp.text().await?)?;
463
0
        let resp_cred = resp.result.credentials;
464
465
0
        let cred = Credential {
466
0
            access_key_id: resp_cred.access_key_id,
467
0
            secret_access_key: resp_cred.secret_access_key,
468
0
            session_token: Some(resp_cred.session_token),
469
0
            expires_in: Some(parse_rfc3339(&resp_cred.expiration)?),
470
        };
471
472
0
        Ok(Some(cred))
473
0
    }
474
475
    /// Get the sts endpoint.
476
    ///
477
    /// The returning format may look like `sts.{region}.amazonaws.com`
478
    ///
479
    /// # Notes
480
    ///
481
    /// AWS could have different sts endpoint based on it's region.
482
    /// We can check them by region name.
483
    ///
484
    /// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs
485
0
    fn sts_endpoint(&self) -> Result<String> {
486
0
        // use regional sts if sts_regional_endpoints has been set.
487
0
        if self.config.sts_regional_endpoints == "regional" {
488
0
            let region = self.config.region.clone().ok_or_else(|| {
489
0
                anyhow!("sts_regional_endpoints set to reginal, but region is not set")
490
0
            })?;
491
0
            if region.starts_with("cn-") {
492
0
                Ok(format!("sts.{region}.amazonaws.com.cn"))
493
            } else {
494
0
                Ok(format!("sts.{region}.amazonaws.com"))
495
            }
496
        } else {
497
0
            let region = self.config.region.clone().unwrap_or_default();
498
0
            if region.starts_with("cn") {
499
                // TODO: seems aws china doesn't support global sts?
500
0
                Ok("sts.amazonaws.com.cn".to_string())
501
            } else {
502
0
                Ok("sts.amazonaws.com".to_string())
503
            }
504
        }
505
0
    }
506
}
507
508
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
509
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
510
impl CredentialLoad for AssumeRoleLoader {
511
0
    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
512
0
        self.load().await
513
0
    }
514
}
515
516
#[derive(Default, Debug, Deserialize)]
517
#[serde(default, rename_all = "PascalCase")]
518
struct AssumeRoleWithWebIdentityResponse {
519
    #[serde(rename = "AssumeRoleWithWebIdentityResult")]
520
    result: AssumeRoleWithWebIdentityResult,
521
}
522
523
#[derive(Default, Debug, Deserialize)]
524
#[serde(default, rename_all = "PascalCase")]
525
struct AssumeRoleWithWebIdentityResult {
526
    credentials: AssumeRoleWithWebIdentityCredentials,
527
}
528
529
#[derive(Default, Debug, Deserialize)]
530
#[serde(default, rename_all = "PascalCase")]
531
struct AssumeRoleWithWebIdentityCredentials {
532
    access_key_id: String,
533
    secret_access_key: String,
534
    session_token: String,
535
    expiration: String,
536
}
537
538
#[derive(Default, Debug, Deserialize)]
539
#[serde(default, rename_all = "PascalCase")]
540
struct AssumeRoleResponse {
541
    #[serde(rename = "AssumeRoleResult")]
542
    result: AssumeRoleResult,
543
}
544
545
#[derive(Default, Debug, Deserialize)]
546
#[serde(default, rename_all = "PascalCase")]
547
struct AssumeRoleResult {
548
    credentials: AssumeRoleCredentials,
549
}
550
551
#[derive(Default, Debug, Deserialize)]
552
#[serde(default, rename_all = "PascalCase")]
553
struct AssumeRoleCredentials {
554
    access_key_id: String,
555
    secret_access_key: String,
556
    session_token: String,
557
    expiration: String,
558
}
559
560
#[derive(Default, Debug, Deserialize)]
561
#[serde(default, rename_all = "PascalCase")]
562
struct Ec2MetadataIamSecurityCredentials {
563
    access_key_id: String,
564
    secret_access_key: String,
565
    token: String,
566
    expiration: String,
567
568
    code: String,
569
}
570
571
#[cfg(test)]
572
mod tests {
573
    use std::env;
574
    use std::str::FromStr;
575
    use std::vec;
576
577
    use anyhow::Result;
578
    use http::Request;
579
    use http::StatusCode;
580
    use once_cell::sync::Lazy;
581
    use quick_xml::de;
582
    use reqwest::Client;
583
    use tokio::runtime::Runtime;
584
585
    use super::*;
586
    use crate::aws::constants::*;
587
    use crate::aws::v4::Signer;
588
589
    static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
590
        tokio::runtime::Builder::new_multi_thread()
591
            .enable_all()
592
            .build()
593
            .expect("Should create a tokio runtime")
594
    });
595
596
    #[test]
597
    fn test_credential_env_loader_without_env() {
598
        let _ = env_logger::builder().is_test(true).try_init();
599
600
        temp_env::with_vars_unset(vec![AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY], || {
601
            RUNTIME.block_on(async {
602
                let l = DefaultLoader::new(reqwest::Client::new(), Config::default())
603
                    .with_disable_ec2_metadata();
604
                let x = l.load().await.expect("load must succeed");
605
                assert!(x.is_none());
606
            })
607
        });
608
    }
609
610
    #[test]
611
    fn test_credential_env_loader_with_env() {
612
        let _ = env_logger::builder().is_test(true).try_init();
613
614
        temp_env::with_vars(
615
            vec![
616
                (AWS_ACCESS_KEY_ID, Some("access_key_id")),
617
                (AWS_SECRET_ACCESS_KEY, Some("secret_access_key")),
618
            ],
619
            || {
620
                RUNTIME.block_on(async {
621
                    let l = DefaultLoader::new(Client::new(), Config::default().from_env());
622
                    let x = l.load().await.expect("load must succeed");
623
624
                    let x = x.expect("must load succeed");
625
                    assert_eq!("access_key_id", x.access_key_id);
626
                    assert_eq!("secret_access_key", x.secret_access_key);
627
                })
628
            },
629
        );
630
    }
631
632
    #[test]
633
    fn test_credential_profile_loader_from_config() {
634
        let _ = env_logger::builder().is_test(true).try_init();
635
636
        temp_env::with_vars(
637
            vec![
638
                (AWS_ACCESS_KEY_ID, None),
639
                (AWS_SECRET_ACCESS_KEY, None),
640
                (
641
                    AWS_CONFIG_FILE,
642
                    Some(format!(
643
                        "{}/testdata/services/aws/default_config",
644
                        env::current_dir()
645
                            .expect("current_dir must exist")
646
                            .to_string_lossy()
647
                    )),
648
                ),
649
                (
650
                    AWS_SHARED_CREDENTIALS_FILE,
651
                    Some(format!(
652
                        "{}/testdata/services/aws/not_exist",
653
                        env::current_dir()
654
                            .expect("current_dir must exist")
655
                            .to_string_lossy()
656
                    )),
657
                ),
658
            ],
659
            || {
660
                RUNTIME.block_on(async {
661
                    let l = DefaultLoader::new(
662
                        Client::new(),
663
                        Config::default().from_env().from_profile(),
664
                    );
665
                    let x = l.load().await.unwrap().unwrap();
666
                    assert_eq!("config_access_key_id", x.access_key_id);
667
                    assert_eq!("config_secret_access_key", x.secret_access_key);
668
                })
669
            },
670
        );
671
    }
672
673
    #[test]
674
    fn test_credential_profile_loader_from_shared() {
675
        let _ = env_logger::builder().is_test(true).try_init();
676
677
        temp_env::with_vars(
678
            vec![
679
                (AWS_ACCESS_KEY_ID, None),
680
                (AWS_SECRET_ACCESS_KEY, None),
681
                (
682
                    AWS_CONFIG_FILE,
683
                    Some(format!(
684
                        "{}/testdata/services/aws/not_exist",
685
                        env::current_dir()
686
                            .expect("load must exist")
687
                            .to_string_lossy()
688
                    )),
689
                ),
690
                (
691
                    AWS_SHARED_CREDENTIALS_FILE,
692
                    Some(format!(
693
                        "{}/testdata/services/aws/default_credential",
694
                        env::current_dir()
695
                            .expect("load must exist")
696
                            .to_string_lossy()
697
                    )),
698
                ),
699
            ],
700
            || {
701
                RUNTIME.block_on(async {
702
                    let l = DefaultLoader::new(
703
                        Client::new(),
704
                        Config::default().from_env().from_profile(),
705
                    );
706
                    let x = l.load().await.unwrap().unwrap();
707
                    assert_eq!("shared_access_key_id", x.access_key_id);
708
                    assert_eq!("shared_secret_access_key", x.secret_access_key);
709
                })
710
            },
711
        );
712
    }
713
714
    /// AWS_SHARED_CREDENTIALS_FILE should be taken first.
715
    #[test]
716
    fn test_credential_profile_loader_from_both() {
717
        let _ = env_logger::builder().is_test(true).try_init();
718
719
        temp_env::with_vars(
720
            vec![
721
                (AWS_ACCESS_KEY_ID, None),
722
                (AWS_SECRET_ACCESS_KEY, None),
723
                (
724
                    AWS_CONFIG_FILE,
725
                    Some(format!(
726
                        "{}/testdata/services/aws/default_config",
727
                        env::current_dir()
728
                            .expect("current_dir must exist")
729
                            .to_string_lossy()
730
                    )),
731
                ),
732
                (
733
                    AWS_SHARED_CREDENTIALS_FILE,
734
                    Some(format!(
735
                        "{}/testdata/services/aws/default_credential",
736
                        env::current_dir()
737
                            .expect("current_dir must exist")
738
                            .to_string_lossy()
739
                    )),
740
                ),
741
            ],
742
            || {
743
                RUNTIME.block_on(async {
744
                    let l = DefaultLoader::new(
745
                        Client::new(),
746
                        Config::default().from_env().from_profile(),
747
                    );
748
                    let x = l.load().await.expect("load must success").unwrap();
749
                    assert_eq!("shared_access_key_id", x.access_key_id);
750
                    assert_eq!("shared_secret_access_key", x.secret_access_key);
751
                })
752
            },
753
        );
754
    }
755
756
    #[test]
757
    fn test_signer_with_web_loader() -> Result<()> {
758
        let _ = env_logger::builder().is_test(true).try_init();
759
760
        dotenv::from_filename(".env").ok();
761
762
        if env::var("REQSIGN_AWS_S3_TEST").is_err()
763
            || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on"
764
        {
765
            return Ok(());
766
        }
767
768
        // Ignore test if role_arn not set
769
        let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") {
770
            v
771
        } else {
772
            return Ok(());
773
        };
774
775
        // let provider_arn = env::var("REQSIGN_AWS_PROVIDER_ARN").expect("REQSIGN_AWS_PROVIDER_ARN not exist");
776
        let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist");
777
778
        let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist");
779
        let file_path = format!(
780
            "{}/testdata/services/aws/web_identity_token_file",
781
            env::current_dir()
782
                .expect("current_dir must exist")
783
                .to_string_lossy()
784
        );
785
        fs::write(&file_path, github_token)?;
786
787
        temp_env::with_vars(
788
            vec![
789
                (AWS_REGION, Some(&region)),
790
                (AWS_ROLE_ARN, Some(&role_arn)),
791
                (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)),
792
            ],
793
            || {
794
                RUNTIME.block_on(async {
795
                    let config = Config::default().from_env();
796
                    let loader = DefaultLoader::new(reqwest::Client::new(), config);
797
798
                    let signer = Signer::new("s3", &region);
799
800
                    let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region);
801
                    let mut req = Request::new("");
802
                    *req.method_mut() = http::Method::GET;
803
                    *req.uri_mut() =
804
                        http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap();
805
806
                    let cred = loader
807
                        .load()
808
                        .await
809
                        .expect("credential must be valid")
810
                        .unwrap();
811
812
                    signer.sign(&mut req, &cred).expect("sign must success");
813
814
                    debug!("signed request url: {:?}", req.uri().to_string());
815
                    debug!("signed request: {:?}", req);
816
817
                    let client = Client::new();
818
                    let resp = client.execute(req.try_into().unwrap()).await.unwrap();
819
820
                    let status = resp.status();
821
                    debug!("got response: {:?}", resp);
822
                    debug!("got response content: {:?}", resp.text().await.unwrap());
823
                    assert_eq!(status, StatusCode::NOT_FOUND);
824
                })
825
            },
826
        );
827
828
        Ok(())
829
    }
830
831
    #[test]
832
    fn test_signer_with_web_loader_assume_role() -> Result<()> {
833
        let _ = env_logger::builder().is_test(true).try_init();
834
835
        dotenv::from_filename(".env").ok();
836
837
        if env::var("REQSIGN_AWS_S3_TEST").is_err()
838
            || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on"
839
        {
840
            return Ok(());
841
        }
842
843
        // Ignore test if role_arn not set
844
        let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") {
845
            v
846
        } else {
847
            return Ok(());
848
        };
849
        // Ignore test if assume_role_arn not set
850
        let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") {
851
            v
852
        } else {
853
            return Ok(());
854
        };
855
856
        let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist");
857
858
        let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist");
859
        let file_path = format!(
860
            "{}/testdata/services/aws/web_identity_token_file",
861
            env::current_dir()
862
                .expect("current_dir must exist")
863
                .to_string_lossy()
864
        );
865
        fs::write(&file_path, github_token)?;
866
867
        temp_env::with_vars(
868
            vec![
869
                (AWS_REGION, Some(&region)),
870
                (AWS_ROLE_ARN, Some(&role_arn)),
871
                (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)),
872
            ],
873
            || {
874
                RUNTIME.block_on(async {
875
                    let client = reqwest::Client::new();
876
                    let default_loader =
877
                        DefaultLoader::new(client.clone(), Config::default().from_env())
878
                            .with_disable_ec2_metadata();
879
880
                    let cfg = Config {
881
                        role_arn: Some(assume_role_arn.clone()),
882
                        region: Some(region.clone()),
883
                        sts_regional_endpoints: "regional".to_string(),
884
                        ..Default::default()
885
                    };
886
                    let loader =
887
                        AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader))
888
                            .expect("AssumeRoleLoader must be valid");
889
890
                    let signer = Signer::new("s3", &region);
891
                    let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region);
892
                    let mut req = Request::new("");
893
                    *req.method_mut() = http::Method::GET;
894
                    *req.uri_mut() =
895
                        http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap();
896
                    let cred = loader
897
                        .load()
898
                        .await
899
                        .expect("credential must be valid")
900
                        .unwrap();
901
                    signer.sign(&mut req, &cred).expect("sign must success");
902
                    debug!("signed request url: {:?}", req.uri().to_string());
903
                    debug!("signed request: {:?}", req);
904
                    let client = Client::new();
905
                    let resp = client.execute(req.try_into().unwrap()).await.unwrap();
906
                    let status = resp.status();
907
                    debug!("got response: {:?}", resp);
908
                    debug!("got response content: {:?}", resp.text().await.unwrap());
909
                    assert_eq!(status, StatusCode::NOT_FOUND);
910
                })
911
            },
912
        );
913
        Ok(())
914
    }
915
916
    #[test]
917
    fn test_parse_assume_role_with_web_identity_response() -> Result<()> {
918
        let _ = env_logger::builder().is_test(true).try_init();
919
920
        let content = r#"<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
921
  <AssumeRoleWithWebIdentityResult>
922
    <Audience>test_audience</Audience>
923
    <AssumedRoleUser>
924
      <AssumedRoleId>role_id:reqsign</AssumedRoleId>
925
      <Arn>arn:aws:sts::123:assumed-role/reqsign/reqsign</Arn>
926
    </AssumedRoleUser>
927
    <Provider>arn:aws:iam::123:oidc-provider/example.com/</Provider>
928
    <Credentials>
929
      <AccessKeyId>access_key_id</AccessKeyId>
930
      <SecretAccessKey>secret_access_key</SecretAccessKey>
931
      <SessionToken>session_token</SessionToken>
932
      <Expiration>2022-05-25T11:45:17Z</Expiration>
933
    </Credentials>
934
    <SubjectFromWebIdentityToken>subject</SubjectFromWebIdentityToken>
935
  </AssumeRoleWithWebIdentityResult>
936
  <ResponseMetadata>
937
    <RequestId>b1663ad1-23ab-45e9-b465-9af30b202eba</RequestId>
938
  </ResponseMetadata>
939
</AssumeRoleWithWebIdentityResponse>"#;
940
941
        let resp: AssumeRoleWithWebIdentityResponse =
942
            de::from_str(content).expect("xml deserialize must success");
943
944
        assert_eq!(&resp.result.credentials.access_key_id, "access_key_id");
945
        assert_eq!(
946
            &resp.result.credentials.secret_access_key,
947
            "secret_access_key"
948
        );
949
        assert_eq!(&resp.result.credentials.session_token, "session_token");
950
        assert_eq!(&resp.result.credentials.expiration, "2022-05-25T11:45:17Z");
951
952
        Ok(())
953
    }
954
955
    #[test]
956
    fn test_parse_assume_role_response() -> Result<()> {
957
        let _ = env_logger::builder().is_test(true).try_init();
958
959
        let content = r#"<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
960
  <AssumeRoleResult>
961
  <SourceIdentity>Alice</SourceIdentity>
962
    <AssumedRoleUser>
963
      <Arn>arn:aws:sts::123456789012:assumed-role/demo/TestAR</Arn>
964
      <AssumedRoleId>ARO123EXAMPLE123:TestAR</AssumedRoleId>
965
    </AssumedRoleUser>
966
    <Credentials>
967
      <AccessKeyId>ASIAIOSFODNN7EXAMPLE</AccessKeyId>
968
      <SecretAccessKey>wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY</SecretAccessKey>
969
      <SessionToken>
970
       AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW
971
       LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd
972
       QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU
973
       9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz
974
       +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==
975
      </SessionToken>
976
      <Expiration>2019-11-09T13:34:41Z</Expiration>
977
    </Credentials>
978
    <PackedPolicySize>6</PackedPolicySize>
979
  </AssumeRoleResult>
980
  <ResponseMetadata>
981
    <RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
982
  </ResponseMetadata>
983
</AssumeRoleResponse>"#;
984
985
        let resp: AssumeRoleResponse = de::from_str(content).expect("xml deserialize must success");
986
987
        assert_eq!(
988
            &resp.result.credentials.access_key_id,
989
            "ASIAIOSFODNN7EXAMPLE"
990
        );
991
        assert_eq!(
992
            &resp.result.credentials.secret_access_key,
993
            "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"
994
        );
995
        assert_eq!(
996
            &resp.result.credentials.session_token,
997
            "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW
998
       LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd
999
       QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU
1000
       9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz
1001
       +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA=="
1002
        );
1003
        assert_eq!(&resp.result.credentials.expiration, "2019-11-09T13:34:41Z");
1004
1005
        Ok(())
1006
    }
1007
}