Rust 代码挑战系列 6 - 数据库的基本操作和 sqlx 的使用

这是一个学习 Rust 的小系列文章,通过完成来自 shuttle 平台所举办的 2023 Christmas Code Hunt 里的每个小挑战,来学习 rust web 框架的使用,此为第六篇文章。

本篇内的 Part 13 将介绍 rust 中数据库的基本使用。

sqlx 简单介绍

rust 中操作数据库的 crate 有很多,比如 sqlx、diesel、seaorm 等,我们将选用 sqlx,它支持 postgresql、mysql 等常见数据库,笔者将使用 mysql8 数据库,现在我们简单了解下 sqlx crate:

sqlx 并不是一个 ORM,它没有提供 find_ond()update()delete() 等封装,而是需要我们自行编写原始的 sql 语句,比如要查询所有订单:

sqlx::query!(r#"SELECT * FROM orders"#).fetch_all(&pool).await.unwrap();

害怕 sql 语句写的不对怎么办?sqlx 提供了编译时检查语句的功能,当我们使用的是 query!()query_as!() 等宏而不是 query()query_as() 等方法时,sqlx 会静态检查语句正确性,它会自动连接数据库(也支持离线检查),若有错误则会报错,其中数据库的连接地址取自 .env 文件的 DATABASE_URL,倘若连接失败也会编译报错。本文尽可能降低一些依赖行为,暂不选用这种写法。

此外,sqlx 还提供了 sqlx-cli 命令行工具,它提供了 migration 功能管理数据库内容变更,本文暂不讨论。

Part 13

在进行挑战前,需要先配置好数据库,笔者使用 docker 运行数据库:

docker run -p 3306:3306 -e MYSQL_ROOT_PASSWORD=jstips -e MYSQL_DATABASE=cch-demo mysql:8.0.36

然后在项目内安装所需依赖,sqlx 提供了两种异步运行时 tokio 和 async-std ,此处选用的是 tokio,同时开启对 mysql 的支持:

cargo add sqlx --features runtime-tokio,mysql

sqlx 提供了连接池,我们将其作为共享状态以方便不同路由函数的使用,编辑 main.rs 文件:

use sqlx::MySqlPool;

// ...
async fn main() {
    let pool = MySqlPool::connect("mysql://root:jstips@localhost:3306/cch-demo")
        .await
        .unwrap();

    let router = Router::new()
        // ...
        .route("/13/sql", get(handler::d13_1))
        .route("/13/reset", post(handler::d13_2_reset))
        .route("/13/orders", post(handler::d13_2_orders))
        .route("/13/orders/total", get(handler::d13_2_total))
        .route("/13/orders/popular", get(handler::d13_3))
        .with_state(handler::D13State { pool: pool.clone() });
}

上述笔者提前将 Part 13 里的路由函数配置进来了,下述便不再强调路由配置。

然后在 d13.rs 里我们新建对应的 D13State

#[derive(Clone)]
pub struct D13State {
    pub pool: MySqlPool,
}

上述的 pool 字段没有定义为 Arc<Mutex<MySqlPool>>,因为 MySqlPool 内部已经做了封装,我们无需再包裹一层。

Task 1

马上都要五一了,这里就不再描述原文里的圣诞场景了,没有氛围感,所以后面仅会简单介绍下挑战内容,不会再有圣诞老人出没。

第一个任务乃热身,来了解 sqlx 怎么执行 sql:

pub async fn d13_1(state: State<D13State>) -> String {
    let pool = state.pool.clone();

    let rec = sqlx::query_scalar::<_, i32>("SELECT 20231213")
        .fetch_one(&pool)
        .await
        .unwrap();

    rec.to_string()
}

Task 2

该任务里需要向 orders 表中插入数据并进行检索。

/13/reset 端口用来重置表内容:

pub async fn d13_2_reset(state: State<D13State>) {
    let pool = state.pool.clone();
    sqlx::query(r#"DROP TABLE IF EXISTS orders;"#)
        .execute(&pool)
        .await
        .unwrap();
    sqlx::query(
        r#"CREATE TABLE orders (
      id INT PRIMARY KEY,
      region_id INT,
      gift_name VARCHAR(50),
      quantity INT
    );"#,
    )
    .execute(&pool)
    .await
    .unwrap();
}

/13/orders 用来插入数据:

// 使用 FromRow 将 sql 查询结果转换为 rust 的数据结构
#[derive(Debug, Serialize, Deserialize, FromRow)]
pub struct Order {
    id: i32,
    region_id: i32,
    gift_name: String,
    quantity: i32,
}
pub async fn d13_2_orders(state: State<D13State>, body: Json<Vec<Order>>) {
    let pool = state.pool.clone();
    // 使用事务确保全部完整插入
    let mut transaction = pool.begin().await.unwrap();
    for item in body.0 {
        sqlx::query(
            r#"
    INSERT INTO orders (id, region_id, gift_name, quantity) VALUES (?, ?, ?, ?)
    "#,
        ) // mysql 的参数占位符是 ?   postgresql 的占位符是 $1 $2 ..
        .bind(item.id)
        .bind(item.region_id)
        .bind(item.gift_name)
        .bind(item.quantity)
        // 此处需要传入 transaction 而非 pool
        .execute(&mut *transaction)
        .await
        .unwrap();
    }
    // 记得调用 commit,否则在变量销毁时会自动执行 rollback
    transaction.commit().await.unwrap();
}

/13/orders/total 用来查询数据:

pub async fn d13_2_total(state: State<D13State>) -> Json<Value> {
    let pool = state.pool.clone();
    let orders = sqlx::query_scalar::<_, i64>("SELECT SUM(quantity) FROM orders")
        .fetch_one(&pool)
        .await
        .unwrap();

    Json::from(json!({
      "total": orders
    }))
}

在执行 cch23-validator 13 验证用例时会发现报错了,报错位置在 d13_2_total 里: ColumnDecode { index: "0", source: "mismatched types; Rust type i64 (as SQL type BIGINT) is not compatible with SQL type DECIMAL" }

从错误信息可以看出,sql 的 SUM 返回的 DICIMAL 类型无法转换成 rust 里的 i64 类型,关于 mysql 的类型和 rust 类型的映射可以参阅文档,可知若要使用 decimal 类型,需要开启 decimal feature,所以我们编辑下 Cargo.toml,给 sqlx 新增 rust_decimal feature,然后将上述的 i64 改为 sqlx::types::Decimal,由于该挑战需要的返回值是数值类型,所以还需将 decimal 转为数值:

Json::from(json!({
    "total": i64::from_str_radix(&orders.to_string(), 10).unwrap_or_default()
}))

Task 3

找到最受欢迎的礼物:

#[derive(Debug, FromRow)]
struct D133Rec {
    sum: sqlx::types::Decimal,
    gift_name: String,
}

pub async fn d13_3(state: State<D13State>) -> Json<Value> {
    let pool = state.pool.clone();
    let rec = sqlx::query_as::<_, D133Rec>(
        "SELECT SUM(quantity) AS sum, gift_name FROM orders GROUP BY gift_name",
    )
    .fetch_all(&pool)
    .await
    .unwrap();

    return Json::from(json!({
      "popular": if rec.len() ==0 { json!(null) } else { json!(rec.iter().max_by_key(|x| x.sum).unwrap().gift_name) }
    }));
}

小结

这次的内容对于熟悉 sql 的读者来说,难点主要在于 sqlx 的使用,上面我们用到了 queryquery_scalarquery_as 这些方法,以及跟着这些方法后面的 executefetch_onefetch_all 等方法,下面是笔者的一个简单个人笔记,读者们也可以阅读文档以及网上的更多例子来更深刻了解用法(虽然笔者很想说去看源码来理解,但是这个真太难看懂了,还是从上层理解吧):

他们都有 fetchfetch_allfetch_onefetch_optional 方法:

小结的小结

此系列会抓紧完结,不然下一个圣诞节都要来了。