diff --git a/src/cmd/frontend/main.go b/src/cmd/frontend/main.go index 65d91a2..6de08b3 100644 --- a/src/cmd/frontend/main.go +++ b/src/cmd/frontend/main.go @@ -9,6 +9,7 @@ import ( "crowsnest/internal/util" "log" "net/http" + "os" _ "github.com/lib/pq" ) @@ -18,13 +19,15 @@ func main() { if err != nil { log.Fatal("failed to connect to database due to", err.Error()) } + log.Println("connected to database successfully") // summarize documents documents := &database.DocumentRepository{DB: db} + openai := &util.OpenAi{ApiKey: os.Getenv("OPENAI_API_KEY")} sumDoc := func(doc *model.Document) *model.Document { if doc.Summary == "" { - summaryText, err := util.Summarize(doc.Content) + summaryText, err := openai.Summarize(doc.Content) if err == nil { doc.Summary = summaryText return doc diff --git a/src/go.mod b/src/go.mod index e1d9fa8..d7e914f 100644 --- a/src/go.mod +++ b/src/go.mod @@ -18,6 +18,7 @@ require ( github.com/cloudwego/iasm v0.2.0 // indirect github.com/containerd/console v1.0.3 // indirect github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/cors v1.7.2 // indirect @@ -33,6 +34,7 @@ require ( github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kennygrant/sanitize v1.2.4 // indirect @@ -50,6 +52,7 @@ require ( github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkoukk/tiktoken-go v0.1.7 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect github.com/spf13/cobra v1.7.0 // indirect diff --git a/src/go.sum b/src/go.sum index 72bdfe0..7e6f824 100644 --- a/src/go.sum +++ b/src/go.sum @@ -54,6 +54,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -132,6 +134,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -181,6 +185,8 @@ github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuR github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= +github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= diff --git a/src/internal/util/summary.go b/src/internal/util/openai.go similarity index 76% rename from src/internal/util/summary.go rename to src/internal/util/openai.go index d474480..02603a6 100644 --- a/src/internal/util/summary.go +++ b/src/internal/util/openai.go @@ -7,10 +7,11 @@ import ( "fmt" "io" "net/http" - "os" + + "github.com/pkoukk/tiktoken-go" ) -type Response struct { +type response struct { Choices []struct { Message struct { Content string `json:"content"` @@ -18,9 +19,12 @@ type Response struct { } `json:"choices"` } -func Summarize(text string) (string, error) { +type OpenAi struct { + ApiKey string +} + +func (oai *OpenAi) Summarize(text string) (string, error) { apiURL := "https://api.openai.com/v1/chat/completions" - apiKey := os.Getenv("OPENAI_API_KEY") // Request payload payload := map[string]interface{}{ @@ -51,7 +55,7 @@ func Summarize(text string) (string, error) { // Add headers req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oai.ApiKey)) // Send the request client := &http.Client{} @@ -68,7 +72,7 @@ func Summarize(text string) (string, error) { } // Unmarshal the JSON response - var response Response + var response response err = json.Unmarshal(body, &response) if err != nil { return "", err @@ -84,3 +88,16 @@ func Summarize(text string) (string, error) { return content, nil } + +func (oai *OpenAi) CountTokens(text string) int { + tkm, err := tiktoken.GetEncoding("o200k_base") + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + return -1 + } + + // encode + token := tkm.Encode(text, nil, nil) + + return len(token) +}