add token counting for openai model
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"crowsnest/internal/util"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
@@ -18,13 +19,15 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal("failed to connect to database due to", err.Error())
|
||||
}
|
||||
log.Println("connected to database successfully")
|
||||
|
||||
// summarize documents
|
||||
documents := &database.DocumentRepository{DB: db}
|
||||
openai := &util.OpenAi{ApiKey: os.Getenv("OPENAI_API_KEY")}
|
||||
|
||||
sumDoc := func(doc *model.Document) *model.Document {
|
||||
if doc.Summary == "" {
|
||||
summaryText, err := util.Summarize(doc.Content)
|
||||
summaryText, err := openai.Summarize(doc.Content)
|
||||
if err == nil {
|
||||
doc.Summary = summaryText
|
||||
return doc
|
||||
|
||||
@@ -18,6 +18,7 @@ require (
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/containerd/console v1.0.3 // indirect
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/cors v1.7.2 // indirect
|
||||
@@ -33,6 +34,7 @@ require (
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kennygrant/sanitize v1.2.4 // indirect
|
||||
@@ -50,6 +52,7 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect
|
||||
github.com/spf13/cobra v1.7.0 // indirect
|
||||
|
||||
@@ -54,6 +54,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -132,6 +134,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
@@ -181,6 +185,8 @@ github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuR
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
type response struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
@@ -18,9 +19,12 @@ type Response struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func Summarize(text string) (string, error) {
|
||||
type OpenAi struct {
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
func (oai *OpenAi) Summarize(text string) (string, error) {
|
||||
apiURL := "https://api.openai.com/v1/chat/completions"
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
// Request payload
|
||||
payload := map[string]interface{}{
|
||||
@@ -51,7 +55,7 @@ func Summarize(text string) (string, error) {
|
||||
|
||||
// Add headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oai.ApiKey))
|
||||
|
||||
// Send the request
|
||||
client := &http.Client{}
|
||||
@@ -68,7 +72,7 @@ func Summarize(text string) (string, error) {
|
||||
}
|
||||
|
||||
// Unmarshal the JSON response
|
||||
var response Response
|
||||
var response response
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -84,3 +88,16 @@ func Summarize(text string) (string, error) {
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (oai *OpenAi) CountTokens(text string) int {
|
||||
tkm, err := tiktoken.GetEncoding("o200k_base")
|
||||
if err != nil {
|
||||
err = fmt.Errorf("getEncoding: %v", err)
|
||||
return -1
|
||||
}
|
||||
|
||||
// encode
|
||||
token := tkm.Encode(text, nil, nil)
|
||||
|
||||
return len(token)
|
||||
}
|
||||
Reference in New Issue
Block a user