如何用 Go 封装大模型推理服务

引言

随着大语言模型(LLM)的快速发展,如何高效地封装和部署这些模型成为了一个重要的话题。本文将详细介绍如何使用 Go 语言构建一个高性能、可扩展的大模型推理服务。我们将从基础架构设计开始,逐步深入到具体实现细节。

目录

  1. 服务架构设计
  2. 核心接口定义
  3. 模型推理封装
  4. 并发控制与性能优化
  5. 错误处理与监控
  6. 完整代码实现
  7. 部署与扩展建议

1. 服务架构设计

在设计大模型推理服务时,我们需要考虑以下几个关键点:

  • 高并发处理能力
  • 资源使用效率
  • 服务可扩展性
  • 错误处理机制
  • 监控与可观测性

1.1 系统架构图

flowchart TB
    %% 定义节点
    Client[客户端]
    Gateway[API网关]
    LB[负载均衡器]
    SI1[服务实例1]
    SI2[服务实例2]
    SIN[服务实例N]
    MI1[模型实例1]
    MI2[模型实例2]
    MIN[模型实例N]
    Cache[(分布式缓存)]
    Monitor[监控系统]

    %% 定义连接
    Client -->|请求| Gateway
    Gateway -->|转发| LB
    LB -->|分发| SI1
    LB -->|分发| SI2
    LB -->|分发| SIN
    SI1 -->|调用| MI1
    SI2 -->|调用| MI2
    SIN -->|调用| MIN
    MI1 -->|读写| Cache
    MI2 -->|读写| Cache
    MIN -->|读写| Cache
    MI1 -->|上报| Monitor
    MI2 -->|上报| Monitor
    MIN -->|上报| Monitor

    %% 定义样式
    classDef default fill:#f8f9fa,stroke:#495057,stroke-width:2px;
    classDef cache fill:#e9ecef,stroke:#495057,stroke-width:2px;
    classDef monitor fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;
    class Cache cache;
    class Monitor monitor;

系统架构图展示了整个服务的核心组件和它们之间的关系:

  • 客户端通过 API 网关访问服务
  • 负载均衡器分发请求到多个服务实例
  • 服务实例与模型实例交互
  • 分布式缓存共享推理结果
  • 监控系统收集指标并触发告警

1.2 数据流转图

sequenceDiagram
    participant Client as 客户端
    participant Gateway as API网关
    participant Service as 服务层
    participant Cache as 缓存
    participant Pool as 工作池
    participant Model as 模型层
    participant Monitor as 监控系统

    Client->>Gateway: 发送推理请求
    Gateway->>Service: 转发请求
    Service->>Cache: 查询缓存
    alt 缓存命中
        Cache-->>Service: 返回缓存结果
    else 缓存未命中
        Service->>Pool: 提交任务
        Pool->>Model: 执行推理
        Model-->>Pool: 返回结果
        Pool-->>Service: 返回结果
        Service->>Cache: 更新缓存
    end
    Service->>Monitor: 上报指标
    Service-->>Gateway: 返回结果
    Gateway-->>Client: 响应请求

    Note over Service,Model: 异步处理
    Note over Service,Monitor: 实时监控

数据流转图展示了请求处理的完整流程:

  1. 客户端发送推理请求
  2. API 网关接收并转发请求
  3. 服务层检查缓存
  4. 如果缓存未命中,则通过工作池执行推理
  5. 模型层处理推理请求
  6. 结果返回并更新缓存
  7. 监控系统收集处理指标

1.3 组件关系图

