Rust 代码挑战系列 5 - 状态共享

这是一个学习 Rust 的小系列文章,通过完成来自 shuttle 平台所举办的 2023 Christmas Code Hunt 里的每个小挑战,来学习 rust web 框架的使用,此为第五篇文章。

本篇内 Part 12 涉及状态共享、时间的处理。

如何在不同路由间共享状态数据?在官方文档的 Sharing state with handlers 章节中提到了相关内容,我们将使用 State 提取器来实现状态共享。

在进行之前,我们先来了解/回顾一下 rust 中的多线程知识,以便更轻松地使用 State

Rust 多线程知识回顾

线程的基本使用

fn main() {
    println!("start");

    let th1 = std::thread::spawn(|| {
        println!("thread1");
    });
    let th2 = std::thread::spawn(|| {
        println!("thread2");
    });

    th1.join();
    th2.join();

    println!("done");
}

使用 std::thread 模块的 spawn() 来创建线程,调用 join() 来等待对应线程执行完毕,若不调用 join,主程序执行到尾部后会直接退出,此时线程可能还未执行。

多个线程访问同一个数据

#[derive(Debug, Clone, Copy)]
struct Counter {
    count: u32,
}

fn main() {
    let counter = Counter { count: 0 };

    // th1 里修改 count 数值
    let th1 = std::thread::spawn(move || {
        let mut c = counter;
        c.count = 1;
        println!("thread1: {:?}", c.count);
    });
    th1.join().unwrap();

    // th2 里仅访问 count 数值
    let th2 = std::thread::spawn(move || {
        let c = counter;
        println!("thread2: {:?}", c.count);
    });
    th2.join().unwrap();
}

上述代码中,当我们移除 CounterCloneCopy traits,会发现编译器报错 "move occurs because counter has type Counter, which does not implement the Copy trait",因为此时根据 rust 的 move 机制,counter 变量的所有权已经被移到了 th1 内,此时 th2 不能继续使用最开始的 counter 变量了,

实现 Copy trait 需要同时实现 Clone trait,按照报错提示加回 CloneCopy 后,现在来运行一下程序,会发现 th2 打印的 c.count 是 0 而非是以为的 1,这是因为两个线程内的 counter 并非指向同一个 counter,而是各自隐式 copy 了一份 counter,这也是为何没有所有权的报错,因为 rust 背地里进行了 copy。

读者们可重新回顾官方文档的理解所有权章节进一步加深认识。

如何让线程间共享同一个 counter 呢?

因为涉及数据的修改,我们需要使用 Mutex struct 为我们提供一把锁,确保同一时间只能有一个线程访问该数据,确保数据被正确更新:

// 为 Counter 添加上锁能力
let counter = std::sync::Mutex::new(Counter { count: 0 });

// 当要读写数据时,调用 lock() 锁住,然后进行其他操作
let mut c = counter.lock().unwrap();

现在代码变成了这样:

fn demo() {
    let counter = std::sync::Mutex::new(Counter { count: 0 });

    let th1 = std::thread::spawn(move || {
        let mut c = counter.lock().unwrap();
        c.count = 1;
        println!("thread1: {:?}", c.count);
    });
    th1.join().unwrap();

    let th2 = std::thread::spawn(move || {
        let mut c = counter.lock().unwrap();
        println!("thread2: {:?}", c.count);
    });
    th2.join().unwrap();
}

但是由于 Mutex 并未实现 Copy,因此编译报错了 "move occurs because counter has type std::sync::Mutex<Counter>, which does not implement the Copy trait",由于 Mutex 是内部类型,我们没法给它添加 Copy trait,该怎么办呢?

Rust 为我们提供了 Rc(Reference count) 引用计数来实现让一个值可以同时拥有多个所有者,对于多线程场景则需要使用线程安全的版本 Arc(Atomic reference count) struct,通过调用 clone() 来增加引用计数,会返回一个指向原数据的指针:

let counter = std::sync::Arc::new(std::sync::Mutex::new(Counter { count: 0 }));

// 调用 clone() 来新增一个 owner 供其他地方消费
let c = counter.clone();

将上面两个 struct 结合使用后,现在代码变成这样了:

