Rust实现WebSockets

创建于 2024年7月29日修改于 2024年7月29日
RustWebSocket

简介

在这篇文章中,我们将实现一个Echo服务器,每个套接字都会返回接收到的内容。我们会额外添加一些功能,但还远达不到完备所有功能的状态。

最后,我会讨论这种实现的缺陷以及我们未来要做的一些更改。请记住,我们只使用标准库来完成这些工作,因此某些东西会比使用第三方库时更冗长。

Contents

WebSockets

在之前的文章中,我们简要讨论了WebSockets,主要是客户端和服务器之间初始握手的上下文,使用此图作为概述:

Handshake

WebSocket连接的四个阶段

完整描述可以在 RFC 中找到,它写得很好且易于理解。我会涵盖文档中的一些内容,比如握手的完成方式以及帧的组成和解析,但更多的是如何在Rust中实现它。代码中的注释重复了很多这里的内容。代码库可以在这里找到,我建议我们在前进的过程中参考它。

概览

这是我们的 websocket.rs 模块的概览:

// websocket.rs

pub struct WebSocket {
    stream: TcpStream,
}

impl WebSocket {
    pub fn new(stream: TcpStream);

    pub fn connect(&mut self) 
    fn handle_handshake(&mut self, request: &str)

    pub fn handle_connection(&mut self) 
    fn parse_frame(&mut self, buffer: &[u8])
    fn send_ping(&mut self) 
    fn send_pong(&mut self) 
    fn send_text(&mut self, data: &str) 
}

目前这里内容不多,尽管其中一些函数做了大量的工作。我们将主要关注 connecthandle_connectionparse_frame,依次进行。目前只有两个需要公开,尽管稍后我们会公开更多函数来处理用户之间的通信。

首先,我们来看一下如何在我们的 main.rs 中使用这个WebSocket,然后我们会深入了解套接字本身的工作原理。

启动我们的服务器

要启动我们的服务器,我们需要创建一个循环来监听和接受用户的传入连接。我们还需要验证这些请求以检查它们是否是WebSockets。为此,我们将在主函数中创建一个循环,使用 TcpListener 接受所有传入请求,然后将它们传递给 handle_client

// main.rs
fn main() {
    let listener = TcpListener::bind("127.0.0.1:8080").expect("Could not bind to port");
    println!("WebSocket server is running on ws://127.0.0.1:8080/");

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => {
                thread::spawn(move || {
                    handle_client(stream);
                });
            }
            Err(e) => {
                println!("Failed to accept client: {}", e);
            }
        }
    }
}

在这里,当我们收到一个新的传入连接请求时,它会生成一个新线程来处理它。注意,我们不处理任何跨线程通信,每个套接字都是一个独立的单元。这看起来并不理想,但这是我们将在本系列后面部分解决的问题。

当生成一个线程时,它会将 TcpStream 传递给 handle_client,后者执行以下操作:

// main.rs

fn handle_client(stream: TcpStream) {
    let mut ws = WebSocket::new(stream);

    match ws.connect() {
        Ok(()) => {
            println!("WebSocket connection established");
            match ws.handle_connection() {
                Ok(_) => {
                    println!("Connection ended without error");
                }
                Err(e) => {
                    println!("Connection ended with error {:?}", e);
                }
            }
        }
        Err(e) => {
            println!("Failed to establish a WebSocket connection: {}", e);
        }
    }
}

这里发生的事情是我们创建了一个 WebSocket 实例,并调用 connect,然后是 handle_connection。正如我们很快会看到的,handle_connection 函数包含一个循环,它会回显消息直到连接关闭。当连接关闭时,它会从 handle_connection 返回,然后是 handle_client,最后终止我们在main中生成的线程。

总结一下,我们有以下过程:

  1. 一个无限循环监听TCP连接。
  2. 当建立连接时,它被接受。
  3. 连接流被传递给handle_client。
  4. 我们创建一个WebSocket实例。
  5. 我们调用 connect 函数。
  6. 如果没有问题,我们调用 handle_connection
  7. handle_connection 返回时,连接和线程都终止。

希望这相对直观,因为在WebSocket中的工作细节稍微复杂一些。

连接(HTTP请求)

main.rs 中,我们首先调用WebSocket上的 connect 函数。这个函数的任务是验证连接,确保它是我们期望的WebSocket连接。记住,我们一开始接受所有连接,所以 connect 的工作是筛选掉不属于WebSocket的请求,例如请求cat.png或发送银行信息的POST请求。