flowchart TB
    %% 定义子图
    subgraph Access[接入层]
        direction TB
        Gateway[API网关]
        Auth[认证授权]
        RateLimit[限流器]
        Gateway --> Auth
        Gateway --> RateLimit
    end

    subgraph Service[服务层]
        direction TB
        Instance[服务实例]
        WorkerPool[工作池]
        Cache[(缓存)]
        Instance --> WorkerPool
        Instance --> Cache
    end

    subgraph Model[模型层]
        direction TB
        ModelInstance[模型实例]
        Quantizer[量化器]
        Optimizer[优化器]
        ModelInstance --> Quantizer
        ModelInstance --> Optimizer
    end

    subgraph Monitor[监控层]
        direction TB
        Metrics[指标收集]
        Alert[告警系统]
        Log[日志系统]
        Metrics --> Alert
        Metrics --> Log
    end

    %% 定义连接
    Gateway --> Instance
    Auth --> Instance
    RateLimit --> Instance
    Instance --> ModelInstance
    ModelInstance --> Metrics

    %% 定义样式
    classDef default fill:#f8f9fa,stroke:#495057,stroke-width:2px;
    classDef subgraph fill:#e9ecef,stroke:#495057,stroke-width:1px;
    classDef cache fill:#e9ecef,stroke:#495057,stroke-width:2px;
    classDef monitor fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;
    class Access,Service,Model,Monitor subgraph;
    class Cache cache;
    class Monitor monitor;

组件关系图展示了系统各个层级之间的详细关系:

  • 接入层:处理认证、授权和限流
  • 服务层:管理服务实例、工作池和缓存
  • 模型层:处理模型实例、量化和优化
  • 监控层:收集指标、日志和追踪信息

1.4 部署架构图

flowchart TB
    %% 定义子图
    subgraph Client[客户端]
        direction TB
        Web[Web应用]
        Mobile[移动应用]
        API[API客户端]
    end

    subgraph LB[负载均衡]
        direction TB
        LoadBalancer[负载均衡器]
        HealthCheck[健康检查]
        LoadBalancer --> HealthCheck
    end

    subgraph K8s[Kubernetes集群]
        direction TB
        subgraph Pods[服务实例]
            direction TB
            Pod1[Pod 1]
            Pod2[Pod 2]
            PodN[Pod N]
        end

        subgraph Storage[存储]
            direction TB
            Persistent[(持久化存储)]
            ConfigMap[配置管理]
            Secret[密钥管理]
        end

        subgraph Monitor[监控]
            direction TB
            Prometheus[Prometheus]
            Grafana[Grafana]
            AlertManager[告警管理器]
            Prometheus --> Grafana
            Prometheus --> AlertManager
        end
    end

    %% 定义连接
    Web --> LoadBalancer
    Mobile --> LoadBalancer
    API --> LoadBalancer
    LoadBalancer --> Pod1
    LoadBalancer --> Pod2
    LoadBalancer --> PodN
    Pod1 --> Persistent
    Pod2 --> Persistent
    PodN --> Persistent
    Pod1 --> Prometheus
    Pod2 --> Prometheus
    PodN --> Prometheus

    %% 定义样式
    classDef default fill:#f8f9fa,stroke:#495057,stroke-width:2px;
    classDef subgraph fill:#e9ecef,stroke:#495057,stroke-width:1px;
    classDef storage fill:#e9ecef,stroke:#495057,stroke-width:2px;
    classDef monitor fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;
    class Client,LB,K8s,Pods,Storage,Monitor subgraph;
    class Persistent storage;
    class Monitor monitor;

部署架构图展示了在 Kubernetes 环境下的部署架构:

  • 多种客户端访问方式
  • 负载均衡器分发请求
  • Kubernetes 集群中的服务实例
  • 持久化存储
  • 监控和可视化系统

2. 核心接口定义

首先,让我们定义核心接口:

// model.go
package llm

import (
    "context"
    "time"
)

// ModelConfig 定义模型配置
type ModelConfig struct {
    ModelPath    string            `json:"model_path"`
    MaxTokens    int              `json:"max_tokens"`
    Temperature  float64          `json:"temperature"`
    TopP         float64          `json:"top_p"`
    BatchSize    int              `json:"batch_size"`
    Device       string           `json:"device"`        // CPU/GPU
    ModelType    string           `json:"model_type"`    // ONNX/TensorRT
    CacheSize    int              `json:"cache_size"`    // 缓存大小
    Threads      int              `json:"threads"`       // 线程数
}

