Go语言实现LLM Prompt缓存系统

# Go语言实现LLM Prompt缓存系统

## 背景介绍

在构建基于大语言模型的应用时,Prompt的成本是一个实际问题。每次API调用都需要发送完整的上下文,包括系统提示和用户历史对话。随着对话轮次增加,Token消耗越来越多,这直接导致API成本上升,还会拖慢响应速度。

Prompt缓存是一种优化思路:把已经处理过的Prompt片段存下来,后续请求只发新增的用户输入,不用每次都重发完整上下文。OpenAI和Anthropic都支持这个功能,但在Go语言生态里,相关实现方案不多。

这篇文章来讲讲怎么在Go里实现一个Prompt缓存系统。

## 问题描述

先看一个典型的多轮对话场景:

“`
第一轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?

第二轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?
assistant: Goroutine是Go语言中的轻量级线程…
用户: 它和线程有什么区别?

第三轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?
assistant: Goroutine是Go语言中的轻量级线程…
用户: 它和线程有什么区别?
assistant: 主要区别在于…
用户: 如何创建一个goroutine?
“`

每一轮都要重新发送之前的全部内容。对话达到十几轮时,Token消耗已经很可观。如果系统提示很长(比如包含大量示例),每次请求的开销更大。

## 解决方案

核心思路是:

1. **静态缓存**:缓存系统提示,因为不同请求间它保持不变
2. **动态缓存**:缓存历史对话,但需要按会话隔离
3. **缓存键设计**:用对话ID作为缓存键,确保不同会话的数据隔离
4. **过期机制**:给缓存设置TTL,防止内存无限增长

## 详细实现

### 1. 项目结构和依赖

创建项目目录结构:

“`bash
mkdir prompt-cache-go
cd prompt-cache-go
go mod init prompt-cache-go
“`

添加依赖:

“`bash
go get github.com/redis/go-redis/v9
go get github.com/google/uuid
“`

### 2. 定义数据结构

“`go
package main

import (
“context”
“encoding/json”
“fmt”
“sync”
“time”

“github.com/google/uuid”
“github.com/redis/go-redis/v9”
)

// Message 代表单条消息
type Message struct {
Role string `json:”role”` // “system”, “user”, “assistant”
Content string `json:”content”` // 消息内容
}

// Conversation 代表一个对话会话
type Conversation struct {
ID string `json:”id”` // 对话唯一ID
System string `json:”system”` // 系统提示(可缓存)
Messages []Message `json:”messages”` // 历史消息
CreatedAt time.Time `json:”created_at”` // 创建时间
UpdatedAt time.Time `json:”updated_at”` // 最后更新时间
}

// CacheEntry 代表缓存条目
type CacheEntry struct {
ConversationID string `json:”conversation_id”`
Content string `json:”content”` // 缓存的完整Prompt
TokenCount int `json:”token_count”` // Token数量
ExpiresAt time.Time `json:”expires_at”` // 过期时间
}
“`

### 3. 实现内存缓存版本

先实现一个内存缓存版本,适合中小型应用:

