84 lines
2.3 KiB
Go
84 lines
2.3 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type DB struct {
|
|
*sql.DB
|
|
path string
|
|
}
|
|
|
|
var (
|
|
database *DB
|
|
migrations = []string{
|
|
`CREATE TABLE IF NOT EXISTS nodes (id TEXT PRIMARY KEY, title TEXT NOT NULL, content TEXT, due_date TEXT, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)`,
|
|
`CREATE TABLE IF NOT EXISTS tags (node_id TEXT NOT NULL, tag TEXT NOT NULL, PRIMARY KEY (node_id, tag), FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE)`,
|
|
`CREATE TABLE IF NOT EXISTS rels (from_id TEXT NOT NULL, to_id TEXT NOT NULL, rel_type TEXT NOT NULL, PRIMARY KEY (from_id, to_id, rel_type), FOREIGN KEY (from_id) REFERENCES nodes(id) ON DELETE CASCADE, FOREIGN KEY (to_id) REFERENCES nodes(id) ON DELETE CASCADE)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_tags_tag ON tags(tag)`, `CREATE INDEX IF NOT EXISTS idx_rels_from ON rels(from_id)`, `CREATE INDEX IF NOT EXISTS idx_rels_to ON rels(to_id)`,
|
|
}
|
|
)
|
|
|
|
func GetDB() (*DB, error) {
|
|
if database != nil {
|
|
return database, nil
|
|
}
|
|
dir, err := filepath.Abs(".")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for {
|
|
if _, err := os.Stat(filepath.Join(dir, ".ax.db")); err == nil {
|
|
if database, err = Open(filepath.Join(dir, ".ax.db")); err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
return database, nil
|
|
}
|
|
if parent := filepath.Dir(dir); parent == dir {
|
|
return nil, errors.New("no .ax.db found (run 'ax init' first)")
|
|
} else {
|
|
dir = parent
|
|
}
|
|
}
|
|
}
|
|
|
|
func Init(path string) error {
|
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
|
return err
|
|
}
|
|
var err error
|
|
database, err = Open(path)
|
|
return err
|
|
}
|
|
|
|
func Open(path string) (*DB, error) {
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, q := range append([]string{"PRAGMA journal_mode=WAL", "PRAGMA busy_timeout=5000", "PRAGMA foreign_keys=ON"}, migrations...) {
|
|
if _, err := db.Exec(q); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &DB{DB: db, path: path}, nil
|
|
}
|
|
|
|
func (db *DB) GetUserByUsername(username string) (string, error) {
|
|
var id string
|
|
err := db.QueryRow(`
|
|
SELECT n.id FROM nodes n
|
|
JOIN tags t ON n.id = t.node_id
|
|
WHERE n.title = ? AND t.tag = '_type::user'`, username).Scan(&id)
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
return id, err
|
|
}
|