#[derive(Debug)]
struct Counter {
    count: u32,
}

fn main() {
    let counter = std::sync::Arc::new(std::sync::Mutex::new(Counter { count: 0 }));

    let c1 = counter.clone();
    let th1 = std::thread::spawn(move || {
        let mut c = c1.lock().unwrap();
        c.count = 1;
        println!("thread1: {:?}", c.count);
    });
    th1.join().unwrap();

    let c2 = counter.clone();
    let th2 = std::thread::spawn(move || {
        let c = c2.lock().unwrap();
        println!("thread2: {:?}", c.count);
    });
    th2.join().unwrap();
}

运行会发现 th2 里能够打印最新值了。在多线程场景,我们会经常碰到 ArcMutex 的搭配使用。

现在我们开始本 Part 的任务吧。

Part 12

Task 1

圣诞老人的礼品制造车间在以惊人的速度进行包装,为了收集礼物制作的数据,圣诞老人需要一个多秒表(multi-stopwatch),可以同时记录多个包裹的时间。现在需要一个接口保存包裹,另一个接口获取该包裹已经存放的时间:

# 输入输出示例
curl -X POST http://localhost:8000/12/save/packet20231212
sleep 2
curl http://localhost:8000/12/load/packet20231212
echo
sleep 2
curl http://localhost:8000/12/load/packet20231212
echo
curl -X POST http://localhost:8000/12/save/packet20231212
curl http://localhost:8000/12/load/packet20231212

# 大约 ~4 秒后:
2
4
0

新增路由:

Router.route("/12/save/:packet_id", post(handler::d12_1_save))
      .route("/12/load/:packet_id", get(handler::d12_1_load))
      .with_state(handler::D12State::default())

上述使用 .with_state() 为应用添加共享状态,现在看看 d12.rs 里如何写:

use std::{
  collections::HashMap,
  sync::{Arc, Mutex},
  time::{self, Duration, Instant},
};

// axum 的 state 必须要实现 Clone
#[derive(Clone)]
pub struct D12State {
  // HashMap 的 key 为 packet id,value 为当前时刻
  store: Arc<Mutex<HashMap<String, Instant>>>,
}

// 实现 Default trait 后,我们便可以调用 MyStruct::default() 返回一个默认值,不需要每次手动给属性赋值
impl Default for D12State {
  fn default() -> Self {
    Self {
      store: Default::default(),
    }
  }
}

pub async fn d12_1_save(Path(packet_id): Path<String>, State(state): State<D12State>) {
  let mut store = state.store.lock().unwrap();
  store.insert(packet_id, Instant::now());
}

pub async fn d12_1_load(Path(packet_id): Path<String>, State(state): State<D12State>) -> String {
  match state.store.lock().unwrap().get(&packet_id) {
    Some(time) => time.elapsed().as_secs().to_string(),
    None => 0.to_string(),
  }
}

上述中通过使用 State 提取器来获取到我们的全局状态 D12State,此文还接触到了 std::time 模块的 Instant struct,它表示一个单调不递减时钟(monotonically nondecreasing clock),调用 elapsed() 会返回 一个 Duration,表示一段时间(span of time)。

Task 2

圣诞老人喜欢老式的技术(old-school tech),现在看到一些包使用了现代的 ULID 标识符,我们需要用他能理解的老格式 UUID 给他看。

# 输入输出示例
curl -X POST http://localhost:8000/12/ulids \
  -H 'Content-Type: application/json' \
  -d '[
    "01BJQ0E1C3Z56ABCD0E11HYX4M",
    "01BJQ0E1C3Z56ABCD0E11HYX5N",
    "01BJQ0E1C3Z56ABCD0E11HYX6Q",
    "01BJQ0E1C3Z56ABCD0E11HYX7R",
    "01BJQ0E1C3Z56ABCD0E11HYX8P"
  ]'

[
  "015cae07-0583-f94c-a5b1-a070431f7516",
  "015cae07-0583-f94c-a5b1-a070431f74f8",
  "015cae07-0583-f94c-a5b1-a070431f74d7",
  "015cae07-0583-f94c-a5b1-a070431f74b5",
  "015cae07-0583-f94c-a5b1-a070431f7494"
]