// InferenceRequest 定义推理请求
type InferenceRequest struct {
    Prompt      string            `json:"prompt"`
    MaxTokens   int              `json:"max_tokens,omitempty"`
    Temperature float64          `json:"temperature,omitempty"`
    TopP        float64          `json:"top_p,omitempty"`
    Stream      bool             `json:"stream,omitempty"`
    StopWords   []string         `json:"stop_words,omitempty"`
    UserID      string           `json:"user_id,omitempty"`
    SessionID   string           `json:"session_id,omitempty"`
}

// InferenceResponse 定义推理响应
type InferenceResponse struct {
    Text        string    `json:"text"`
    Tokens      int       `json:"tokens"`
    Latency     float64   `json:"latency"`
    Error       string    `json:"error,omitempty"`
    Usage       Usage     `json:"usage"`
    FinishReason string   `json:"finish_reason"`
}

// Usage 定义使用统计
type Usage struct {
    PromptTokens     int `json:"prompt_tokens"`
    CompletionTokens int `json:"completion_tokens"`
    TotalTokens      int `json:"total_tokens"`
}

// Model 定义模型接口
type Model interface {
    // Load 加载模型
    Load(config ModelConfig) error
    
    // Infer 执行推理
    Infer(ctx context.Context, req InferenceRequest) (*InferenceResponse, error)
    
    // BatchInfer 批量推理
    BatchInfer(ctx context.Context, reqs []InferenceRequest) ([]*InferenceResponse, error)
    
    // StreamInfer 流式推理
    StreamInfer(ctx context.Context, req InferenceRequest, ch chan<- *InferenceResponse) error
    
    // Close 关闭模型
    Close() error
    
    // GetModelInfo 获取模型信息
    GetModelInfo() ModelInfo
}

// ModelInfo 定义模型信息
type ModelInfo struct {
    Name        string    `json:"name"`
    Version     string    `json:"version"`
    Type        string    `json:"type"`
    Parameters  int64     `json:"parameters"`
    MaxTokens   int       `json:"max_tokens"`
    CreatedAt   time.Time `json:"created_at"`
}

3. 模型推理封装

下面是一个具体的模型实现示例:

// model_impl.go
package llm

import (
    "context"
    "fmt"
    "sync"
    "time"
    
    "github.com/your-org/llm-sdk"  // 假设的模型SDK
)

// ModelImpl 实现 Model 接口
type ModelImpl struct {
    config     ModelConfig
    model      interface{} // 实际的模型实例
    mu         sync.RWMutex
    isLoaded   bool
    cache      *Cache
    metrics    *Metrics
}

// Cache 实现简单的LRU缓存
type Cache struct {
    size       int
    data       map[string]*InferenceResponse
    mu         sync.RWMutex
}

// NewCache 创建新的缓存实例
func NewCache(size int) *Cache {
    return &Cache{
        size: size,
        data: make(map[string]*InferenceResponse),
    }
}

// Get 从缓存获取结果
func (c *Cache) Get(key string) (*InferenceResponse, bool) {
    c.mu.RLock()
    defer c.mu.RUnlock()
    resp, ok := c.data[key]
    return resp, ok
}

// Set 设置缓存结果
func (c *Cache) Set(key string, resp *InferenceResponse) {
    c.mu.Lock()
    defer c.mu.Unlock()
    
    if len(c.data) >= c.size {
        // 简单的LRU实现:删除最早的条目
        for k := range c.data {
            delete(c.data, k)
            break
        }
    }
    c.data[key] = resp
}

// NewModel 创建新的模型实例
func NewModel() Model {
    return &ModelImpl{
        cache:   NewCache(1000),
        metrics: &Metrics{},
    }
}

// Load 实现模型加载
func (m *ModelImpl) Load(config ModelConfig) error {
    m.mu.Lock()
    defer m.mu.Unlock()

    if m.isLoaded {
        return nil
    }

    // 根据模型类型选择不同的加载方式
    switch config.ModelType {
    case "ONNX":
        m.model = loadONNXModel(config)
    case "TensorRT":
        m.model = loadTensorRTModel(config)
    default:
        return fmt.Errorf("unsupported model type: %s", config.ModelType)
    }
    
    m.config = config
    m.isLoaded = true
    return nil
}