connect 函数分为两部分:处理初始HTTP请求以确定它是否为GET请求,以及进行WebSocket握手。其过程大致如下:

  1. 定义一个缓冲区
  2. 将HTTP请求读入缓冲区
  3. 检查请求是否为GET
  4. 进行握手

如果你在跟随代码库Repo,你会发现(4)是一个私有函数 handle_handshake。我们稍后会详细介绍它,因为它包含许多步骤。现在我们按照上述步骤逐一进行:

步骤1 - 定义一个缓冲区

// websocket.rs - connect()

let mut buffer: [u8; 1024] = [0; 1024];

我们可以将缓冲区定义为任何大小,但这应该是足够的。如果我们预计接收到的请求长度超过1024字节,我们可以增加大小。

步骤2 - 将HTTP请求读入缓冲区

// websocket.rs - connect()

// From the stream read in the HTTP request
let byte_length = match self.stream.read(&mut buffer) {
    Ok(bytes) => bytes,
    Err(e) => return Err(WebSocketError::IoError(e)),
};

这里,stream 是在 main 中通过 handle_client 实例化时定义的结构体成员,并通过这一行给出:

// main.rs

let mut ws = WebSocket::new(stream);

使用这个,我们读取read请求。read 函数返回字节长度,我们用它来创建一个名为 requeststr 变量。

// websocket.rs - connect()

let request = str::from_utf8(&buffer[..byte_length])?;

在我们将完整请求作为字符串之后,我们可以继续。

步骤3 - 检查请求是否为GET

然后我们确保我们的请求是GET请求:

// websocket.rs - connect()

// 我们只处理用于升级的GET请求
if !request.starts_with("GET") {
    return Err(WebSocketError::NonGetRequest);
}

WebSocket连接请求应始终为HTTP GET,因此我们在这里检查并在否则情况下抛出错误。

WebSocket connection requests should always be an HTTP GET, so we check that here and throw an error otherwise.

步骤4 - 进行握手

最后一步分为两部分,其中一个我们将在下面的部分详细介绍,另一个是将响应返回给用户:

// websocket.rs - connect()

// 获取HTTP响应头并发送回去
let response = self.handle_handshake(request)?;

// 使用响应
self.stream
    .write_all(response.as_bytes())
    .map_err(WebSocketError::IoError)?;

self.stream.flush().map_err(WebSocketError::IoError)?;
Ok(())

在我们能够对连接给予“ok”之前,最后一件事是握手,它将在我们的 handle_handshake 函数中进行。你可以看到,在握手完成后,我们会将响应response写回流中,向用户表示他们应该将连接升级为WebSocket。

另一部分,握手handshake,是我们需要进一步研究的内容。

连接 (握手 handshake)

握手过程是我们在前文中简要讨论过的,当时我们看了用户向服务器传递密钥的过程以及这些密钥如何作为验证用户希望升级连接的方式。这并不难理解,尽管某些部分需要你自己进行更多的研究以确切了解它们为何以这种方式完成。

握手过程包括以下步骤:

  1. 检查 Sec-WebSocket-Key 是否存在。
  2. 创建响应密钥。
  3. 对响应密钥进行哈希。
  4. 使用Base64编码响应密钥。
  5. 创建一个升级连接的HTTP响应头。

请记住,这些步骤中的每一个都是 connect 函数中的(4)的一部分,因此实际上从启动连接到我们可以开始处理消息,总共有八个步骤。

步骤1 - 检查Sec-WebSocket-Key是否存在

handle_handshake 方法接受我们在 connect 中创建的请求,然后设置SHA-1和Base64实例,以便我们分别进行哈希和编码。

The handle_handshake method takes in the request we created in connect and then sets up our SHA-1 and Base64 instances so we can hash and encode, respectively.

// websocket.rs - handle_handshake()

let mut base64 = Base64::new();
let mut sha1 = Sha1::new();

let key_header = "Sec-WebSocket-Key: ";

// 给定请求,我们找到以key_header开头的行,然后找到客户端发送的密钥。
let key = request
    .lines()
    .find(|line| line.starts_with(key_header))
    .map(|line| line[key_header.len()..].trim())
    .ok_or_else(|| {
        WebSocketError::HandshakeError(
            "Could not find Sec-WebSocket-Key in HTTP request header".to_string(),
        )
    })?;

密钥的解析有点复杂,但可以分解为:

  1. 找到 Sec-WebSocket-Key 所在的行。
  2. 切掉 Sec-WebSocket-Key 部分,留下密钥。
  3. 如果出错,则退出。

如果我们有 key,我们可以继续下一步。

步骤2 - 创建响应密钥

这部分相当简单:

// websocket.rs - handle_handshake()

