Rust 代码挑战系列 8 - 使用 Websocket

这是一个学习 Rust 的小系列文章,通过完成来自 shuttle 平台所举办的 2023 Christmas Code Hunt 里的每个小挑战,来学习 rust web 框架的使用,此为第八篇文章。

本篇内 Part 18 是数据库的操作,Part 19 是在 rust 中操作 websocket。

前言

截止目前,我们的项目里其实存在这么两个问题:

state 组织

我们已经定义了多个 state,比如 D12StateD13State,给不同挑战篇章各自定义的各自的 state 的写法仅是针对这种特别场景,在实际的项目中到不太会是这种组织形式。

此外,在 axum 不能同时存在多个 State extractor,即 fn my_route(state12: State<D12Stat>, state13: State<D13Stat>) {} 的写法会编译报错。笔者一般会全局定义一个 AppState 存放所有不同路由函数所需的内容,所以在开始挑战前,我们把代码结构进行一下调整吧:

  1. main.rs 里定义 AppState,将之前分散各处的 state 内容都放在这里:

    #[derive(Clone)]
    pub struct AppState {
        pub pool: MySqlPool,
    
        pub d12_store: Arc<Mutex<HashMap<String, Instant>>>, // f64 == timestamp
    }
    
  2. 然后修改 with_state(),只需要一个 with_state,其他的都可移除:

    let router = Router::new()
            .route("/", get(demo))
            // ...
            .with_state(AppState {
                pool: pool.clone(),
                d12_store: Default::default(),
            });
    
  3. 然后是对路由函数进行修改,将 State<D1XXState> 替换成 State<AppState>,并调整对应代码

在 axum 文档中的 Substates 部分介绍了通过 FromRef 实现子状态,读者们感兴趣可前往了解,然后可尝试将 AppState 里的 d12_store 重构成 struct D12SubState { store: xxx, pool: xxx } 这种子状态形式,让代码更“美丽”。

路由组织

main.rs 里路由注册不是那么的“模块化”,我们通过使用 .nest 来重新组织下写法,下述是调整后的代码:

let r1 = Router::new().route("/*nums", get(handler::d1));
let r4 = Router::new()
    .route("/strength", post(handler::d4_1))
    .route("/contest", post(handler::d4_2));

let r5 = Router::new()
    // .route("/", post(handler::d5_1))
    .route("/", post(handler::d5_2));

let r6 = Router::new().route("/", post(handler::d6_2));

let r7 = Router::new()
    .route("/decode", get(handler::d7_1))
    .route("/bake", get(handler::d7_3));

let r8 = Router::new()
    .route("/weight/:id", get(handler::d8_1))
    .route("/drop/:id", get(handler::d8_2));

let r11 = Router::new()
    .nest_service("/assets", ServeDir::new("assets"))
    .route("/red_pixels", post(handler::d11_2));

let r12 = Router::new()
    .route("/save/:packet_id", post(handler::d12_1_save))
    .route("/load/:packet_id", get(handler::d12_1_load))
    .route("/ulids", post(handler::d12_2))
    .route("/ulids/:weekday", post(handler::d12_3));

let r13 = Router::new()
    .route("/sql", get(handler::d13_1))
    .route("/reset", post(handler::d13_2_reset))
    .route("/orders", post(handler::d13_2_orders))
    .route("/orders/total", get(handler::d13_2_total))
    .route("/orders/popular", get(handler::d13_3));

let r14 = Router::new()
    .route("/unsafe", post(handler::d14_1))
    .route("/safe", post(handler::d14_2));

let r15 = Router::new()
    .route("/nice", post(handler::d15_1))
    .route("/game", post(handler::d15_2));

let router = Router::new()
    .route("/", get(demo))
    .nest("/1", r1)
    .nest("/4", r4)
    .nest("/5", r5)
    .nest("/6", r6)
    .nest("/7", r7)
    .nest("/8", r8)
    .nest("/11", r11)
    .nest("/12", r12)
    .nest("/13", r13)
    .nest("/14", r14)
    .nest("/15", r15)
    .with_state(AppState {
        pool: pool.clone(),
        d12_store: Default::default(),
    });

读者们可进一步对文件结构进行调整,并且可以一个路由函数放到一个文件里,从而让项目结构变得更加清晰和可维护

现在我们可以开始挑战了

Part 18

该部分涉及的所有路由如下:

let r18 = Router::new()
    .route("/reset", post(handler::d18_1_reset))
    .route("/orders", post(handler::d18_1_orders))
    .route("/regions", post(handler::d18_1_regions))
    .route("/regions/total", get(handler::d18_1_total))
    .route("/regions/top_list/:num", get(handler::d18_2));