// Infer 实现推理逻辑
func (m *ModelImpl) Infer(ctx context.Context, req InferenceRequest) (*InferenceResponse, error) {
    if !m.isLoaded {
        return nil, fmt.Errorf("model not loaded")
    }

    // 检查缓存
    cacheKey := generateCacheKey(req)
    if resp, ok := m.cache.Get(cacheKey); ok {
        return resp, nil
    }

    start := time.Now()
    
    // 实现具体的推理逻辑
    var response *InferenceResponse
    var err error
    
    switch m.config.ModelType {
    case "ONNX":
        response, err = m.inferONNX(ctx, req)
    case "TensorRT":
        response, err = m.inferTensorRT(ctx, req)
    default:
        return nil, fmt.Errorf("unsupported model type: %s", m.config.ModelType)
    }
    
    if err != nil {
        m.metrics.UpdateMetrics(time.Since(start).Seconds(), err)
        return nil, err
    }
    
    // 更新指标
    m.metrics.UpdateMetrics(time.Since(start).Seconds(), nil)
    
    // 缓存结果
    m.cache.Set(cacheKey, response)
    
    return response, nil
}

// BatchInfer 实现批量推理
func (m *ModelImpl) BatchInfer(ctx context.Context, reqs []InferenceRequest) ([]*InferenceResponse, error) {
    if !m.isLoaded {
        return nil, fmt.Errorf("model not loaded")
    }
    
    results := make([]*InferenceResponse, len(reqs))
    errChan := make(chan error, len(reqs))
    
    // 使用工作池处理批量请求
    for i, req := range reqs {
        go func(idx int, r InferenceRequest) {
            resp, err := m.Infer(ctx, r)
            if err != nil {
                errChan <- err
                return
            }
            results[idx] = resp
        }(i, req)
    }
    
    // 等待所有请求完成
    for i := 0; i < len(reqs); i++ {
        select {
        case err := <-errChan:
            return nil, err
        case <-ctx.Done():
            return nil, ctx.Err()
        }
    }
    
    return results, nil
}

// StreamInfer 实现流式推理
func (m *ModelImpl) StreamInfer(ctx context.Context, req InferenceRequest, ch chan<- *InferenceResponse) error {
    if !m.isLoaded {
        return fmt.Errorf("model not loaded")
    }
    
    // 实现流式推理逻辑
    go func() {
        defer close(ch)
        
        // 模拟流式输出
        for i := 0; i < 10; i++ {
            select {
            case <-ctx.Done():
                return
            default:
                response := &InferenceResponse{
                    Text:    fmt.Sprintf("Stream chunk %d", i),
                    Tokens:  i,
                    Latency: float64(i) * 0.1,
                }
                ch <- response
                time.Sleep(100 * time.Millisecond)
            }
        }
    }()
    
    return nil
}

// Close 实现资源清理
func (m *ModelImpl) Close() error {
    m.mu.Lock()
    defer m.mu.Unlock()
    
    if !m.isLoaded {
        return nil
    }
    
    // 清理模型资源
    switch m.config.ModelType {
    case "ONNX":
        cleanupONNXModel(m.model)
    case "TensorRT":
        cleanupTensorRTModel(m.model)
    }
    
    m.isLoaded = false
    return nil
}

// GetModelInfo 获取模型信息
func (m *ModelImpl) GetModelInfo() ModelInfo {
    return ModelInfo{
        Name:       "example-model",
        Version:    "1.0.0",
        Type:       m.config.ModelType,
        Parameters: 7_000_000_000,
        MaxTokens:  m.config.MaxTokens,
        CreatedAt:  time.Now(),
    }
}

// 辅助函数
func generateCacheKey(req InferenceRequest) string {
    // 实现缓存键生成逻辑
    return fmt.Sprintf("%s_%d_%.2f_%.2f", 
        req.Prompt, 
        req.MaxTokens, 
        req.Temperature, 
        req.TopP)
}

