add token counting for openai model

This commit is contained in:
2025-01-27 09:17:25 +01:00
parent be41a4e84b
commit 9687f327fe
4 changed files with 36 additions and 7 deletions

View File

@@ -9,6 +9,7 @@ import (
"crowsnest/internal/util"
"log"
"net/http"
"os"
_ "github.com/lib/pq"
)
@@ -18,13 +19,15 @@ func main() {
if err != nil {
log.Fatal("failed to connect to database due to", err.Error())
}
log.Println("connected to database successfully")
// summarize documents
documents := &database.DocumentRepository{DB: db}
openai := &util.OpenAi{ApiKey: os.Getenv("OPENAI_API_KEY")}
sumDoc := func(doc *model.Document) *model.Document {
if doc.Summary == "" {
summaryText, err := util.Summarize(doc.Content)
summaryText, err := openai.Summarize(doc.Content)
if err == nil {
doc.Summary = summaryText
return doc

View File

@@ -18,6 +18,7 @@ require (
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/containerd/console v1.0.3 // indirect
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/cors v1.7.2 // indirect
@@ -33,6 +34,7 @@ require (
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kennygrant/sanitize v1.2.4 // indirect
@@ -50,6 +52,7 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect
github.com/spf13/cobra v1.7.0 // indirect

View File

@@ -54,6 +54,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -132,6 +134,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@@ -181,6 +185,8 @@ github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuR
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=

View File

@@ -7,10 +7,11 @@ import (
"fmt"
"io"
"net/http"
"os"
"github.com/pkoukk/tiktoken-go"
)
type Response struct {
type response struct {
Choices []struct {
Message struct {
Content string `json:"content"`
@@ -18,9 +19,12 @@ type Response struct {
} `json:"choices"`
}
func Summarize(text string) (string, error) {
type OpenAi struct {
ApiKey string
}
func (oai *OpenAi) Summarize(text string) (string, error) {
apiURL := "https://api.openai.com/v1/chat/completions"
apiKey := os.Getenv("OPENAI_API_KEY")
// Request payload
payload := map[string]interface{}{
@@ -51,7 +55,7 @@ func Summarize(text string) (string, error) {
// Add headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oai.ApiKey))
// Send the request
client := &http.Client{}
@@ -68,7 +72,7 @@ func Summarize(text string) (string, error) {
}
// Unmarshal the JSON response
var response Response
var response response
err = json.Unmarshal(body, &response)
if err != nil {
return "", err
@@ -84,3 +88,16 @@ func Summarize(text string) (string, error) {
return content, nil
}
func (oai *OpenAi) CountTokens(text string) int {
tkm, err := tiktoken.GetEncoding("o200k_base")
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return -1
}
// encode
token := tkm.Encode(text, nil, nil)
return len(token)
}