// 按照WebSocket协议规范,附加必要的ID
let response_key = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", key);

根据RFC:

对于此头字段,服务器必须获取值(如头字段中所示,例如,去除所有前导和尾随空白的base64编码[RFC4648]版本)并将其与全局唯一标识符(GUID,[RFC4122])“258EAFA5-E914-47DA-95CA-C5AB0DC85B11”以字符串形式连接,这在不了解WebSocket协议的网络端点中不太可能使用。然后将此连接的SHA-1哈希(160位)[FIPS.180-3],base64编码(见[RFC4648]的第4节)返回到服务器的握手中。

这就是说我们将用户发送给我们的任何 response_key 与预定义的GUID连接起来,这对于每个WebSocket连接都是相同的。之后,我们将进行哈希和编码的下两个步骤。

步骤3 - 对响应密钥进行哈希

// websocket.rs - handle_handshake()

// 首先对客户端发送的随机密钥进行哈希
let hash = sha1.hash(response_key).map_err(|_| {
    WebSocketError::HandshakeError("Failed to hash the response key".to_string())
})?;

这将给我们一个固定长度的哈希。

步骤4 - 使用Base64编码响应密钥

哈希然后被编码为Base64。

// websocket.rs - handle_handshake()

// 然后我们将该哈希编码为Base64
let header_key = base64.encode(hash).map_err(|_| {
    WebSocketError::HandshakeError("Failed to encode the hash as Base64".to_string())
})?;

在这两种情况下,如果发生任何失败,我们将退出并返回 HandshakeError,这是在 websocket.rs 中定义的自定义错误类型。

步骤5 - 创建一个升级连接的HTTP响应头

最后,我们使用 header_key 并将其插入响应头:

// websocket.rs - handle_handshake()\

// 最后,我们将该密钥附加到响应头
Ok(format!(
    "HTTP/1.1 101 Switching Protocols\r\n\
    Upgrade: websocket\r\n\
    Connection: Upgrade\r\n\
    Sec-WebSocket-Accept: {}\r\n\r\n",
    header_key
))

在这里我们必须包括 UpgradeConnection 以确保WebSocket的建立。返回的密钥也将由客户端验证。如果它无效,连接将失败。

此时,如果没有失败,连接已经建立,并且在我们的 handle_client 函数中调用 ws.connect() 之后,我们已经回到 main.rs。接下来我们要做的是维护连接并处理任何传入的消息。这将在 handle_connection 中进行。

处理消息

handle_connection 函数,是 WebSocket 的一部分,看起来相当冗长,但大部分内容都用于处理我们可能收到的不同类型的消息。实际上,发生的事情如下:

  1. 启动一个无限循环。
  2. 从我们的连接中读取。
  3. 检查消息类型。
  4. 处理该消息。
  5. 如果连接关闭,则退出循环。

还有一些其他内容,但我们稍后会讨论。现在请注意,我们有一个主循环,用于接收建立 WebSocket 连接的请求。当这些连接建立后,它们会有自己的循环,只处理该用户的消息。

如果你正在查看 handle_connection 函数的代码,你会注意到顶部有一些关于 ping 和 pong 的内容:

// websocket.rs - handle_connection

// 2048 的缓冲区应该足够处理传入的数据。
let mut buffer = [0; 2048];

// 发送初始 ping
self.send_ping()?;
let mut last_ping = std::time::Instant::now();
let mut pong_received = false;

缓冲区只是用于接收单个消息,这应该是不言自明的,但第二部分有点令人好奇。

这是为一种称为心跳的机制做准备,这是一种服务器检查连接是否仍然活跃的方法,如果没有则关闭它。你可能会想,既然往下看有 Frame::Close 来处理关闭连接,为什么还需要这个?

通常情况下,确实会处理关闭,但 Frame::Close 需要用户发送一条消息来声明这一点。由于你的浏览器会处理 WebSocket 连接的细节,即使你强制退出浏览器,它通常也会发送这条关闭消息,但如果没有呢?在这种情况下,我们需要一个备用计划,那就是在循环内进行心跳检查:

// websocket.rs - handle_connection
// 发送初始 ping
self.send_ping()?;
let mut last_ping = std::time::Instant::now();
let mut pong_received = false;

// 在 main.rs 中生成的线程内运行的主要循环
loop {
    // 这是检查连接是否超时的代码。
    // 我们硬编码为 10 秒,但以后最好能配置。
    if last_ping.elapsed() > Duration::from_secs(10) {
        if !pong_received {
            println!("Pong not received; disconnecting client.");
            break;
        }

        if self.send_ping().is_err() {
            println!("Ping failed; disconnecting client.");
            break;
        }

        pong_received = false;
        last_ping = std::time::Instant::now();
    }
    ...
 }