func loadONNXModel(config ModelConfig) interface{} {
    // 实现ONNX模型加载逻辑
    return nil
}

func loadTensorRTModel(config ModelConfig) interface{} {
    // 实现TensorRT模型加载逻辑
    return nil
}

func (m *ModelImpl) inferONNX(ctx context.Context, req InferenceRequest) (*InferenceResponse, error) {
    // 实现ONNX推理逻辑
    return &InferenceResponse{
        Text:    "ONNX inference result",
        Tokens:  100,
        Latency: 0.1,
    }, nil
}

func (m *ModelImpl) inferTensorRT(ctx context.Context, req InferenceRequest) (*InferenceResponse, error) {
    // 实现TensorRT推理逻辑
    return &InferenceResponse{
        Text:    "TensorRT inference result",
        Tokens:  100,
        Latency: 0.05,
    }, nil
}

func cleanupONNXModel(model interface{}) {
    // 实现ONNX模型清理逻辑
}

func cleanupTensorRTModel(model interface{}) {
    // 实现TensorRT模型清理逻辑
}

4. 并发控制与性能优化

为了实现高效的并发处理,我们使用工作池模式:

// worker_pool.go
package llm

import (
    "context"
    "sync"
    "time"
)

// WorkerPool 工作池实现
type WorkerPool struct {
    workers    int
    tasks      chan *Task
    wg         sync.WaitGroup
    metrics    *Metrics
    timeout    time.Duration
}

// Task 定义任务结构
type Task struct {
    Request  InferenceRequest
    Response chan *InferenceResponse
    Error    chan error
    Start    time.Time
}

// NewWorkerPool 创建新的工作池
func NewWorkerPool(workers int) *WorkerPool {
    return &WorkerPool{
        workers: workers,
        tasks:   make(chan *Task, workers*2),
        metrics: &Metrics{},
        timeout: 30 * time.Second,
    }
}

// Start 启动工作池
func (p *WorkerPool) Start(ctx context.Context, model Model) {
    for i := 0; i < p.workers; i++ {
        p.wg.Add(1)
        go p.worker(ctx, model)
    }
}

// worker 工作协程实现
func (p *WorkerPool) worker(ctx context.Context, model Model) {
    defer p.wg.Done()
    
    for {
        select {
        case task := <-p.tasks:
            // 检查任务是否超时
            if time.Since(task.Start) > p.timeout {
                task.Error <- fmt.Errorf("task timeout")
                continue
            }
            
            // 执行推理
            resp, err := model.Infer(ctx, task.Request)
            if err != nil {
                task.Error <- err
                p.metrics.UpdateMetrics(time.Since(task.Start).Seconds(), err)
            } else {
                task.Response <- resp
                p.metrics.UpdateMetrics(time.Since(task.Start).Seconds(), nil)
            }
            
        case <-ctx.Done():
            return
        }
    }
}

// Submit 提交任务到工作池
func (p *WorkerPool) Submit(task *Task) error {
    task.Start = time.Now()
    select {
    case p.tasks <- task:
        return nil
    default:
        return fmt.Errorf("worker pool is full")
    }
}

// Stop 停止工作池
func (p *WorkerPool) Stop() {
    close(p.tasks)
    p.wg.Wait()
}

5. 错误处理与监控

实现错误处理和监控机制:

// metrics.go
package llm

import (
    "sync/atomic"
    "time"
)

// Metrics 定义监控指标
type Metrics struct {
    TotalRequests    uint64
    FailedRequests   uint64
    TotalLatency     float64
    LastError        error
    LastErrorTime    time.Time
    QueueSize        int32
    ActiveWorkers    int32
}

// UpdateMetrics 更新监控指标
func (m *Metrics) UpdateMetrics(latency float64, err error) {
    atomic.AddUint64(&m.TotalRequests, 1)
    if err != nil {
        atomic.AddUint64(&m.FailedRequests, 1)
        m.LastError = err
        m.LastErrorTime = time.Now()
    }
    atomic.AddUint64((*uint64)(&m.TotalLatency), uint64(latency*1000))
}

