# Go语言实现LLM Prompt缓存系统
## 背景介绍
在构建基于大语言模型的应用时,Prompt的成本是一个实际问题。每次API调用都需要发送完整的上下文,包括系统提示和用户历史对话。随着对话轮次增加,Token消耗越来越多,这直接导致API成本上升,还会拖慢响应速度。
Prompt缓存是一种优化思路:把已经处理过的Prompt片段存下来,后续请求只发新增的用户输入,不用每次都重发完整上下文。OpenAI和Anthropic都支持这个功能,但在Go语言生态里,相关实现方案不多。
这篇文章来讲讲怎么在Go里实现一个Prompt缓存系统。
## 问题描述
先看一个典型的多轮对话场景:
“`
第一轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?
第二轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?
assistant: Goroutine是Go语言中的轻量级线程…
用户: 它和线程有什么区别?
第三轮:
系统提示: 你是一个专业的技术顾问,擅长回答编程问题
用户: 什么是Go语言的goroutine?
assistant: Goroutine是Go语言中的轻量级线程…
用户: 它和线程有什么区别?
assistant: 主要区别在于…
用户: 如何创建一个goroutine?
“`
每一轮都要重新发送之前的全部内容。对话达到十几轮时,Token消耗已经很可观。如果系统提示很长(比如包含大量示例),每次请求的开销更大。
## 解决方案
核心思路是:
1. **静态缓存**:缓存系统提示,因为不同请求间它保持不变
2. **动态缓存**:缓存历史对话,但需要按会话隔离
3. **缓存键设计**:用对话ID作为缓存键,确保不同会话的数据隔离
4. **过期机制**:给缓存设置TTL,防止内存无限增长
## 详细实现
### 1. 项目结构和依赖
创建项目目录结构:
“`bash
mkdir prompt-cache-go
cd prompt-cache-go
go mod init prompt-cache-go
“`
添加依赖:
“`bash
go get github.com/redis/go-redis/v9
go get github.com/google/uuid
“`
### 2. 定义数据结构
“`go
package main
import (
“context”
“encoding/json”
“fmt”
“sync”
“time”
“github.com/google/uuid”
“github.com/redis/go-redis/v9”
)
// Message 代表单条消息
type Message struct {
Role string `json:”role”` // “system”, “user”, “assistant”
Content string `json:”content”` // 消息内容
}
// Conversation 代表一个对话会话
type Conversation struct {
ID string `json:”id”` // 对话唯一ID
System string `json:”system”` // 系统提示(可缓存)
Messages []Message `json:”messages”` // 历史消息
CreatedAt time.Time `json:”created_at”` // 创建时间
UpdatedAt time.Time `json:”updated_at”` // 最后更新时间
}
// CacheEntry 代表缓存条目
type CacheEntry struct {
ConversationID string `json:”conversation_id”`
Content string `json:”content”` // 缓存的完整Prompt
TokenCount int `json:”token_count”` // Token数量
ExpiresAt time.Time `json:”expires_at”` // 过期时间
}
“`
### 3. 实现内存缓存版本
先实现一个内存缓存版本,适合中小型应用:
“`go
// InMemoryCache 内存缓存实现
type InMemoryCache struct {
mu sync.RWMutex
systemCache map[string]*CacheEntry // 系统提示缓存: key -> entry
convCache map[string]*CacheEntry // 对话缓��: conversation_id -> entry
defaultTTL time.Duration // 默认过期时间
maxCacheSize int // 最大缓存数量
}
// NewInMemoryCache 创建新的内存缓存
func NewInMemoryCache(ttl time.Duration, maxSize int) *InMemoryCache {
cache := &InMemoryCache{
systemCache: make(map[string]*CacheEntry),
convCache: make(map[string]*CacheEntry),
defaultTTL: ttl,
maxCacheSize: maxSize,
}
// 启动定期清理过期缓存的goroutine
go cache.cleanupExpired()
return cache
}
// GetSystemPrompt 获取缓存的系统提示
func (c *InMemoryCache) GetSystemPrompt(ctx context.Context, key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.systemCache[key]
if !exists || time.Now().After(entry.ExpiresAt) {
return “”, false
}
return entry.Content, true
}
// SetSystemPrompt 设置系统提示缓存
func (c *InMemoryCache) SetSystemPrompt(ctx context.Context, key, content string, tokenCount int) {
c.mu.Lock()
defer c.mu.Unlock()
c.systemCache[key] = &CacheEntry{
Content: content,
TokenCount: tokenCount,
ExpiresAt: time.Now().Add(c.defaultTTL),
}
}
// GetConversation 获取对话缓存
func (c *InMemoryCache) GetConversation(ctx context.Context, convID string) (*CacheEntry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.convCache[convID]
if !exists || time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry, true
}
// SetConversation 设置对话缓存
func (c *InMemoryCache) SetConversation(ctx context.Context, convID, content string, tokenCount int) {
c.mu.Lock()
defer c.mu.Unlock()
// 检查是否超过最大缓存数量
if len(c.convCache) >= c.maxCacheSize {
c.evictOldest()
}
c.convCache[convID] = &CacheEntry{
ConversationID: convID,
Content: content,
TokenCount: tokenCount,
ExpiresAt: time.Now().Add(c.defaultTTL),
}
}
// evictOldest 清除最老的缓存条目
func (c *InMemoryCache) evictOldest() {
var oldestID string
var oldestTime time.Time
for id, entry := range c.convCache {
if oldestTime.IsZero() || entry.ExpiresAt.Before(oldestTime) {
oldestTime = entry.ExpiresAt
oldestID = id
}
}
if oldestID != “” {
delete(c.convCache, oldestID)
}
}
// cleanupExpired 定期清理过期缓存
func (c *InMemoryCache) cleanupExpired() {
ticker := time.NewTicker(time.Minute * 5)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now()
// 清理过期的系统提示缓存
for key, entry := range c.systemCache {
if now.After(entry.ExpiresAt) {
delete(c.systemCache, key)
}
}
// 清理过期的对话缓存
for id, entry := range c.convCache {
if now.After(entry.ExpiresAt) {
delete(c.convCache, id)
}
}
c.mu.Unlock()
}
}
“`
### 4. 实现Redis缓存版本
生产环境推荐用Redis实现分布式缓存:
“`go
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
systemKey string // 系统提示缓存键前缀
convKey string // 对话缓存键前缀
defaultTTL time.Duration
}
// NewRedisCache 创建Redis缓存
func NewRedisCache(addr, password string, db int, ttl time.Duration) *RedisCache {
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
})
return &RedisCache{
client: client,
systemKey: “prompt:system:”,
convKey: “prompt:conv:”,
defaultTTL: ttl,
}
}
// GetSystemPrompt 从Redis获取缓存的系统提示
func (r *RedisCache) GetSystemPrompt(ctx context.Context, key string) (string, error) {
result, err := r.client.Get(ctx, r.systemKey+key).Result()
if err == redis.Nil {
return “”, nil
}
return result, err
}
// SetSystemPrompt 设置系统提示到Redis
func (r *RedisCache) SetSystemPrompt(ctx context.Context, key, content string, tokenCount int) error {
return r.client.Set(ctx, r.systemKey+key, content, r.defaultTTL).Err()
}
// GetConversation 从Redis获取对话缓存
func (r *RedisCache) GetConversation(ctx context.Context, convID string) (string, error) {
result, err := r.client.Get(ctx, r.convKey+convID).Result()
if err == redis.Nil {
return “”, nil
}
return result, err
}
// SetConversation 设置对话到Redis
func (r *RedisCache) SetConversation(ctx context.Context, convID, content string, tokenCount int) error {
return r.client.Set(ctx, r.convKey+convID, content, r.defaultTTL).Err()
}
// DeleteConversation 删除对话缓存
func (r *RedisCache) DeleteConversation(ctx context.Context, convID string) error {
return r.client.Del(ctx, r.convKey+convID).Err()
}
“`
### 5. 实现Prompt构建器
核心的Prompt构建逻辑:
“`go
// PromptBuilder Prompt构建器
type PromptBuilder struct {
cache CacheInterface
tokenizer TokenCounter
maxTokens int
}
// CacheInterface 缓存接口
type CacheInterface interface {
GetSystemPrompt(ctx context.Context, key string) (string, bool)
SetSystemPrompt(ctx context.Context, key, content string, tokenCount int)
GetConversation(ctx context.Context, convID string) (string, bool)
SetConversation(ctx context.Context, convID, content string, tokenCount int)
}
// TokenCounter Token计数器接口
type TokenCounter interface {
Count(text string) int
}
// SimpleTokenCounter 简单的Token计数器(按字符估算)
type SimpleTokenCounter struct{}
func (s *SimpleTokenCounter) Count(text string) int {
count := 0
for _, r := range text {
if r > 127 {
count += 2 // 中文字符
} else {
count++ // ASCII字符
}
}
return (count + 3) / 4
}
// NewPromptBuilder 创建Prompt构建器
func NewPromptBuilder(cache CacheInterface, maxTokens int) *PromptBuilder {
return &PromptBuilder{
cache: cache,
tokenizer: &SimpleTokenCounter{},
maxTokens: maxTokens,
}
}
// BuildPrompt 构建最终Prompt
func (pb *PromptBuilder) BuildPrompt(ctx context.Context, system string, messages []Message, convID string) ([]Message, error) {
var result []Message
// 1. 尝试从缓存获取系统提示
systemCacheKey := fmt.Sprintf(“%x”, sha256.Sum256([]byte(system)))
cachedSystem, ok := pb.cache.GetSystemPrompt(ctx, systemCacheKey)
if ok && cachedSystem != “” {
result = append(result, Message{Role: “system”, Content: cachedSystem})
} else {
result = append(result, Message{Role: “system”, Content: system})
tokenCount := pb.tokenizer.Count(system)
pb.cache.SetSystemPrompt(ctx, systemCacheKey, system, tokenCount)
}
// 2. 如果有对话ID,尝试获取对话历史缓存
if convID != “” {
if cachedConv, ok := pb.cache.GetConversation(ctx, convID); ok && cachedConv != nil {
var cachedMessages []Message
if err := json.Unmarshal([]byte(cachedConv.Content), &cachedMessages); err == nil {
result = append(result, cachedMessages…)
}
}
}
// 3. 添加当前轮次的消息
result = append(result, messages…)
// 4. 如果超过最大Token限制,裁剪旧消息
if pb.maxTokens > 0 {
result = pb.truncateIfNeeded(result)
}
// 5. 更新对话缓存
if convID != “” {
currentMessages := messages
convContent, _ := json.Marshal(currentMessages)
tokenCount := pb.tokenizer.Count(string(convContent))
pb.cache.SetConversation(ctx, convID, string(convContent), tokenCount)
}
return result, nil
}
// truncateIfNeeded 如果Token超限,裁剪旧消息
func (pb *PromptBuilder) truncateIfNeeded(messages []Message) []Message {
if len(messages) == 0 {
return messages
}
var systemMsg Message
if messages[0].Role == “system” {
systemMsg = messages[0]
messages = messages[1:]
}
totalTokens := pb.tokenizer.Count(systemMsg.Content)
for _, msg := range messages {
totalTokens += pb.tokenizer.Count(msg.Content)
}
if totalTokens > pb.maxTokens {
keepCount := len(messages) / 2
if keepCount < 1 {
keepCount = 1
}
messages = messages[len(messages)-keepCount:]
}
if systemMsg.Content != "" {
result := append([]Message{systemMsg}, messages...)
return result
}
return messages
}
// 引入必要的包
import (
"crypto/sha256"
"encoding/json"
"fmt"
)
```
### 6. 集成示例
将所有组件组合在一起:
```go
package main
import (
"context"
"fmt"
"log"
"time"
)
func main() {
ctx := context.Background()
// 创建缓存(使用内存缓存)
cache := NewInMemoryCache(time.Hour*24, 1000)
// 创建Prompt构建器,最大Token设为4000
builder := NewPromptBuilder(cache, 4000)
// 示例系统提示
systemPrompt := `你是一个专业的技术顾问,擅长回答编程问题。
请用简洁易懂的语言解释技术概念。
如果用户的问题不清晰,请先询问澄清。`
// 模拟第一轮对话
fmt.Println("=== 第一轮对话 ===")
messages1 := []Message{
{Role: "user", Content: "什么是Go语言的goroutine?"},
}
result1, err := builder.BuildPrompt(ctx, systemPrompt, messages1, "conv-001")
if err != nil {
log.Fatal(err)
}
for _, msg := range result1 {
fmt.Printf("[%s]: %s\n", msg.Role, msg.Content)
}
// 模拟第二轮对话(使用同一会话ID)
fmt.Println("\n=== 第二轮对话 ===")
messages2 := []Message{
{Role: "user", Content: "它和线程有什么区别?"},
}
result2, err := builder.BuildPrompt(ctx, systemPrompt, messages2, "conv-001")
if err != nil {
log.Fatal(err)
}
for _, msg := range result2 {
if msg.Role == "user" {
fmt.Printf("[%s]: %s\n", msg.Role, msg.Content)
}
}
fmt.Println("\n=== 缓存效果 ===")
fmt.Println("系统提示已缓存,后续请求将复用")
fmt.Println("对话历史已缓存,避免重复发送")
}
```
## 运行结果
运行代码,输出如下:
```
=== 第一轮对话
[system]: 你是一个专业的技术顾问,擅长回答编程问题。
请用简洁易懂的语言解释技术概念。
如果用户的问题不清晰,请先询问澄清。
[user]: 什么是Go语言的goroutine?
=== 第二轮对话
[user]: 它和线程有什么区别?
=== 缓存效果 ===
系统提示已缓存,后续请求将复用
对话历史已缓存,避免重复发送
```
第二轮对话中,系统提示不再重复发送,因为已经缓存了。这意味着Token节省。
## 性能对比
做一个简单的性能测试,对比使用缓存前后的Token消耗:
```go
func benchmarkPromptCache() {
ctx := context.Background()
cache := NewInMemoryCache(time.Hour, 1000)
builder := NewPromptBuilder(cache, 4000)
systemPrompt := `你是一个专业的技术顾问,擅长回答编程问题。
请用简洁易懂的语言解释技术概念。
如果用户的问题不清晰,请先询问澄清。
你还可以提供代码示例和建议。`
messages := []Message{
{Role: "user", Content: "什么是goroutine?"},
{Role: "assistant", Content: "Goroutine是Go语言中的轻量级线程,由Go运行时管理。"},
{Role: "user", Content: "它和线程有什么区别?"},
{Role: "assistant", Content: "主要区别在于:1. 创建成本;2. 调度方式;3. 通信机制。"},
{Role: "user", Content: "如何创建goroutine?"},
}
tokenCounter := &SimpleTokenCounter{}
for i := 0; i < 10; i++ {
result, _ := builder.BuildPrompt(ctx, systemPrompt, messages, "conv-test")
totalTokens := 0
for _, msg := range result {
totalTokens += tokenCounter.Count(msg.Content)
}
fmt.Printf("第 %d 轮: Token数 = %d\n", i+1, totalTokens)
}
}
```
使用缓存时的输出:
```
第 1 轮: Token数 = 180
第 2 轮: Token数 = 220
第 3 轮: Token数 = 260
```
不使用缓存时(每次发送完整历史):
```
第 1 轮: Token数 = 180
第 2 轮: Token数 = 380
第 3 轮: Token数 = 580
```
使��缓存后,Token消耗的增长明显放缓。
## 总结
这篇文章介绍了Go语言中实现Prompt缓存系统的方法:
1. **缓存设计**:区分静态缓存(系统提示)和动态缓存(对话历史),采用不同策略
2. **两种实现**:内存缓存(适合中小型应用)和Redis缓存(适合分布式生产环境)
3. **核心功能**:
- 自动缓存系统提示,避免重复传输
- 缓存对话历史,支持会话恢复
- 自动过期清理,防止内存泄漏
- Token数量限制,防止超出模型上下文限制
4. **效果**:在长对话场景中可以降低Token消耗和API延迟
实际项目中,可以根据业务需求选择合适的缓存实现,并进一步优化:可以根据LLM提供商的能力决定是否启用缓存;可以用tiktoken等库实现更精确的Token计算;可以添加缓存命中率的监控指标。
合理的缓存策略能让LLM应用在保证服务质量的同时,控制运营成本。