用Go语言实现WebSocket客户端

本文介绍了如何使用Go语言和gorilla/websocket库实现一个WebSocket客户端,包括连接、断开、消息接收、错误处理和重连机制。

defagi avatar
  • defagi
  • 4 min read

在近期的 AI AGENT 项目开发进程中,服务器端运用了 python + fastapi + langchain(langgraph)的技术架构予以实现。而客户端需在家庭 IP 环境下运行,并且要求服务器与客户端能够相互调用函数。在全面权衡各类技术方案后,察觉到 socket、mqtt 等技术虽能满足部分需求,然而鉴于服务器还需提供 web 服务,最终选定了 WebSocket。客户端决定选用 golang 进行开发,由于对 golang 较为陌生,所以只能请求 Claude 协助编写一个基础版的 websocket 客户端,后续再自行修改。

Claude编写go websocket客户端

  1. 提示词很直接…
go websocket 客户端

claude websocket prompt

  1. 感觉缺少写东西,好像缺少了,断开重连
添加断开重连

reconnect_prompt

  1. 这个实现有点陌生,叫claude使用"github.com/gorilla/websocket"
使用 "github.com/gorilla/websocket" 实现
``  
  
![gorilla_websocket_prompt](gorilla_websocket_prompt.png)
  
4. 还缺个回调
```md
添加状态回调功能

callback_prompt

  1. 最终实现了go websocket的基础功能,包含如下:
  • 配置(Config):定义了连接的 URL、重连等待时间、握手超时、读取限制和 HTTP 头部等参数。
  • 回调(Callbacks):定义了连接、断开、消息接收、错误和重连时的回调函数。
  • WSClient 结构体:包含了配置、回调、连接状态和同步机制。
  • 连接管理:处理连接的建立和断开。
  • 消息处理:接收和发送消息。
  • 重连机制:在连接丢失时尝试重新连接。
  • 心跳机制:定期发送心跳以保持连接活跃。
package main

import (
    "fmt"
    "log"
    "net/http"
    "sync"
    "time"

    "github.com/gorilla/websocket"
)

// Callbacks 定义所有可用的回调函数
type Callbacks struct {
    OnConnect    func()
    OnDisconnect func(err error)
    OnMessage    func(message []byte)
    OnError      func(err error)
    OnReconnect  func(attempt int)
}

// Config WebSocket客户端配置
type Config struct {
    URL             string
    ReconnectWait   time.Duration
    HandshakeTimeout time.Duration
    ReadLimit       int64
    Headers         http.Header
    MaxRetries      int // 最大重试次数,0表示无限重试
}

// DefaultConfig 返回默认配置
func DefaultConfig(url string) *Config {
    return &Config{
        URL:             url,
        ReconnectWait:   5 * time.Second,
        HandshakeTimeout: 10 * time.Second,
        ReadLimit:       512 * 1024, // 512KB
        Headers:         http.Header{},
        MaxRetries:      0,
    }
}

type WSClient struct {
    config     *Config
    callbacks  *Callbacks
    conn       *websocket.Conn
    mu         sync.Mutex
    isConnected bool
    done       chan struct{}
    retryCount int
}

func NewWSClient(config *Config, callbacks *Callbacks) *WSClient {
    if callbacks == nil {
        callbacks = &Callbacks{}
    }
    return &WSClient{
        config:    config,
        callbacks: callbacks,
        done:      make(chan struct{}),
    }
}

func (c *WSClient) Connect() error {
    c.mu.Lock()
    defer c.mu.Unlock()

    if c.isConnected {
        return nil
    }

    dialer := websocket.Dialer{
        HandshakeTimeout: c.config.HandshakeTimeout,
    }

    conn, _, err := dialer.Dial(c.config.URL, c.config.Headers)
    if err != nil {
        if c.callbacks.OnError != nil {
            c.callbacks.OnError(fmt.Errorf("dial error: %v", err))
        }
        return fmt.Errorf("dial error: %v", err)
    }

    conn.SetReadLimit(c.config.ReadLimit)
    conn.SetPingHandler(func(string) error {
        return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second*5))
    })

    c.conn = conn
    c.isConnected = true

    // 触发连接成功回调
    if c.callbacks.OnConnect != nil {
        go c.callbacks.OnConnect()
    }

    return nil
}

func (c *WSClient) Disconnect() {
    c.mu.Lock()
    defer c.mu.Unlock()

    if c.conn != nil {
        c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
        c.conn.Close()
        c.conn = nil
    }

    if c.isConnected {
        c.isConnected = false
        // 触发断开连接回调
        if c.callbacks.OnDisconnect != nil {
            go c.callbacks.OnDisconnect(nil)
        }
    }
}

func (c *WSClient) reconnect() {
    c.retryCount++
    
    for {
        select {
        case <-c.done:
            return
        default:
            if c.config.MaxRetries > 0 && c.retryCount > c.config.MaxRetries {
                if c.callbacks.OnError != nil {
                    c.callbacks.OnError(fmt.Errorf("max retry attempts reached"))
                }
                return
            }

            // 触发重连回调
            if c.callbacks.OnReconnect != nil {
                c.callbacks.OnReconnect(c.retryCount)
            }

            err := c.Connect()
            if err == nil {
                c.retryCount = 0 // 重置重试计数
                return
            }

            time.Sleep(c.config.ReconnectWait)
        }
    }
}

func (c *WSClient) Send(message []byte) error {
    c.mu.Lock()
    defer c.mu.Unlock()

    if !c.isConnected {
        return fmt.Errorf("not connected")
    }

    err := c.conn.WriteMessage(websocket.TextMessage, message)
    if err != nil && c.callbacks.OnError != nil {
        go c.callbacks.OnError(fmt.Errorf("send error: %v", err))
    }
    return err
}

func (c *WSClient) handleReceive() {
    for {
        select {
        case <-c.done:
            return
        default:
            if !c.isConnected {
                time.Sleep(time.Second)
                continue
            }

            _, message, err := c.conn.ReadMessage()
            if err != nil {
                if c.callbacks.OnError != nil {
                    go c.callbacks.OnError(fmt.Errorf("read error: %v", err))
                }
                c.Disconnect()
                go c.reconnect()
                continue
            }

            // 触发消息回调
            if c.callbacks.OnMessage != nil {
                go c.callbacks.OnMessage(message)
            }
        }
    }
}

func (c *WSClient) heartbeat() {
    ticker := time.NewTicker(30 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-c.done:
            return
        case <-ticker.C:
            c.mu.Lock()
            if c.isConnected && c.conn != nil {
                err := c.conn.WriteMessage(websocket.PingMessage, nil)
                if err != nil {
                    if c.callbacks.OnError != nil {
                        go c.callbacks.OnError(fmt.Errorf("heartbeat error: %v", err))
                    }
                    c.conn.Close()
                    c.isConnected = false
                    go c.reconnect()
                }
            }
            c.mu.Unlock()
        }
    }
}

func (c *WSClient) Start() error {
    err := c.Connect()
    if err != nil {
        return err
    }

    go c.handleReceive()
    go c.heartbeat()

    return nil
}

func (c *WSClient) Stop() {
    close(c.done)
    c.Disconnect()
}

// IsConnected 返回当前连接状态
func (c *WSClient) IsConnected() bool {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.isConnected
}

// 使用示例
func main() {
    // 创建配置
    config := DefaultConfig("ws://localhost:8080/ws")
    config.MaxRetries = 5 // 设置最大重试次数
    
    // 创建回调
    callbacks := &Callbacks{
        OnConnect: func() {
            log.Println("Connected to server")
        },
        OnDisconnect: func(err error) {
            if err != nil {
                log.Printf("Disconnected due to error: %v\n", err)
            } else {
                log.Println("Disconnected from server")
            }
        },
        OnMessage: func(message []byte) {
            log.Printf("Received message: %s\n", string(message))
        },
        OnError: func(err error) {
            log.Printf("Error occurred: %v\n", err)
        },
        OnReconnect: func(attempt int) {
            log.Printf("Attempting to reconnect... (attempt %d)\n", attempt)
        },
    }

    // 创建客户端
    client := NewWSClient(config, callbacks)

    // 启动客户端
    if err := client.Start(); err != nil {
        log.Fatal(err)
    }

    // 定期发送消息示例
    go func() {
        ticker := time.NewTicker(5 * time.Second)
        defer ticker.Stop()

        for {
            select {
            case <-client.done:
                return
            case <-ticker.C:
                if client.IsConnected() {
                    err := client.Send([]byte("Hello, Server!"))
                    if err != nil {
                        log.Printf("Send error: %v", err)
                    }
                }
            }
        }
    }()

    // 保持程序运行
    select {}
}

推荐

最全的微信机器人开源框架集合

最全的微信机器人开源框架集合

收集整理目前最全基于微信生态的机器人开源框架,结合LLM function call机制与业务API,实现对业务系统的精准访问。

Cursor白嫖方法-反复白嫖专业版

Cursor白嫖方法-反复白嫖专业版

Cursor是一款专业的编程辅助软件,可以帮助程序员提升编程效率,提高工作效率。如果想白嫖可使用邮箱别名无限登录试用,破解Cursor使用限制。