用Go语言实现WebSocket客户端
本文介绍了如何使用Go语言和gorilla/websocket库实现一个WebSocket客户端,包括连接、断开、消息接收、错误处理和重连机制。
- defagi
- 4 min read
在近期的 AI AGENT 项目开发进程中,服务器端运用了 python + fastapi + langchain(langgraph)的技术架构予以实现。而客户端需在家庭 IP 环境下运行,并且要求服务器与客户端能够相互调用函数。在全面权衡各类技术方案后,察觉到 socket、mqtt 等技术虽能满足部分需求,然而鉴于服务器还需提供 web 服务,最终选定了 WebSocket。客户端决定选用 golang 进行开发,由于对 golang 较为陌生,所以只能请求 Claude 协助编写一个基础版的 websocket 客户端,后续再自行修改。
Claude编写go websocket客户端
- 提示词很直接…
go websocket 客户端
- 感觉缺少写东西,好像缺少了,断开重连
添加断开重连
- 这个实现有点陌生,叫claude使用"github.com/gorilla/websocket"
使用 "github.com/gorilla/websocket" 实现
``
![gorilla_websocket_prompt](gorilla_websocket_prompt.png)
4. 还缺个回调
```md
添加状态回调功能
- 最终实现了go websocket的基础功能,包含如下:
- 配置(Config):定义了连接的 URL、重连等待时间、握手超时、读取限制和 HTTP 头部等参数。
- 回调(Callbacks):定义了连接、断开、消息接收、错误和重连时的回调函数。
- WSClient 结构体:包含了配置、回调、连接状态和同步机制。
- 连接管理:处理连接的建立和断开。
- 消息处理:接收和发送消息。
- 重连机制:在连接丢失时尝试重新连接。
- 心跳机制:定期发送心跳以保持连接活跃。
package main
import (
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Callbacks 定义所有可用的回调函数
type Callbacks struct {
OnConnect func()
OnDisconnect func(err error)
OnMessage func(message []byte)
OnError func(err error)
OnReconnect func(attempt int)
}
// Config WebSocket客户端配置
type Config struct {
URL string
ReconnectWait time.Duration
HandshakeTimeout time.Duration
ReadLimit int64
Headers http.Header
MaxRetries int // 最大重试次数,0表示无限重试
}
// DefaultConfig 返回默认配置
func DefaultConfig(url string) *Config {
return &Config{
URL: url,
ReconnectWait: 5 * time.Second,
HandshakeTimeout: 10 * time.Second,
ReadLimit: 512 * 1024, // 512KB
Headers: http.Header{},
MaxRetries: 0,
}
}
type WSClient struct {
config *Config
callbacks *Callbacks
conn *websocket.Conn
mu sync.Mutex
isConnected bool
done chan struct{}
retryCount int
}
func NewWSClient(config *Config, callbacks *Callbacks) *WSClient {
if callbacks == nil {
callbacks = &Callbacks{}
}
return &WSClient{
config: config,
callbacks: callbacks,
done: make(chan struct{}),
}
}
func (c *WSClient) Connect() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.isConnected {
return nil
}
dialer := websocket.Dialer{
HandshakeTimeout: c.config.HandshakeTimeout,
}
conn, _, err := dialer.Dial(c.config.URL, c.config.Headers)
if err != nil {
if c.callbacks.OnError != nil {
c.callbacks.OnError(fmt.Errorf("dial error: %v", err))
}
return fmt.Errorf("dial error: %v", err)
}
conn.SetReadLimit(c.config.ReadLimit)
conn.SetPingHandler(func(string) error {
return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second*5))
})
c.conn = conn
c.isConnected = true
// 触发连接成功回调
if c.callbacks.OnConnect != nil {
go c.callbacks.OnConnect()
}
return nil
}
func (c *WSClient) Disconnect() {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.conn.Close()
c.conn = nil
}
if c.isConnected {
c.isConnected = false
// 触发断开连接回调
if c.callbacks.OnDisconnect != nil {
go c.callbacks.OnDisconnect(nil)
}
}
}
func (c *WSClient) reconnect() {
c.retryCount++
for {
select {
case <-c.done:
return
default:
if c.config.MaxRetries > 0 && c.retryCount > c.config.MaxRetries {
if c.callbacks.OnError != nil {
c.callbacks.OnError(fmt.Errorf("max retry attempts reached"))
}
return
}
// 触发重连回调
if c.callbacks.OnReconnect != nil {
c.callbacks.OnReconnect(c.retryCount)
}
err := c.Connect()
if err == nil {
c.retryCount = 0 // 重置重试计数
return
}
time.Sleep(c.config.ReconnectWait)
}
}
}
func (c *WSClient) Send(message []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.isConnected {
return fmt.Errorf("not connected")
}
err := c.conn.WriteMessage(websocket.TextMessage, message)
if err != nil && c.callbacks.OnError != nil {
go c.callbacks.OnError(fmt.Errorf("send error: %v", err))
}
return err
}
func (c *WSClient) handleReceive() {
for {
select {
case <-c.done:
return
default:
if !c.isConnected {
time.Sleep(time.Second)
continue
}
_, message, err := c.conn.ReadMessage()
if err != nil {
if c.callbacks.OnError != nil {
go c.callbacks.OnError(fmt.Errorf("read error: %v", err))
}
c.Disconnect()
go c.reconnect()
continue
}
// 触发消息回调
if c.callbacks.OnMessage != nil {
go c.callbacks.OnMessage(message)
}
}
}
}
func (c *WSClient) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
c.mu.Lock()
if c.isConnected && c.conn != nil {
err := c.conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
if c.callbacks.OnError != nil {
go c.callbacks.OnError(fmt.Errorf("heartbeat error: %v", err))
}
c.conn.Close()
c.isConnected = false
go c.reconnect()
}
}
c.mu.Unlock()
}
}
}
func (c *WSClient) Start() error {
err := c.Connect()
if err != nil {
return err
}
go c.handleReceive()
go c.heartbeat()
return nil
}
func (c *WSClient) Stop() {
close(c.done)
c.Disconnect()
}
// IsConnected 返回当前连接状态
func (c *WSClient) IsConnected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.isConnected
}
// 使用示例
func main() {
// 创建配置
config := DefaultConfig("ws://localhost:8080/ws")
config.MaxRetries = 5 // 设置最大重试次数
// 创建回调
callbacks := &Callbacks{
OnConnect: func() {
log.Println("Connected to server")
},
OnDisconnect: func(err error) {
if err != nil {
log.Printf("Disconnected due to error: %v\n", err)
} else {
log.Println("Disconnected from server")
}
},
OnMessage: func(message []byte) {
log.Printf("Received message: %s\n", string(message))
},
OnError: func(err error) {
log.Printf("Error occurred: %v\n", err)
},
OnReconnect: func(attempt int) {
log.Printf("Attempting to reconnect... (attempt %d)\n", attempt)
},
}
// 创建客户端
client := NewWSClient(config, callbacks)
// 启动客户端
if err := client.Start(); err != nil {
log.Fatal(err)
}
// 定期发送消息示例
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-client.done:
return
case <-ticker.C:
if client.IsConnected() {
err := client.Send([]byte("Hello, Server!"))
if err != nil {
log.Printf("Send error: %v", err)
}
}
}
}
}()
// 保持程序运行
select {}
}
- Tags:
- Ai-News
- Programming