let router = Router::new()
    .nest("/18", r18)

该部分主要是对 sql 能力的考验,顺便强化下 sqlx 包的使用,下面是相关代码和一些说明:

Task 1

对数据库进行重置和初始化,使用事务来执行多条语句:

pub async fn d18_1_reset(state: State<AppState>) {
    let pool = state.pool.clone();
    let mut transaction = pool.begin().await.unwrap();

    sqlx::query(r#"DROP TABLE IF EXISTS regions;"#)
        .execute(&mut *transaction)
        .await
        .unwrap();
    sqlx::query(r#"DROP TABLE IF EXISTS orders;"#)
        .execute(&mut *transaction)
        .await
        .unwrap();
    sqlx::query(r#"CREATE TABLE regions (id INT PRIMARY KEY, name VARCHAR(50));"#)
        .execute(&mut *transaction)
        .await
        .unwrap();
    sqlx::query(
        r#"CREATE TABLE orders (
    id INT PRIMARY KEY,
    region_id INT,
    gift_name VARCHAR(50),
    quantity INT
  );"#,
    )
    .execute(&mut *transaction)
    .await
    .unwrap();
}

插入相关数据的逻辑,同样使用了事务:

#[derive(Debug, Serialize, Deserialize)]
pub struct Order {
    id: i32,
    region_id: i32,
    gift_name: String,
    quantity: i32,
}
pub async fn d18_1_orders(state: State<AppState>, body: Json<Vec<Order>>) {
    let pool = state.pool.clone();
    let mut transaction = pool.begin().await.unwrap();
    for item in body.0 {
        sqlx::query(
            r#"
  INSERT INTO orders (id, region_id, gift_name, quantity) VALUES (?, ?, ?, ?)
  "#,
        )
        .bind(item.id)
        .bind(item.region_id)
        .bind(item.gift_name)
        .bind(item.quantity)
        .execute(&mut *transaction)
        .await
        .unwrap();
    }
    transaction.commit().await.unwrap();
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Region {
    id: i32,
    name: String,
}
pub async fn d18_1_regions(state: State<AppState>, body: Json<Vec<Region>>) {
    let pool = state.pool.clone();
    let transaction = pool.begin().await.unwrap();

    for item in body.0 {
        sqlx::query(
            r#"
INSERT INTO regions (id, name) VALUES (?, ?)
"#,
        )
        .bind(item.id)
        .bind(item.name)
        .execute(&pool)
        .await
        .unwrap();
    }
    transaction.commit().await.unwrap();
}

查询每个地区所销售的礼物数:

#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct TotalQueryRes {
    region: String,
    total: i64,
}
pub async fn d18_1_total(state: State<AppState>) -> Json<Vec<TotalQueryRes>> {
    let pool = state.pool.clone();
    let res = sqlx::query_as::<_, TotalQueryRes>(
        r#"
SELECT regions.name AS region, CAST(SUM(orders.quantity) AS SIGNED) AS total FROM orders JOIN regions ON orders.region_id = regions.id GROUP BY orders.gift_name, region ORDER BY region;
"#,
    )
    .fetch_all(&pool)
    .await
    .unwrap();

    return Json::from(res);
}

我们需要确保 SQL 的返回值需要和 rust 定义的类型匹配,否则会运行时报错,其中 SQL 里的 SUM() 函数返回值对应 sqlx 里的 Decimal 类型,这里为了避免在 rust 里写一遍把 Decimal 转成数值类型的“冗长”代码,就直接在 SQL 语句里使用 CAST() 将返回类型转成了数值类型。

Task 2

查询礼物并返回 top n

#[derive(sqlx::FromRow, Debug, Serialize)]
pub struct D182SqlRes {
    region: String,
    gift_name: Option<String>,
    total: Option<i64>,
}

#[derive(Debug, Serialize)]
pub struct D182Res {
    region: String,
    top_gifts: Vec<String>,
}
pub async fn d18_2(state: State<AppState>, Path(num): Path<usize>) -> Json<Vec<D182Res>> {
    let pool = state.pool.clone();
    // 使用 RIGHT OUTER JOIN 确保能查询到所有的 region
    let res = sqlx::query_as::<_, D182SqlRes>(
      r#"
SELECT regions.name AS region, gift_name, CAST(SUM(orders.quantity) AS SIGNED) AS total FROM orders RIGHT OUTER JOIN regions ON orders.region_id = regions.id GROUP BY orders.gift_name, region  ORDER BY region ASC, total DESC, gift_name ASC;
"#,
  )
  .fetch_all(&pool)
  .await
  .unwrap();

    let mut gifts_map: BTreeMap<String, Vec<String>> = BTreeMap::new();
    res.iter().for_each(|item| {
        let gift_name = item.gift_name.as_ref();
        match gift_name {
            Some(gift_name) => {
                gifts_map
                    .entry(item.region.clone())
                    .and_modify(|gifts| {
                        gifts.push(gift_name.clone());
                    })
                    .or_insert(vec![gift_name.clone()]);
            }
            None => {
                gifts_map.insert(item.region.clone(), vec![]);
            }
        }
    });

    let ret = gifts_map
        .iter()
        .map(|item| {
            return D182Res {
                region: item.0.to_owned(),
                // 获取前 n 个元素
                top_gifts: item.1.iter().take(num).cloned().collect::<Vec<String>>(),
            };
        })
        .collect::<Vec<D182Res>>();

    return Json::from(ret);
}

这个路由函数中,SQL 语句的内容不做讨论,主要看一些和 rust 相关的点:

Part 19

该部分涉及 websocket 的相关知识,在开始挑战前,我们先简单回顾下 ws 的一些知识:

websocket 可以实现客户端和服务端的双向主动通信,常用于比如即时聊天、即时游戏等,以及前端开发中 webpack-dev-server 也使用了 websocket 传递代码变更信息。websocket 连接的建立需要通过特定的握手过程,一般是通过 http/1.1 协议的 GET 接口进行协议升级,当连接建立后,两端便可以进行实时双向通信。

在 axum 的文档里我们可以学习了解到如何使用 websocket,当然要记得开启 ws feature:

axum = { version = "0.7.2", features = ["ws"] }

现在来看看本节内容,该部分涉及的所有路由和全局状态新增如下:

pub struct RoomState {
    pub sender: broadcast::Sender<String>,
}

impl Default for RoomState {
    fn default() -> Self {
        Self {
            // 此处的 capacity 需要设置大一些,不能是比如 10,太小会导致未被接收的消息丢弃而导致不符合业务逻辑
            sender: broadcast::channel(1024).0,
        }
    }
}

#[derive(Clone)]
pub struct AppState {
    pub d19_started: bool,
    pub d19_rooms: Arc<Mutex<HashMap<usize, RoomState>>>,
    pub d19_count: Arc<AtomicU32>,
}

let r19 = Router::new()
    .route("/ws/ping", get(handler::d19_1))
    .route("/reset", post(handler::d19_2_reset))
    .route("/views", get(handler::d19_2_views))
    .route("/ws/room/:room/user/:user", get(handler::d19_2_room));

let router = Router::new()
    .nest("/19", r19)
    .with_state(AppState {
        // ...
        d19_started: false,
        d19_count: Default::default(),
        d19_rooms: Default::default(),
    });

Task 1

该任务里实现一个 websocket 来处理客户端发来的 "ping"、"pong"、"server" 指令:

pub async fn d19_1(ws: WebSocketUpgrade, mut state: State<AppState>) -> Response {
    ws.on_upgrade(|mut socket| async move {
        while let Some(msg) = socket.recv().await {
            let msg = if let Ok(msg) = msg {
                msg
            } else {
                // client disconnected
                return;
            };
            if let Ok(message) = msg.to_text() {
                if message == "serve" {
                    state.d19_started = true;
                }
                if message == "ping" {
                    if state.d19_started {
                        let _ = socket.send(Message::Text("pong".to_string())).await;
                    }
                }
            }
        }
    })
}

上述使用 WebSocketUpgrade 处理 ws 内容,在 on_upgrade 的回调函数里处理连接建立后的逻辑。

对于经常使用 NodeJS 的伙伴来说,上面出现的 while 循环写法让人感到很奇怪,在 NodeJS 中几乎都是基于事件的写法,比如:

ws.on('open', function open() {
  ws.send('something');
});

ws.on('message', function message(data) {
  console.log('received: %s', data);
});

而 rust 并非是基于事件的设计,所以在编写代码时还需转变一下心智。

Task 2

这个挑战中需要实现一个进房间、发帖子、统计阅读数的功能。

该挑战里将会使用到 AppState 里新增的另外两个状态:

下述 d19_2_reset 实现了重置房间的功能:

pub async fn d19_2_reset(state: State<AppState>) {
    let rooms = state.d19_rooms.clone();
    let mut rooms = rooms.lock().unwrap();
    rooms.clear();
    state.d19_count.store(0, Ordering::Relaxed);
}

Ordering::Relaxed 和内存排序有关,读者们可自行网上查阅相关资料学习。

下属是获取当前的帖子浏览总数:

pub async fn d19_2_views(state: State<AppState>) -> String {
    let c = state.d19_count.clone();
    let v = c.load(std::sync::atomic::Ordering::Relaxed).to_string();
    v
}

下属则为本节最具有挑战的部分,实现简单聊天室。

在开始前先来了解一些必要知识: axum 所提供的 websocket 并没有广播功能(用惯了 NodeJS 的 ws 包都差点误以为 websocket 先天自带广播功能),我们需要自行实现,其中会涉及到这四个概念:

ws_receiver 和 ws_sender 可以通过 on_upgrade 回调函数的参数来获取到。sender 和 receiver 则需要自行实现,每个 sender/receiver 会存在于不同的线程中,便涉及到了线程间的通信,常见模式有 MPMC(multi-producer multi-consumer)、MPSC(multi-producer, single-consumer),我们要使用的是 MPMC,但 rust 仅内置提供了 std::sync::mpsc,不满足我们的需求,我们将使用到 tokio 提供的 broadcast 来实现广播功能:

pub async fn d19_2_room(
    ws: WebSocketUpgrade,
    Path((room_id, user)): Path<(usize, String)>,
    state: State<AppState>,
) -> Response {
    ws.on_upgrade(move |socket| async move {
        let sender = {
            let mut rooms = state.d19_rooms.lock().unwrap();
            rooms
                .entry(room_id)
                .or_insert_with(|| {
                    let mut room = crate::RoomState::default();
                    room
                })
                .sender
                .clone()
        };
        let mut receiver = sender.subscribe();

        let (mut ws_sender, mut ws_receiver) = socket.split();

        // 监听客户端传来的消息,
        let mut send_task = tokio::spawn(async move {
            let user: String = user.clone();
            while let Some(msg) = ws_receiver.next().await {
                match msg {
                    Ok(msg) => {
                        match msg {
                            Message::Text(msg) => {
                                if let Ok(user_msg) = serde_json::from_str::<UserMessage>(&msg) {
                                    if user_msg.message.len() <= 128 {
                                        // Send to all users in the same room
                                        let _ = sender.send(
                                            serde_json::to_string(&UserSentMessage {
                                                user: user.clone(),
                                                message: user_msg.message,
                                            })
                                            .unwrap(),
                                        );
                                    }
                                } else {
                                    println!("{} sent unhandled format msg: {:?}", user, msg);
                                }
                            }
                            Message::Binary(_) | Message::Ping(_) | Message::Pong(_) => {
                                println!("{} sent unhandled type msg: {:?}", user, msg);
                            }
                            Message::Close(_) => {
                                break;
                            }
                        }
                    }
                    Err(err) => {
                        break;
                    }
                }
            }
        });

        let count = state.d19_count.clone();
        // 聊天室里的人监听消息接收
        let mut listen_task = tokio::spawn(async move {
            loop {
                match receiver.recv().await {
                    Ok(msg) => {
                        // 浏览+1
                        count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                        // 发回到客户端
                        let m = serde_json::from_str::<UserSentMessage>(&msg).unwrap();
                        let _ = ws_sender
                            .send(Message::Text(serde_json::to_string(&m).unwrap()))
                            .await;
                    }
                    Err(err) => {
                        println!("err: {:?}", err);
                        break;
                    }
                }
            }
        });

        // 当一个任务结束后,需要终止对应的另一个任务
        tokio::select! {
            r = (&mut send_task) => {
                println!("abort listen_task, {:?}", r);
                listen_task.abort()
            },
            r = (&mut listen_task) => {
                println!("abort send_task, {:?}", r);
                send_task.abort()
            },
        };
    })
}

上述使用 socket.split() 获取 ws_sender 和 ws_receiver,我们需要 cargo add futures 安装 futures crate,并需要导入相关内容:

use futures::{SinkExt, StreamExt};

sender 则是 AppState 里通过 broadcast::channel(1024).0 创建的,其中的 1024 是设置的 capacity 容量,此处不易设置过小,否则当端发送速度大于接收端处理速度时,如果积压了大于 capacity 的消息,则会丢弃掉最早的超出内容,从而造成帖子数记录不准确。

小结

rust 的异步相关知识让学习 rust 之旅变得更陡峭了,一定是 nodejs 把笔者“养”得太好了,赶紧挑灯继续阅读 Asynchronous Programming in Rust