// GetMetrics 获取当前指标
func (m *Metrics) GetMetrics() map[string]interface{} {
    return map[string]interface{}{
        "total_requests":   atomic.LoadUint64(&m.TotalRequests),
        "failed_requests":  atomic.LoadUint64(&m.FailedRequests),
        "total_latency":    atomic.LoadUint64((*uint64)(&m.TotalLatency)),
        "last_error":       m.LastError,
        "last_error_time":  m.LastErrorTime,
        "queue_size":       atomic.LoadInt32(&m.QueueSize),
        "active_workers":   atomic.LoadInt32(&m.ActiveWorkers),
    }
}

// PrometheusMetrics 实现 Prometheus 指标
type PrometheusMetrics struct {
    *Metrics
    // 添加 Prometheus 特定的指标
}

// NewPrometheusMetrics 创建新的 Prometheus 指标实例
func NewPrometheusMetrics() *PrometheusMetrics {
    return &PrometheusMetrics{
        Metrics: &Metrics{},
    }
}

// RegisterMetrics 注册 Prometheus 指标
func (m *PrometheusMetrics) RegisterMetrics() {
    // 实现 Prometheus 指标注册
}

6. 完整服务实现

最后,我们实现完整的 HTTP 服务:

// server.go
package llm

import (
    "context"
    "encoding/json"
    "net/http"
    "time"
    
    "github.com/gorilla/mux"
    "go.uber.org/zap"
)

// Server 定义 HTTP 服务
type Server struct {
    model    Model
    pool     *WorkerPool
    metrics  *Metrics
    logger   *zap.Logger
    router   *mux.Router
    config   ServerConfig
}

// ServerConfig 定义服务器配置
type ServerConfig struct {
    Port            int           `json:"port"`
    ReadTimeout     time.Duration `json:"read_timeout"`
    WriteTimeout    time.Duration `json:"write_timeout"`
    MaxConns        int           `json:"max_conns"`
    EnableMetrics   bool          `json:"enable_metrics"`
    EnableProfiling bool          `json:"enable_profiling"`
}

// NewServer 创建新的服务实例
func NewServer(model Model, config ServerConfig) *Server {
    logger, _ := zap.NewProduction()
    
    return &Server{
        model:   model,
        pool:    NewWorkerPool(config.MaxConns),
        metrics: &Metrics{},
        logger:  logger,
        router:  mux.NewRouter(),
        config:  config,
    }
}

// Start 启动服务
func (s *Server) Start() error {
    ctx := context.Background()
    s.pool.Start(ctx, s.model)
    
    // 注册路由
    s.registerRoutes()
    
    // 创建 HTTP 服务器
    server := &http.Server{
        Addr:         fmt.Sprintf(":%d", s.config.Port),
        Handler:      s.router,
        ReadTimeout:  s.config.ReadTimeout,
        WriteTimeout: s.config.WriteTimeout,
    }
    
    s.logger.Info("Starting server",
        zap.Int("port", s.config.Port),
        zap.Duration("read_timeout", s.config.ReadTimeout),
        zap.Duration("write_timeout", s.config.WriteTimeout))
    
    return server.ListenAndServe()
}

// registerRoutes 注册路由
func (s *Server) registerRoutes() {
    // 健康检查
    s.router.HandleFunc("/health", s.handleHealth).Methods(http.MethodGet)
    
    // 推理接口
    s.router.HandleFunc("/v1/infer", s.handleInference).Methods(http.MethodPost)
    s.router.HandleFunc("/v1/batch_infer", s.handleBatchInference).Methods(http.MethodPost)
    s.router.HandleFunc("/v1/stream_infer", s.handleStreamInference).Methods(http.MethodPost)
    
    // 模型信息
    s.router.HandleFunc("/v1/model_info", s.handleModelInfo).Methods(http.MethodGet)
    
    // 指标接口
    if s.config.EnableMetrics {
        s.router.HandleFunc("/metrics", s.handleMetrics).Methods(http.MethodGet)
    }
    
    // 性能分析
    if s.config.EnableProfiling {
        s.router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux)
    }
}

