-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathchat_session.go
More file actions
132 lines (109 loc) · 3.32 KB
/
chat_session.go
File metadata and controls
132 lines (109 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package gemini
import (
"context"
"encoding/json"
"fmt"
"iter"
"sync"
"google.golang.org/genai"
)
const DefaultModel = "gemini-2.5-flash"
// ChatSession represents a gemini powered chat session.
type ChatSession struct {
ctx context.Context
client *genai.Client
chat *genai.Chat
config *genai.GenerateContentConfig
model string
loadModels sync.Once
models []string
}
// NewChatSession returns a new [ChatSession].
func NewChatSession(ctx context.Context, model string,
contentConfig *genai.GenerateContentConfig) (*ChatSession, error) {
client, err := genai.NewClient(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to create client: %w", err)
}
chat, err := client.Chats.Create(ctx, model, contentConfig, nil)
if err != nil {
return nil, fmt.Errorf("failed to create chat: %w", err)
}
return &ChatSession{
ctx: ctx,
client: client,
chat: chat,
config: contentConfig,
model: model,
}, nil
}
// SendMessage sends a request to the model as part of a chat session.
func (c *ChatSession) SendMessage(input string) (*genai.GenerateContentResponse, error) {
return c.chat.SendMessage(c.ctx, genai.Part{Text: input})
}
// SendMessageStream is like SendMessage, but with a streaming request.
func (c *ChatSession) SendMessageStream(input string) iter.Seq2[*genai.GenerateContentResponse, error] {
return c.chat.SendMessageStream(c.ctx, genai.Part{Text: input})
}
// ModelInfo returns information about the chat generative model in JSON format.
func (c *ChatSession) ModelInfo() (string, error) {
modelInfo, err := c.client.Models.Get(c.ctx, c.model, nil)
if err != nil {
return "", err
}
encoded, err := json.MarshalIndent(modelInfo, "", " ")
if err != nil {
return "", fmt.Errorf("error encoding model info: %w", err)
}
return string(encoded), nil
}
// ListModels returns a list of the supported generative model names.
func (c *ChatSession) ListModels() []string {
c.loadModels.Do(func() {
c.models = []string{DefaultModel}
for model, err := range c.client.Models.All(c.ctx) {
if err != nil {
continue
}
c.models = append(c.models, model.Name)
}
})
return c.models
}
// SetModel sets the chat generative model.
func (c *ChatSession) SetModel(model string) error {
chat, err := c.client.Chats.Create(c.ctx, model, c.config, c.GetHistory())
if err != nil {
return fmt.Errorf("failed to set model: %w", err)
}
c.model = model
c.chat = chat
return nil
}
// GetHistory returns the chat session history.
func (c *ChatSession) GetHistory() []*genai.Content {
return c.chat.History(true)
}
// SetHistory sets the chat session history.
func (c *ChatSession) SetHistory(history []*genai.Content) error {
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, history)
if err != nil {
return fmt.Errorf("failed to set history: %w", err)
}
c.chat = chat
return nil
}
// ClearHistory clears the chat session history.
func (c *ChatSession) ClearHistory() error {
return c.SetHistory(nil)
}
// SetSystemInstruction sets the chat session system instruction.
func (c *ChatSession) SetSystemInstruction(systemInstruction *genai.Content) error {
c.config.SystemInstruction = systemInstruction
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, c.GetHistory())
if err != nil {
return fmt.Errorf("failed to set system instruction: %w", err)
}
c.chat = chat
return nil
}