-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.go
More file actions
77 lines (63 loc) · 1.6 KB
/
client.go
File metadata and controls
77 lines (63 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package ollama
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
const (
DefaultBaseURL = "http://localhost:11434"
Model = "llama3.2"
)
type Client struct {
baseURL string
http *http.Client
}
func NewClient(baseURL string) *Client {
if baseURL == "" {
baseURL = DefaultBaseURL
}
return &Client{
baseURL: baseURL,
http: &http.Client{},
}
}
type embedRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}
type embedResponse struct {
Embeddings [][]float64 `json:"embeddings"`
}
// GetEmbedding calls the Ollama embed API and returns the embedding vector
func (c *Client) GetEmbedding(text string) ([]float32, error) {
reqBody := embedRequest{
Model: Model,
Input: text,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := c.http.Post(c.baseURL+"/api/embed", "application/json", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to call ollama: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode)
}
var embedResp embedResponse
if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(embedResp.Embeddings) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
// Convert float64 to float32
embedding := make([]float32, len(embedResp.Embeddings[0]))
for i, v := range embedResp.Embeddings[0] {
embedding[i] = float32(v)
}
return embedding, nil
}