这个任务中是将 ULID 转换为 UUID,关于 UUID 和 ULID 的简单介绍:都是用来表示唯一标识符的一种格式,后者包含了毫秒时间戳信息。

现在我们新增路由:

// 需要写在 with_state(..) 前面
Router.route("/12/ulids", post(handler::d12_2))

我们会使用 uliduuid crate 来处理 ULID 和 UUID:

cargo add uuid ulid

路由函数实现:

use ulid::Ulid;
use uuid::Uuid;

pub async fn d12_2(body: Json<Vec<String>>) -> Json<Vec<String>> {
    let result = body
        .iter()
        .map(|x| {
            let ulid = Ulid::from_string(&x).unwrap();
            let uuid: Uuid = Uuid::from_u128(ulid.0);
            uuid.to_string()
        })
        .rev()
        .collect::<Vec<_>>();

    Json::from(result)
}

该任务里除了了解到两种标识符,我们还会发现这里的命名竟然是 Ulid 而不像其他语言中常见的全大写 ULID 命名,因为在 rust 的命名规范中,对于 UpperCamelCase 的场景,字母缩写要当做一个单词:Uuid 而不是 UUIDUsize 而不是 USizeStdin 而不是 StdIn

Task 3

圣诞老人了解了 ULID 后,他需要我们帮他分析在车间角落里发现的一些包裹的生产日期。

# 输入输出示例
curl -X POST http://localhost:8000/12/ulids/5 \
  -H 'Content-Type: application/json' \
  -d '[
    "00WEGGF0G0J5HEYXS3D7RWZGV8",
    "76EP4G39R8JD1N8AQNYDVJBRCF",
    "018CJ7KMG0051CDCS3B7BFJ3AK",
    "00Y986KPG0AMGB78RD45E9109K",
    "010451HTG0NYWMPWCEXG6AJ8F2",
    "01HH9SJEG0KY16H81S3N1BMXM4",
    "01HH9SJEG0P9M22Z9VGHH9C8CX",
    "017F8YY0G0NQA16HHC2QT5JD6X",
    "03QCPC7P003V1NND3B3QJW72QJ"
  ]'

{
  "christmas eve": 3,
  "weekday": 1, // 要求0为周一,6为周日
  "in the future": 2,
  "LSB is 1": 5 // LSB = Least Significant Bit 
}

新增路由:

// 需要写在 with_state 前面
Router.route("/12/ulids/:weekday", post(handler::d12_3))

rust 里涉及日期处理时,几乎都会使用到 chrono 这个 crate,现在安装它:

cargo add chrono

路由函数:

use chrono::{DateTime, Datelike, Utc};

// 响应体定义
#[derive(Serialize, Default, Debug)]
pub struct DaysCount {
    #[serde(rename = "christmas eve")]
    christmas_eve: usize,
    weekday: usize,
    #[serde(rename = "in the future")]
    future: usize,
    #[serde(rename = "LSB is 1")]
    lsb1: usize,
}

pub async fn d12_3(path: Path<usize>, body: Json<Vec<String>>) -> Json<DaysCount> {
    let weekday = path.0;
    let mut result = DaysCount::default();
    body.iter().for_each(|x| {
        let ulid = Ulid::from_string(&x).unwrap();
        // ulid.datetime() 返回 std::time::SystemTime, 再调用 into() 将其转为 chrono 的日期数据类型
        let datetime: DateTime<Utc> = ulid.datetime().into();

        let month = datetime.month();
        let day = datetime.day();
        // 获取星期时,需要留意起始星期是周日还是周一,返回值是从 0 开始还是从 1 开始
        let wd = datetime.weekday().num_days_from_monday();

        if wd == weekday as u32 {
            result.weekday += 1;
        }

        if month == 12 && day == 24 {
            result.christmas_eve += 1;
        }

        if Utc::now() < datetime {
            result.future += 1;
        }

        if ulid.0 & 1 == 1 {
            result.lsb1 += 1;
        }
    });

    Json::from(result)
}

要实现本任务,需要对 chrono 的基本使用有所了解,大家可自行阅读器文档。

小结

不是在学习这个 crate,就是在安装下一个 crate 😅