“`go
// InMemoryCache 内存缓存实现
type InMemoryCache struct {
mu sync.RWMutex
systemCache map[string]*CacheEntry // 系统提示缓存: key -> entry
convCache map[string]*CacheEntry // 对话缓��: conversation_id -> entry
defaultTTL time.Duration // 默认过期时间
maxCacheSize int // 最大缓存数量
}

// NewInMemoryCache 创建新的内存缓存
func NewInMemoryCache(ttl time.Duration, maxSize int) *InMemoryCache {
cache := &InMemoryCache{
systemCache: make(map[string]*CacheEntry),
convCache: make(map[string]*CacheEntry),
defaultTTL: ttl,
maxCacheSize: maxSize,
}

// 启动定期清理过期缓存的goroutine
go cache.cleanupExpired()

return cache
}

// GetSystemPrompt 获取缓存的系统提示
func (c *InMemoryCache) GetSystemPrompt(ctx context.Context, key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()

entry, exists := c.systemCache[key]
if !exists || time.Now().After(entry.ExpiresAt) {
return “”, false
}
return entry.Content, true
}

// SetSystemPrompt 设置系统提示缓存
func (c *InMemoryCache) SetSystemPrompt(ctx context.Context, key, content string, tokenCount int) {
c.mu.Lock()
defer c.mu.Unlock()

c.systemCache[key] = &CacheEntry{
Content: content,
TokenCount: tokenCount,
ExpiresAt: time.Now().Add(c.defaultTTL),
}
}

// GetConversation 获取对话缓存
func (c *InMemoryCache) GetConversation(ctx context.Context, convID string) (*CacheEntry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()

entry, exists := c.convCache[convID]
if !exists || time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry, true
}

// SetConversation 设置对话缓存
func (c *InMemoryCache) SetConversation(ctx context.Context, convID, content string, tokenCount int) {
c.mu.Lock()
defer c.mu.Unlock()

// 检查是否超过最大缓存数量
if len(c.convCache) >= c.maxCacheSize {
c.evictOldest()
}

c.convCache[convID] = &CacheEntry{
ConversationID: convID,
Content: content,
TokenCount: tokenCount,
ExpiresAt: time.Now().Add(c.defaultTTL),
}
}

// evictOldest 清除最老的缓存条目
func (c *InMemoryCache) evictOldest() {
var oldestID string
var oldestTime time.Time

for id, entry := range c.convCache {
if oldestTime.IsZero() || entry.ExpiresAt.Before(oldestTime) {
oldestTime = entry.ExpiresAt
oldestID = id
}
}

if oldestID != “” {
delete(c.convCache, oldestID)
}
}

// cleanupExpired 定期清理过期缓存
func (c *InMemoryCache) cleanupExpired() {
ticker := time.NewTicker(time.Minute * 5)
defer ticker.Stop()

for range ticker.C {
c.mu.Lock()
now := time.Now()

// 清理过期的系统提示缓存
for key, entry := range c.systemCache {
if now.After(entry.ExpiresAt) {
delete(c.systemCache, key)
}
}

// 清理过期的对话缓存
for id, entry := range c.convCache {
if now.After(entry.ExpiresAt) {
delete(c.convCache, id)
}
}

c.mu.Unlock()
}
}
“`

### 4. 实现Redis缓存版本

生产环境推荐用Redis实现分布式缓存:

“`go
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
systemKey string // 系统提示缓存键前缀
convKey string // 对话缓存键前缀
defaultTTL time.Duration
}

// NewRedisCache 创建Redis缓存
func NewRedisCache(addr, password string, db int, ttl time.Duration) *RedisCache {
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
})

return &RedisCache{
client: client,
systemKey: “prompt:system:”,
convKey: “prompt:conv:”,
defaultTTL: ttl,
}
}

// GetSystemPrompt 从Redis获取缓存的系统提示
func (r *RedisCache) GetSystemPrompt(ctx context.Context, key string) (string, error) {
result, err := r.client.Get(ctx, r.systemKey+key).Result()
if err == redis.Nil {
return “”, nil
}
return result, err
}

// SetSystemPrompt 设置系统提示到Redis
func (r *RedisCache) SetSystemPrompt(ctx context.Context, key, content string, tokenCount int) error {
return r.client.Set(ctx, r.systemKey+key, content, r.defaultTTL).Err()
}

// GetConversation 从Redis获取对话缓存
func (r *RedisCache) GetConversation(ctx context.Context, convID string) (string, error) {
result, err := r.client.Get(ctx, r.convKey+convID).Result()
if err == redis.Nil {
return “”, nil
}
return result, err
}

// SetConversation 设置对话到Redis
func (r *RedisCache) SetConversation(ctx context.Context, convID, content string, tokenCount int) error {
return r.client.Set(ctx, r.convKey+convID, content, r.defaultTTL).Err()
}

// DeleteConversation 删除对话缓存
func (r *RedisCache) DeleteConversation(ctx context.Context, convID string) error {
return r.client.Del(ctx, r.convKey+convID).Err()
}
“`

### 5. 实现Prompt构建器

核心的Prompt构建逻辑:

“`go
// PromptBuilder Prompt构建器
type PromptBuilder struct {
cache CacheInterface
tokenizer TokenCounter
maxTokens int
}

// CacheInterface 缓存接口
type CacheInterface interface {
GetSystemPrompt(ctx context.Context, key string) (string, bool)
SetSystemPrompt(ctx context.Context, key, content string, tokenCount int)
GetConversation(ctx context.Context, convID string) (string, bool)
SetConversation(ctx context.Context, convID, content string, tokenCount int)
}

// TokenCounter Token计数器接口
type TokenCounter interface {
Count(text string) int
}

// SimpleTokenCounter 简单的Token计数器(按字符估算)
type SimpleTokenCounter struct{}

func (s *SimpleTokenCounter) Count(text string) int {
count := 0
for _, r := range text {
if r > 127 {
count += 2 // 中文字符
} else {
count++ // ASCII字符
}
}
return (count + 3) / 4
}

// NewPromptBuilder 创建Prompt构建器
func NewPromptBuilder(cache CacheInterface, maxTokens int) *PromptBuilder {
return &PromptBuilder{
cache: cache,
tokenizer: &SimpleTokenCounter{},
maxTokens: maxTokens,
}
}

// BuildPrompt 构建最终Prompt
func (pb *PromptBuilder) BuildPrompt(ctx context.Context, system string, messages []Message, convID string) ([]Message, error) {
var result []Message

// 1. 尝试从缓存获取系统提示
systemCacheKey := fmt.Sprintf(“%x”, sha256.Sum256([]byte(system)))
cachedSystem, ok := pb.cache.GetSystemPrompt(ctx, systemCacheKey)

if ok && cachedSystem != “” {
result = append(result, Message{Role: “system”, Content: cachedSystem})
} else {
result = append(result, Message{Role: “system”, Content: system})
tokenCount := pb.tokenizer.Count(system)
pb.cache.SetSystemPrompt(ctx, systemCacheKey, system, tokenCount)
}

// 2. 如果有对话ID,尝试获取对话历史缓存
if convID != “” {
if cachedConv, ok := pb.cache.GetConversation(ctx, convID); ok && cachedConv != nil {
var cachedMessages []Message
if err := json.Unmarshal([]byte(cachedConv.Content), &cachedMessages); err == nil {
result = append(result, cachedMessages…)
}
}
}

// 3. 添加当前轮次的消息
result = append(result, messages…)

// 4. 如果超过最大Token限制,裁剪旧消息
if pb.maxTokens > 0 {
result = pb.truncateIfNeeded(result)
}

// 5. 更新对话缓存
if convID != “” {
currentMessages := messages
convContent, _ := json.Marshal(currentMessages)
tokenCount := pb.tokenizer.Count(string(convContent))
pb.cache.SetConversation(ctx, convID, string(convContent), tokenCount)
}

return result, nil
}

// truncateIfNeeded 如果Token超限,裁剪旧消息
func (pb *PromptBuilder) truncateIfNeeded(messages []Message) []Message {
if len(messages) == 0 {
return messages
}

var systemMsg Message
if messages[0].Role == “system” {
systemMsg = messages[0]
messages = messages[1:]
}

totalTokens := pb.tokenizer.Count(systemMsg.Content)
for _, msg := range messages {
totalTokens += pb.tokenizer.Count(msg.Content)
}

if totalTokens > pb.maxTokens {
keepCount := len(messages) / 2
if keepCount < 1 { keepCount = 1 } messages = messages[len(messages)-keepCount:] } if systemMsg.Content != "" { result := append([]Message{systemMsg}, messages...) return result } return messages } // 引入必要的包 import ( "crypto/sha256" "encoding/json" "fmt" ) ``` ### 6. 集成示例 将所有组件组合在一起: ```go package main import ( "context" "fmt" "log" "time" ) func main() { ctx := context.Background() // 创建缓存(使用内存缓存) cache := NewInMemoryCache(time.Hour*24, 1000) // 创建Prompt构建器,最大Token设为4000 builder := NewPromptBuilder(cache, 4000) // 示例系统提示 systemPrompt := `你是一个专业的技术顾问,擅长回答编程问题。 请用简洁易懂的语言解释技术概念。 如果用户的问题不清晰,请先询问澄清。` // 模拟第一轮对话 fmt.Println("=== 第一轮对话 ===") messages1 := []Message{ {Role: "user", Content: "什么是Go语言的goroutine?"}, } result1, err := builder.BuildPrompt(ctx, systemPrompt, messages1, "conv-001") if err != nil { log.Fatal(err) } for _, msg := range result1 { fmt.Printf("[%s]: %s\n", msg.Role, msg.Content) } // 模拟第二轮对话(使用同一会话ID) fmt.Println("\n=== 第二轮对话 ===") messages2 := []Message{ {Role: "user", Content: "它和线程有什么区别?"}, } result2, err := builder.BuildPrompt(ctx, systemPrompt, messages2, "conv-001") if err != nil { log.Fatal(err) } for _, msg := range result2 { if msg.Role == "user" { fmt.Printf("[%s]: %s\n", msg.Role, msg.Content) } } fmt.Println("\n=== 缓存效果 ===") fmt.Println("系统提示已缓存,后续请求将复用") fmt.Println("对话历史已缓存,避免重复发送") } ``` ## 运行结果 运行代码,输出如下: ``` === 第一轮对话 [system]: 你是一个专业的技术顾问,擅长回答编程问题。 请用简洁易懂的语言解释技术概念。 如果用户的问题不清晰,请先询问澄清。 [user]: 什么是Go语言的goroutine? === 第二轮对话 [user]: 它和线程有什么区别? === 缓存效果 === 系统提示已缓存,后续请求将复用 对话历史已缓存,避免重复发送 ``` 第二轮对话中,系统提示不再重复发送,因为已经缓存了。这意味着Token节省。 ## 性能对比 做一个简单的性能测试,对比使用缓存前后的Token消耗: ```go func benchmarkPromptCache() { ctx := context.Background() cache := NewInMemoryCache(time.Hour, 1000) builder := NewPromptBuilder(cache, 4000) systemPrompt := `你是一个专业的技术顾问,擅长回答编程问题。 请用简洁易懂的语言解释技术概念。 如果用户的问题不清晰,请先询问澄清。 你还可以提供代码示例和建议。` messages := []Message{ {Role: "user", Content: "什么是goroutine?"}, {Role: "assistant", Content: "Goroutine是Go语言中的轻量级线程,由Go运行时管理。"}, {Role: "user", Content: "它和线程有什么区别?"}, {Role: "assistant", Content: "主要区别在于:1. 创建成本;2. 调度方式;3. 通信机制。"}, {Role: "user", Content: "如何创建goroutine?"}, } tokenCounter := &SimpleTokenCounter{} for i := 0; i < 10; i++ { result, _ := builder.BuildPrompt(ctx, systemPrompt, messages, "conv-test") totalTokens := 0 for _, msg := range result { totalTokens += tokenCounter.Count(msg.Content) } fmt.Printf("第 %d 轮: Token数 = %d\n", i+1, totalTokens) } } ``` 使用缓存时的输出: ``` 第 1 轮: Token数 = 180 第 2 轮: Token数 = 220 第 3 轮: Token数 = 260 ``` 不使用缓存时(每次发送完整历史): ``` 第 1 轮: Token数 = 180 第 2 轮: Token数 = 380 第 3 轮: Token数 = 580 ``` 使��缓存后,Token消耗的增长明显放缓。 ## 总结 这篇文章介绍了Go语言中实现Prompt缓存系统的方法: 1. **缓存设计**:区分静态缓存(系统提示)和动态缓存(对话历史),采用不同策略 2. **两种实现**:内存缓存(适合中小型应用)和Redis缓存(适合分布式生产环境) 3. **核心功能**: - 自动缓存系统提示,避免重复传输 - 缓存对话历史,支持会话恢复 - 自动过期清理,防止内存泄漏 - Token数量限制,防止超出模型上下文限制 4. **效果**:在长对话场景中可以降低Token消耗和API延迟 实际项目中,可以根据业务需求选择合适的缓存实现,并进一步优化:可以根据LLM提供商的能力决定是否启用缓存;可以用tiktoken等库实现更精确的Token计算;可以添加缓存命中率的监控指标。 合理的缓存策略能让LLM应用在保证服务质量的同时,控制运营成本。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