Rust 代码挑战系列 8 - 使用 Websocket
这是一个学习 Rust 的小系列文章,通过完成来自 shuttle 平台所举办的 2023 Christmas Code Hunt 里的每个小挑战,来学习 rust web 框架的使用,此为第八篇文章。
本篇内 Part 18 是数据库的操作,Part 19 是在 rust 中操作 websocket。
前言
截止目前,我们的项目里其实存在这么两个问题:
state 组织
我们已经定义了多个 state,比如 D12State
、D13State
,给不同挑战篇章各自定义的各自的 state 的写法仅是针对这种特别场景,在实际的项目中到不太会是这种组织形式。
此外,在 axum 不能同时存在多个 State
extractor,即 fn my_route(state12: State<D12Stat>, state13: State<D13Stat>) {}
的写法会编译报错。笔者一般会全局定义一个 AppState
存放所有不同路由函数所需的内容,所以在开始挑战前,我们把代码结构进行一下调整吧:
-
在
main.rs
里定义AppState
,将之前分散各处的 state 内容都放在这里:#[derive(Clone)] pub struct AppState { pub pool: MySqlPool, pub d12_store: Arc<Mutex<HashMap<String, Instant>>>, // f64 == timestamp }
-
然后修改
with_state()
,只需要一个with_state
,其他的都可移除:let router = Router::new() .route("/", get(demo)) // ... .with_state(AppState { pool: pool.clone(), d12_store: Default::default(), });
-
然后是对路由函数进行修改,将
State<D1XXState>
替换成State<AppState>
,并调整对应代码
在 axum 文档中的 Substates 部分介绍了通过 FromRef
实现子状态,读者们感兴趣可前往了解,然后可尝试将 AppState
里的 d12_store
重构成 struct D12SubState { store: xxx, pool: xxx }
这种子状态形式,让代码更“美丽”。
路由组织
main.rs
里路由注册不是那么的“模块化”,我们通过使用 .nest
来重新组织下写法,下述是调整后的代码:
let r1 = Router::new().route("/*nums", get(handler::d1));
let r4 = Router::new()
.route("/strength", post(handler::d4_1))
.route("/contest", post(handler::d4_2));
let r5 = Router::new()
// .route("/", post(handler::d5_1))
.route("/", post(handler::d5_2));
let r6 = Router::new().route("/", post(handler::d6_2));
let r7 = Router::new()
.route("/decode", get(handler::d7_1))
.route("/bake", get(handler::d7_3));
let r8 = Router::new()
.route("/weight/:id", get(handler::d8_1))
.route("/drop/:id", get(handler::d8_2));
let r11 = Router::new()
.nest_service("/assets", ServeDir::new("assets"))
.route("/red_pixels", post(handler::d11_2));
let r12 = Router::new()
.route("/save/:packet_id", post(handler::d12_1_save))
.route("/load/:packet_id", get(handler::d12_1_load))
.route("/ulids", post(handler::d12_2))
.route("/ulids/:weekday", post(handler::d12_3));
let r13 = Router::new()
.route("/sql", get(handler::d13_1))
.route("/reset", post(handler::d13_2_reset))
.route("/orders", post(handler::d13_2_orders))
.route("/orders/total", get(handler::d13_2_total))
.route("/orders/popular", get(handler::d13_3));
let r14 = Router::new()
.route("/unsafe", post(handler::d14_1))
.route("/safe", post(handler::d14_2));
let r15 = Router::new()
.route("/nice", post(handler::d15_1))
.route("/game", post(handler::d15_2));
let router = Router::new()
.route("/", get(demo))
.nest("/1", r1)
.nest("/4", r4)
.nest("/5", r5)
.nest("/6", r6)
.nest("/7", r7)
.nest("/8", r8)
.nest("/11", r11)
.nest("/12", r12)
.nest("/13", r13)
.nest("/14", r14)
.nest("/15", r15)
.with_state(AppState {
pool: pool.clone(),
d12_store: Default::default(),
});
读者们可进一步对文件结构进行调整,并且可以一个路由函数放到一个文件里,从而让项目结构变得更加清晰和可维护
现在我们可以开始挑战了
Part 18
该部分涉及的所有路由如下:
let r18 = Router::new()
.route("/reset", post(handler::d18_1_reset))
.route("/orders", post(handler::d18_1_orders))
.route("/regions", post(handler::d18_1_regions))
.route("/regions/total", get(handler::d18_1_total))
.route("/regions/top_list/:num", get(handler::d18_2));
let router = Router::new()
.nest("/18", r18)
该部分主要是对 sql 能力的考验,顺便强化下 sqlx
包的使用,下面是相关代码和一些说明:
Task 1
对数据库进行重置和初始化,使用事务来执行多条语句:
pub async fn d18_1_reset(state: State<AppState>) {
let pool = state.pool.clone();
let mut transaction = pool.begin().await.unwrap();
sqlx::query(r#"DROP TABLE IF EXISTS regions;"#)
.execute(&mut *transaction)
.await
.unwrap();
sqlx::query(r#"DROP TABLE IF EXISTS orders;"#)
.execute(&mut *transaction)
.await
.unwrap();
sqlx::query(r#"CREATE TABLE regions (id INT PRIMARY KEY, name VARCHAR(50));"#)
.execute(&mut *transaction)
.await
.unwrap();
sqlx::query(
r#"CREATE TABLE orders (
id INT PRIMARY KEY,
region_id INT,
gift_name VARCHAR(50),
quantity INT
);"#,
)
.execute(&mut *transaction)
.await
.unwrap();
}
插入相关数据的逻辑,同样使用了事务:
#[derive(Debug, Serialize, Deserialize)]
pub struct Order {
id: i32,
region_id: i32,
gift_name: String,
quantity: i32,
}
pub async fn d18_1_orders(state: State<AppState>, body: Json<Vec<Order>>) {
let pool = state.pool.clone();
let mut transaction = pool.begin().await.unwrap();
for item in body.0 {
sqlx::query(
r#"
INSERT INTO orders (id, region_id, gift_name, quantity) VALUES (?, ?, ?, ?)
"#,
)
.bind(item.id)
.bind(item.region_id)
.bind(item.gift_name)
.bind(item.quantity)
.execute(&mut *transaction)
.await
.unwrap();
}
transaction.commit().await.unwrap();
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Region {
id: i32,
name: String,
}
pub async fn d18_1_regions(state: State<AppState>, body: Json<Vec<Region>>) {
let pool = state.pool.clone();
let transaction = pool.begin().await.unwrap();
for item in body.0 {
sqlx::query(
r#"
INSERT INTO regions (id, name) VALUES (?, ?)
"#,
)
.bind(item.id)
.bind(item.name)
.execute(&pool)
.await
.unwrap();
}
transaction.commit().await.unwrap();
}
查询每个地区所销售的礼物数:
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct TotalQueryRes {
region: String,
total: i64,
}
pub async fn d18_1_total(state: State<AppState>) -> Json<Vec<TotalQueryRes>> {
let pool = state.pool.clone();
let res = sqlx::query_as::<_, TotalQueryRes>(
r#"
SELECT regions.name AS region, CAST(SUM(orders.quantity) AS SIGNED) AS total FROM orders JOIN regions ON orders.region_id = regions.id GROUP BY orders.gift_name, region ORDER BY region;
"#,
)
.fetch_all(&pool)
.await
.unwrap();
return Json::from(res);
}
我们需要确保 SQL 的返回值需要和 rust 定义的类型匹配,否则会运行时报错,其中 SQL 里的 SUM()
函数返回值对应 sqlx 里的 Decimal
类型,这里为了避免在 rust 里写一遍把 Decimal
转成数值类型的“冗长”代码,就直接在 SQL 语句里使用 CAST()
将返回类型转成了数值类型。
Task 2
查询礼物并返回 top n
#[derive(sqlx::FromRow, Debug, Serialize)]
pub struct D182SqlRes {
region: String,
gift_name: Option<String>,
total: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct D182Res {
region: String,
top_gifts: Vec<String>,
}
pub async fn d18_2(state: State<AppState>, Path(num): Path<usize>) -> Json<Vec<D182Res>> {
let pool = state.pool.clone();
// 使用 RIGHT OUTER JOIN 确保能查询到所有的 region
let res = sqlx::query_as::<_, D182SqlRes>(
r#"
SELECT regions.name AS region, gift_name, CAST(SUM(orders.quantity) AS SIGNED) AS total FROM orders RIGHT OUTER JOIN regions ON orders.region_id = regions.id GROUP BY orders.gift_name, region ORDER BY region ASC, total DESC, gift_name ASC;
"#,
)
.fetch_all(&pool)
.await
.unwrap();
let mut gifts_map: BTreeMap<String, Vec<String>> = BTreeMap::new();
res.iter().for_each(|item| {
let gift_name = item.gift_name.as_ref();
match gift_name {
Some(gift_name) => {
gifts_map
.entry(item.region.clone())
.and_modify(|gifts| {
gifts.push(gift_name.clone());
})
.or_insert(vec![gift_name.clone()]);
}
None => {
gifts_map.insert(item.region.clone(), vec![]);
}
}
});
let ret = gifts_map
.iter()
.map(|item| {
return D182Res {
region: item.0.to_owned(),
// 获取前 n 个元素
top_gifts: item.1.iter().take(num).cloned().collect::<Vec<String>>(),
};
})
.collect::<Vec<D182Res>>();
return Json::from(ret);
}
这个路由函数中,SQL 语句的内容不做讨论,主要看一些和 rust 相关的点:
-
为何使用
BTreeMap
而不是HashMap
?前者会保留插入 key 的顺序,后者不会,本例中我们需要保证顺序。 -
HashMap
/BTreeMap
提供的and_modify
和or_insert
方法,可以更优雅的实现往一个对象的某个 key 值数组里插入内容时,如果不存在则新建数组,存在则插入的逻辑。如果不使用这两个方法,则我们需要这么写:let gifts = gifts_map.get_mut(&item.region); match gifts { Some(gifts) => { gifts.push(gift_name.clone()); } None => { gifts_map.insert(item.region.clone(), vec![gift_name.clone()]); } }
-
上述我们使用了
item.1.iter().take(num)
来获取前 n 个元素,而没有使用item.1[num..]
的 slice 写法,为何呢?因为这并不是一个安全的做法:let v = vec![1,2,3,4]; v[4..]; // ok 按预期返回一个空数组 v[5..]; // 运行时报错
从此 stackoverflow 帖子中学习到:在使用 slice index 时,如果开始索引的值大于 slice 长度,则会 panic
Part 19
该部分涉及 websocket 的相关知识,在开始挑战前,我们先简单回顾下 ws 的一些知识:
websocket 可以实现客户端和服务端的双向主动通信,常用于比如即时聊天、即时游戏等,以及前端开发中 webpack-dev-server 也使用了 websocket 传递代码变更信息。websocket 连接的建立需要通过特定的握手过程,一般是通过 http/1.1 协议的 GET 接口进行协议升级,当连接建立后,两端便可以进行实时双向通信。
在 axum 的文档里我们可以学习了解到如何使用 websocket,当然要记得开启 ws
feature:
axum = { version = "0.7.2", features = ["ws"] }
现在来看看本节内容,该部分涉及的所有路由和全局状态新增如下:
pub struct RoomState {
pub sender: broadcast::Sender<String>,
}
impl Default for RoomState {
fn default() -> Self {
Self {
// 此处的 capacity 需要设置大一些,不能是比如 10,太小会导致未被接收的消息丢弃而导致不符合业务逻辑
sender: broadcast::channel(1024).0,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub d19_started: bool,
pub d19_rooms: Arc<Mutex<HashMap<usize, RoomState>>>,
pub d19_count: Arc<AtomicU32>,
}
let r19 = Router::new()
.route("/ws/ping", get(handler::d19_1))
.route("/reset", post(handler::d19_2_reset))
.route("/views", get(handler::d19_2_views))
.route("/ws/room/:room/user/:user", get(handler::d19_2_room));
let router = Router::new()
.nest("/19", r19)
.with_state(AppState {
// ...
d19_started: false,
d19_count: Default::default(),
d19_rooms: Default::default(),
});
Task 1
该任务里实现一个 websocket 来处理客户端发来的 "ping"、"pong"、"server" 指令:
pub async fn d19_1(ws: WebSocketUpgrade, mut state: State<AppState>) -> Response {
ws.on_upgrade(|mut socket| async move {
while let Some(msg) = socket.recv().await {
let msg = if let Ok(msg) = msg {
msg
} else {
// client disconnected
return;
};
if let Ok(message) = msg.to_text() {
if message == "serve" {
state.d19_started = true;
}
if message == "ping" {
if state.d19_started {
let _ = socket.send(Message::Text("pong".to_string())).await;
}
}
}
}
})
}
上述使用 WebSocketUpgrade
处理 ws 内容,在 on_upgrade
的回调函数里处理连接建立后的逻辑。
对于经常使用 NodeJS 的伙伴来说,上面出现的 while 循环写法让人感到很奇怪,在 NodeJS 中几乎都是基于事件的写法,比如:
ws.on('open', function open() {
ws.send('something');
});
ws.on('message', function message(data) {
console.log('received: %s', data);
});
而 rust 并非是基于事件的设计,所以在编写代码时还需转变一下心智。
Task 2
这个挑战中需要实现一个进房间、发帖子、统计阅读数的功能。
该挑战里将会使用到 AppState
里新增的另外两个状态:
d19_rooms: Arc<Mutex<HashMap<usize, RoomState>>>
: 存放房间号和对应房间的信息d19_count: Arc<AtomicU32>
: 存放帖子的总阅读数。std::symc::atomic
模块里提供了原始类型的的原子版本,即可以在多线程环境下正常使用的版本,否则就得手动写成比如Arc<Mutex<u32>>
下述 d19_2_reset
实现了重置房间的功能:
pub async fn d19_2_reset(state: State<AppState>) {
let rooms = state.d19_rooms.clone();
let mut rooms = rooms.lock().unwrap();
rooms.clear();
state.d19_count.store(0, Ordering::Relaxed);
}
Ordering::Relaxed
和内存排序有关,读者们可自行网上查阅相关资料学习。
下属是获取当前的帖子浏览总数:
pub async fn d19_2_views(state: State<AppState>) -> String {
let c = state.d19_count.clone();
let v = c.load(std::sync::atomic::Ordering::Relaxed).to_string();
v
}
下属则为本节最具有挑战的部分,实现简单聊天室。
在开始前先来了解一些必要知识: axum 所提供的 websocket 并没有广播功能(用惯了 NodeJS 的 ws 包都差点误以为 websocket 先天自带广播功能),我们需要自行实现,其中会涉及到这四个概念:
- ws_receiver:用来获取客户端的用户发来的信息
- ws_sender:用来给客户端发送消息
- broadcast sender(sender):用来发送广播。当用户发送消息后,ws_receiver 会接收到消息,然后借助 sender 发送广播给相同房间的 receivers
- broadcast receiver(receiver):用来接收 sender 广播的消息。收到消息后,会通过 ws_sender 将消息发回给自己的 ws_client
ws_receiver 和 ws_sender 可以通过 on_upgrade 回调函数的参数来获取到。sender 和 receiver 则需要自行实现,每个 sender/receiver 会存在于不同的线程中,便涉及到了线程间的通信,常见模式有 MPMC(multi-producer multi-consumer)、MPSC(multi-producer, single-consumer),我们要使用的是 MPMC,但 rust 仅内置提供了 std::sync::mpsc
,不满足我们的需求,我们将使用到 tokio 提供的 broadcast 来实现广播功能:
pub async fn d19_2_room(
ws: WebSocketUpgrade,
Path((room_id, user)): Path<(usize, String)>,
state: State<AppState>,
) -> Response {
ws.on_upgrade(move |socket| async move {
let sender = {
let mut rooms = state.d19_rooms.lock().unwrap();
rooms
.entry(room_id)
.or_insert_with(|| {
let mut room = crate::RoomState::default();
room
})
.sender
.clone()
};
let mut receiver = sender.subscribe();
let (mut ws_sender, mut ws_receiver) = socket.split();
// 监听客户端传来的消息,
let mut send_task = tokio::spawn(async move {
let user: String = user.clone();
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(msg) => {
match msg {
Message::Text(msg) => {
if let Ok(user_msg) = serde_json::from_str::<UserMessage>(&msg) {
if user_msg.message.len() <= 128 {
// Send to all users in the same room
let _ = sender.send(
serde_json::to_string(&UserSentMessage {
user: user.clone(),
message: user_msg.message,
})
.unwrap(),
);
}
} else {
println!("{} sent unhandled format msg: {:?}", user, msg);
}
}
Message::Binary(_) | Message::Ping(_) | Message::Pong(_) => {
println!("{} sent unhandled type msg: {:?}", user, msg);
}
Message::Close(_) => {
break;
}
}
}
Err(err) => {
break;
}
}
}
});
let count = state.d19_count.clone();
// 聊天室里的人监听消息接收
let mut listen_task = tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(msg) => {
// 浏览+1
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// 发回到客户端
let m = serde_json::from_str::<UserSentMessage>(&msg).unwrap();
let _ = ws_sender
.send(Message::Text(serde_json::to_string(&m).unwrap()))
.await;
}
Err(err) => {
println!("err: {:?}", err);
break;
}
}
}
});
// 当一个任务结束后,需要终止对应的另一个任务
tokio::select! {
r = (&mut send_task) => {
println!("abort listen_task, {:?}", r);
listen_task.abort()
},
r = (&mut listen_task) => {
println!("abort send_task, {:?}", r);
send_task.abort()
},
};
})
}
上述使用 socket.split()
获取 ws_sender 和 ws_receiver,我们需要 cargo add futures
安装 futures
crate,并需要导入相关内容:
use futures::{SinkExt, StreamExt};
sender 则是 AppState
里通过 broadcast::channel(1024).0
创建的,其中的 1024 是设置的 capacity 容量,此处不易设置过小,否则当端发送速度大于接收端处理速度时,如果积压了大于 capacity 的消息,则会丢弃掉最早的超出内容,从而造成帖子数记录不准确。
小结
rust 的异步相关知识让学习 rust 之旅变得更陡峭了,一定是 nodejs 把笔者“养”得太好了,赶紧挑灯继续阅读 Asynchronous Programming in Rust