这是一段相当大的代码,但它做的事情相对简单:

  1. 在循环之前,向用户发送初始 ping。
  2. 记录发送 ping 的时间。
  3. ping_received 标志设置为 false。
  4. 进入循环。
  5. 检查是否已经过去了 10 秒。
  6. 如果没有收到 pong,退出循环。
  7. 如果收到 pong,重置 pong_receivedlast_ping

每 10 秒左右,我们检查用户是否回复了我们的 ping,如果没有,则假设连接已断开。这里我设置检查时间为 10 秒,但如果你愿意,可以设置为 60 秒或更长。

稍后,我们将使用相同的想法来踢出空闲用户。例如,假设有人连接到我们的服务器但只是坐在那里。如果是聊天客户端,这可能没问题,但如果是游戏,那个人就占用了实际在玩游戏的玩家的带宽。你可以设置一个类似的空闲时间为 10 分钟,如果他们除了 pong 外没有发送任何消息,就把用户踢回登录界面。

现在让我们看看读取信息,它看起来与我们在主程序中做的几乎一样。不同的是这里发送的数据是需要解析的 WebSocket frames。我们在流上的读取循环如下:

// websocket.rs - handle_connection

// 在 main.rs 中生成的线程内运行的主要循环
loop {
    // 心跳代码
    ...

    // 读取当前的流或数据。
    match self.stream.read(&mut buffer) {
        // read(&mut buffer) 将返回一个 usize,我们只在大于 0 时处理它。然后我们在 parse_frame 函数中解析帧。
        Ok(n) if n > 0 => match self.parse_frame(&buffer[..n]) {
    ...
    }
}

大部分工作是在 parse_frame 函数中完成的,它接收缓冲区的当前值并确定将采取的操作。操作是指用户是发送了 ping、pong、close、text 还是 binary 消息?为了弄清楚这一点,让我们看看什么是帧以及如何解析它。

WebSocket 帧

在客户端和服务器之间发送的 WebSocket 消息称为帧,它们采用如 ASCII RFC 图所示的特定形式。下面是一个稍微不同格式的版本:

Websocket Frame

一个 WebSocket 帧

如上所示,每个部分代表一个字节,负载数据是一些任意数量的字节。你可以在 RFC 中阅读更多关于这些的内容,虽然我会在我们浏览代码时详细讨论每个部分。

parse_frame 中,我们首先检查缓冲区的长度是否至少为用于关闭连接的最小两字节帧:

// websocket.rs - parse_frame

if buffer.len() < 2 {
    return Err(WebSocketError::ProtocolError("Frame too short".to_string()));
}

let first_byte = buffer[0];

如果是这样,我们继续提取第一个和第二个字节作为掩码位和负载长度:

// websocket.rs - parse_frame
let first_byte = buffer[0];
let opcode = first_byte & 0x0F; // 确定操作码

let second_byte = buffer[1];
let masked = (second_byte & 0x80) != 0;
let mut payload_len = (second_byte & 0x7F) as usize;

查看第一个字节,我们提取它的后半部分,即帧的操作码。如果你不熟悉操作码(opcodes),它们表示我们将执行的操作。在 WebSocket 的情况下,它表示我们如何处理消息,包括:

  1. continuation
  2. text
  3. binary
  4. close
  5. ping
  6. pong

还有一些保留的操作码,预留给未来可能的 WebSocket 功能,如 RFC 中所述。

回到上面的代码,在第一个字节中,我们忽略了用于 FINRSV 值的前四位。FIN 位是一个为多部分/分段消息设置的标志。在这个例子中,我们不会实现分段消息,但将来可能会实现。RSV 值保留给未来的可能功能,类似于操作码。

在第二个字节中,我们有一个 MASK 位,表示我们是否使用掩码。根据 RFC:

定义“负载数据”是否被掩码。如果设置为 1,掩码键将出现在掩码键中,并按照第 5.3 节的规定取消掩码。“客户端到服务器的所有帧都将此位设置为 1”。

这将我们引向下一段代码,确保它设置为 1。

// websocket.rs - parse_frame

 // 如果没有掩码,退出
if !masked {
    return Err(WebSocketError::ProtocolError(
        "Frames from client must be masked".to_string(),
    ));
}

如果你不熟悉什么是掩码,维基百科上的文章相当不错。请注意,这个位仅表示我们是否使用在扩展负载部分后面的四个字节中定义的掩码键。

在第一个字节和掩码位之后,我们有上图所示的 PAYLOAD LENGTH。这是七位,最大值为 127。有三种长度值很重要:

  1. 小于或等于 125
  2. 等于 126
  3. 等于 127

是的,负载长度可以变动,并且可以表示为 64 位整数(虽然最高有效位必须为 0),这非常大。我想这可能是为了未来的需要,以防将来有人想通过 WebSocket 连接发送全部未压缩的人类信息。

为了检查这些负载长度,我们使用以下代码:

// websocket.rs - parse_frame

// 初始设置为 2,以便跳过前两个字节
let mut offset = 2;

if payload_len == 126 {
    if buffer.len() < 4 {
        return Err(WebSocketError::ProtocolError(
            "Frame too short for extended payload length".to_string(),
        ));
    }

    payload_len = u16::from_be_bytes([buffer[offset], buffer[offset + 1]]) as usize;
    offset += 2;
} else if payload_len == 127 {
    return Err(WebSocketError::ProtocolError(
        "Extended payload length too large".to_string(),
    ));
}

你会注意到,代码在初始两个字节之后设置偏移量,因为那是扩展负载长度所在。我们仅在 payload_len 等于 126 时允许代码继续运行,因为如上所述,我们跳过扩展负载。

下一步是检查缓冲区长度是否有效,然后将掩码应用于数据。

// websocket.rs - parse_frame

 if buffer.len() < offset + 4 + payload_len {
    return Err(WebSocketError::ProtocolError(
        "Frame too short for mask and data".to_string(),
    ));
}

// 提取掩码键
let mask = &buffer[offset..offset + 4];

// 跳过掩码键并开始处理数据
offset += 4;

// 提取并通过 XOR 应用掩码键
let mut data = Vec::with_capacity(payload_len);
for i in 0..payload_len {
    data.push(buffer[offset + i] ^ mask[i % 4]);
}

掩码是通过在缓冲区的每一部分上使用 XOR 完成的。在取消掩码之前,它看起来像乱码。这个取消掩码的数据将是我们从用户接收到的数据,我们然后根据之前提取的操作码来处理它:

// websocket.rs - parse_frame

// 返回操作码和数据(如果有)
Ok(match opcode {
    0x01 => Frame::Text(data),   // text frame 文本帧
    0x02 => Frame::Binary(data), // binary frame 二进制帧
    0x08 => Frame::Close,        // close frame 关闭帧
    0x09 => Frame::Ping,         // ping frame ping 帧
    0x0A => Frame::Pong,         // pong frame pong 帧
    _ => return Err(WebSocketError::ProtocolError("Unknown opcode".to_string())),
})

这将我们带回到 handle_connection 函数,我们根据返回的内容采取一些行动。在我们只关心一个回声服务器时,这些行动就是我们最后要讨论的内容。

处理帧

在处理帧时,Pong 将 pong_received 值设置为 true,而 Text 会回显数据。只有在找到 Frame::Text 时,我们才真正做任何事情:

// websocket.rs - handle_connection

Ok(Frame::Text(data)) => match String::from_utf8(data) {
    Ok(valid_text) => {
        println!("Received data: {}", valid_text);
        if self.send_text(&valid_text).is_err() {
            println!("Failed to send echo message");
            break;
        }
    }
    Err(utf8_err) => {
        return Err(WebSocketError::Utf8Error(utf8_err.utf8_error()));
    }
},

如果在处理数据时出现任何问题,连接将终止,这可能不是理想的,但目前可以使用。

随着我们浏览这些例子,你会看到很多自定义错误,如 WebSocketError,我建议查看仓库中的完整代码,以更好地了解每个错误的处理。在 handle_connection 中的循环内,这些错误将终止循环,从而终止连接。

运行程序

你可以按照 README 或以下说明在自己的电脑上测试代码。

首先,克隆仓库:

git clone https://github.com/kilroyjones/series_game_from_scratch

然后运行:

cd websockets_from_scratch/2_websocket
cargo run

在同一文件夹中,有一个 client 文件夹,里面有一个单独的 HTML 页面,可以用来测试服务器。如何运行这个页面取决于你,但如果你安装了 python3,你可以这样做:

python3 -m http.server

之后,你可以导航到 http://localhost:8000 查看一切是否正常。

问题

当前实现中存在许多问题,我们将继续解决。我们已经讨论了一些,你可能还看到了其他问题,但这里是我们计划在不久的将来解决的一些问题的简短列表:

很可能还有其他问题,还望海涵。

原文:https://www.thespatula.io/rust/rust_websocket/