// handleHealth 处理健康检查
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
    w.WriteHeader(http.StatusOK)
    json.NewEncoder(w).Encode(map[string]string{
        "status": "ok",
        "time":   time.Now().Format(time.RFC3339),
    })
}

// handleInference 处理推理请求
func (s *Server) handleInference(w http.ResponseWriter, r *http.Request) {
    var req InferenceRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        s.logger.Error("Failed to decode request",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    task := &Task{
        Request:  req,
        Response: make(chan *InferenceResponse, 1),
        Error:    make(chan error, 1),
    }

    if err := s.pool.Submit(task); err != nil {
        s.logger.Error("Failed to submit task",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusServiceUnavailable)
        return
    }

    select {
    case resp := <-task.Response:
        w.Header().Set("Content-Type", "application/json")
        json.NewEncoder(w).Encode(resp)
    case err := <-task.Error:
        s.logger.Error("Inference failed",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusInternalServerError)
    case <-time.After(s.config.WriteTimeout):
        s.logger.Error("Request timeout")
        http.Error(w, "Request timeout", http.StatusGatewayTimeout)
    }
}

// handleBatchInference 处理批量推理请求
func (s *Server) handleBatchInference(w http.ResponseWriter, r *http.Request) {
    var reqs []InferenceRequest
    if err := json.NewDecoder(r.Body).Decode(&reqs); err != nil {
        s.logger.Error("Failed to decode batch request",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    responses, err := s.model.BatchInfer(r.Context(), reqs)
    if err != nil {
        s.logger.Error("Batch inference failed",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusInternalServerError)
        return
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(responses)
}

// handleStreamInference 处理流式推理请求
func (s *Server) handleStreamInference(w http.ResponseWriter, r *http.Request) {
    var req InferenceRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        s.logger.Error("Failed to decode stream request",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    w.Header().Set("Content-Type", "text/event-stream")
    w.Header().Set("Cache-Control", "no-cache")
    w.Header().Set("Connection", "keep-alive")

    ch := make(chan *InferenceResponse, 10)
    err := s.model.StreamInfer(r.Context(), req, ch)
    if err != nil {
        s.logger.Error("Stream inference failed",
            zap.Error(err))
        http.Error(w, err.Error(), http.StatusInternalServerError)
        return
    }

    for resp := range ch {
        data, err := json.Marshal(resp)
        if err != nil {
            s.logger.Error("Failed to marshal response",
                zap.Error(err))
            continue
        }
        fmt.Fprintf(w, "data: %s\n\n", data)
        w.(http.Flusher).Flush()
    }
}

// handleModelInfo 处理模型信息请求
func (s *Server) handleModelInfo(w http.ResponseWriter, r *http.Request) {
    info := s.model.GetModelInfo()
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(info)
}

// handleMetrics 处理指标请求
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
    metrics := s.metrics.GetMetrics()
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(metrics)
}

7. 部署与扩展建议

  1. 容器化部署

    • 使用 Docker 打包服务
    • 使用 Kubernetes 进行容器编排
    • 实现健康检查和自动扩缩容
  2. 性能优化

    • 使用 GPU 加速
    • 实现模型量化
    • 使用批处理优化
    • 实现请求合并
    • 使用内存池优化
  3. 监控告警

    • 集成 Prometheus 监控
    • 设置关键指标告警
    • 实现日志追踪
    • 使用 OpenTelemetry 进行分布式追踪
  4. 扩展建议

    • 实现模型版本管理
    • 添加负载均衡
    • 实现服务发现
    • 添加缓存层
    • 实现限流和熔断
    • 添加认证和授权
    • 实现模型热更新

总结

本文详细介绍了如何使用 Go 语言构建大模型推理服务。通过合理的架构设计和实现,我们可以构建出高性能、可扩展的推理服务。关键点包括:

  1. 清晰的接口定义
  2. 高效的并发处理
  3. 完善的错误处理
  4. 可靠的监控机制
  5. 良好的扩展性

在实际部署时,还需要根据具体场景进行优化和调整。希望本文能够帮助读者更好地理解和实现大模型推理服务。