internal/store/accounts.go
package store
import (
"context"
"database/sql"
"fmt"
"slices"
"strings"
)
type Account struct {
GUID string
Name string
FullName string
AccountType string
CommodityGUID *string
CommoditySCU int64
NonSTDSCU int64
ParentGUID *string
Code *string
Description *string
Hidden *int64
Placeholder *int64
Parent *Account
Children []*Account
Level int
}
func (a *Account) GetDescendants() []*Account {
var descendants []*Account
var traverse func(*Account)
traverse = func(account *Account) {
for _, child := range account.Children {
descendants = append(descendants, child)
traverse(child)
}
}
traverse(a)
return descendants
}
type AccountQuery struct {
whereClauses []string
args []any
orderFields []orderField
limit *int
offset *int
}
func NewAccountQuery() *AccountQuery {
return &AccountQuery{
whereClauses: make([]string, 0),
args: make([]any, 0),
orderFields: make([]orderField, 0),
}
}
func (q *AccountQuery) Where(clause string, args ...any) *AccountQuery {
q.whereClauses = append(q.whereClauses, clause)
q.args = append(q.args, args...)
return q
}
func (q *AccountQuery) OrderBy(field string, descending bool) *AccountQuery {
q.orderFields = append(q.orderFields, orderField{field: field, descending: descending})
return q
}
func (q *AccountQuery) Limit(limit int) *AccountQuery {
q.limit = &limit
return q
}
func (q *AccountQuery) Offset(offset int) *AccountQuery {
q.offset = &offset
return q
}
func (q *AccountQuery) Page(page, pageSize int) *AccountQuery {
offset := (page - 1) * pageSize
return q.Limit(pageSize).Offset(offset)
}
func (q *AccountQuery) Build() string {
var b strings.Builder
b.WriteString(`
SELECT
guid,
name,
account_type,
commodity_guid,
commodity_scu,
non_std_scu,
parent_guid,
code,
description,
hidden,
placeholder
FROM accounts
`)
if len(q.whereClauses) > 0 {
b.WriteString("\nWHERE ")
b.WriteString(strings.Join(q.whereClauses, " AND "))
}
if len(q.orderFields) > 0 {
b.WriteString("\nORDER BY ")
orders := make([]string, len(q.orderFields))
for i, field := range q.orderFields {
direction := "ASC"
if field.descending {
direction = "DESC"
}
orders[i] = fmt.Sprintf("%s %s", field.field, direction)
}
b.WriteString(strings.Join(orders, ", "))
}
if q.limit != nil {
b.WriteString(fmt.Sprintf("\nLIMIT %d", *q.limit))
}
if q.offset != nil {
b.WriteString(fmt.Sprintf("\nOFFSET %d", *q.offset))
}
return b.String()
}
func (q *AccountQuery) Args() []any {
return q.args
}
type AccountsStorer interface {
All(ctx context.Context, q *AccountQuery) ([]*Account, error)
Get(ctx context.Context, s string, opts ...AccountsOptFunc) (*Account, error)
Update(ctx context.Context, account *Account) error
}
type AccountsStore struct {
db DBTX
Opts AccountsOpts
}
// AccountsOpts configures account query and retrieval behavior.
type AccountsOpts struct {
// withAccountFullName indicates that the lookup string should be interpreted as
// a colon-separated full account name (e.g., "expenses:dining:pizza") rather than
// a GUID. When true, Get() traverses the account tree from the root to locate
// the account by its hierarchy path.
withAccountFullName bool
}
func defaultAccountsOpts() *AccountsOpts {
return &AccountsOpts{
withAccountFullName: false,
}
}
type AccountsOptFunc func(*AccountsOpts)
func WithAccountFullName(b bool) AccountsOptFunc {
return func(o *AccountsOpts) {
o.withAccountFullName = b
}
}
func getAccountSubtree(ctx context.Context, db DBTX, guid string) ([]*Account, error) {
query := `
WITH RECURSIVE account_tree AS (
SELECT
accounts.guid,
accounts.name,
accounts.account_type,
accounts.commodity_guid,
accounts.commodity_scu,
accounts.non_std_scu,
accounts.parent_guid,
accounts.code,
accounts.description,
accounts.hidden,
accounts.placeholder
FROM accounts
WHERE accounts.guid = ?
UNION ALL
SELECT
a.guid,
a.name,
a.account_type,
a.commodity_guid,
a.commodity_scu,
a.non_std_scu,
a.parent_guid,
a.code,
a.description,
a.hidden,
a.placeholder
FROM accounts a
INNER JOIN account_tree at ON a.parent_guid = at.guid
)
SELECT
guid,
name,
account_type,
commodity_guid,
commodity_scu,
non_std_scu,
parent_guid,
code,
description,
hidden,
placeholder
FROM account_tree
`
rows, err := db.QueryContext(ctx, query, guid)
if err != nil {
return nil, err
}
defer rows.Close()
var accounts []*Account
for rows.Next() {
account, err := scanAccount(rows)
if err != nil {
return nil, err
}
fullName, err := getFullAccountName(ctx, db, account)
if err != nil {
return nil, err
}
account.FullName = fullName
accounts = append(accounts, account)
}
return accounts, nil
}
func buildTreeFromAccounts(accounts []*Account, root *Account) {
accountMap := make(map[string]*Account)
for _, account := range accounts {
account.Children = make([]*Account, 0)
accountMap[account.GUID] = account
}
for _, account := range accounts {
if account.ParentGUID != nil {
if parent, exists := accountMap[*account.ParentGUID]; exists {
parent.Children = append(parent.Children, account)
account.Parent = parent
account.Level = parent.Level + 1
}
}
}
if rootFromAccountMap := accountMap[root.GUID]; rootFromAccountMap != nil {
root.Children = rootFromAccountMap.Children
root.Parent = rootFromAccountMap.Parent
root.Level = rootFromAccountMap.Level
for _, child := range root.Children {
if child.Parent != root {
child.Parent = root
}
}
}
}
// getFullAccountName takes a store.Account and attempts to return its full
// account name (e.g. expenses:dining:pizza)
func getFullAccountName(ctx context.Context, db DBTX, account *Account) (string, error) {
s := []string{account.Name}
for account.ParentGUID != nil {
var err error
q := NewAccountQuery().Where("guid=?", account.ParentGUID)
row := db.QueryRowContext(ctx, q.Build(), q.Args()...)
account, err = scanAccount(row)
if err != nil {
return "", err
}
if strings.ToLower(account.AccountType) == "root" {
break
}
s = append(s, account.Name)
}
slices.Reverse(s)
return strings.Join(s, ":"), nil
}
func getAccountFromAccountTree(ctx context.Context, db DBTX, s string) (*Account, error) {
q := NewAccountQuery().Where("account_type=? AND name=? AND parent_guid IS NULL", "ROOT", "Root Account")
row := db.QueryRowContext(ctx, q.Build(), q.Args()...)
rootAccount, err := scanAccount(row)
if err != nil {
return nil, err
}
parentAccounts := []*Account{rootAccount}
accounts := strings.Split(s, ":")
for idx, accountName := range accounts {
parentAccount := parentAccounts[idx]
q = NewAccountQuery().Where("name=? COLLATE NOCASE and parent_guid=?", accountName, parentAccount.GUID)
row := db.QueryRowContext(ctx, q.Build(), q.Args()...)
account, err := scanAccount(row)
if err != nil {
return nil, err
}
parentAccounts = append(parentAccounts, account)
}
if len(parentAccounts) != len(accounts)+1 {
return nil, fmt.Errorf("failed to find account from tree")
}
return parentAccounts[len(parentAccounts)-1], nil
}
func getSubtreeWithTree(ctx context.Context, db DBTX, root *Account) error {
accounts, err := getAccountSubtree(ctx, db, root.GUID)
if err != nil {
return err
}
if len(accounts) == 0 {
return sql.ErrNoRows
}
buildTreeFromAccounts(accounts, root)
return nil
}
func (s AccountsStore) Get(ctx context.Context, guidOrName string, opts ...AccountsOptFunc) (*Account, error) {
var account *Account
o := defaultAccountsOpts()
for _, fn := range opts {
fn(o)
}
if o.withAccountFullName {
account, err := getAccountFromAccountTree(ctx, s.db, guidOrName)
if err != nil {
return nil, err
}
fullName, err := getFullAccountName(ctx, s.db, account)
if err != nil {
return nil, err
}
account.FullName = fullName
if err := getSubtreeWithTree(ctx, s.db, account); err != nil {
return nil, err
}
return account, nil
}
q := NewAccountQuery()
q.Where("guid= ?", guidOrName)
sqlQuery := q.Build()
args := q.Args()
row := s.db.QueryRowContext(ctx, sqlQuery, args...)
account, err := scanAccount(row)
if err != nil {
return nil, err
}
fullName, err := getFullAccountName(ctx, s.db, account)
if err != nil {
return nil, err
}
account.FullName = fullName
if err := getSubtreeWithTree(ctx, s.db, account); err != nil {
return nil, err
}
return account, nil
}
func (s AccountsStore) All(ctx context.Context, q *AccountQuery) ([]*Account, error) {
sqlQuery := q.Build()
args := q.Args()
rows, err := s.db.QueryContext(ctx, sqlQuery, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var accounts []*Account
for rows.Next() {
account, err := scanAccount(rows)
if err != nil {
return nil, err
}
fullName, err := getFullAccountName(ctx, s.db, account)
if err != nil {
return nil, err
}
account.FullName = fullName
accounts = append(accounts, account)
}
return accounts, rows.Err()
}
func scanAccount(scanner rowScanner) (*Account, error) {
var account Account
var commodityGUID, parentGUID, code, description sql.NullString
var hidden, placeholder sql.NullInt64
err := scanner.Scan(
&account.GUID,
&account.Name,
&account.AccountType,
&commodityGUID,
&account.CommoditySCU,
&account.NonSTDSCU,
&parentGUID,
&code,
&description,
&hidden,
&placeholder,
)
if err != nil {
return nil, err
}
if commodityGUID.Valid {
account.CommodityGUID = &commodityGUID.String
}
if parentGUID.Valid {
account.ParentGUID = &parentGUID.String
}
if code.Valid {
account.Code = &code.String
}
if description.Valid {
account.Description = &description.String
}
if hidden.Valid {
account.Hidden = &hidden.Int64
}
if placeholder.Valid {
account.Placeholder = &placeholder.Int64
}
return &account, nil
}
func (a AccountsStore) Update(ctx context.Context, account *Account) error {
query := `
UPDATE accounts
SET
name = ?,
account_type = ?,
commodity_guid = ?,
commodity_scu = ?,
non_std_scu = ?,
parent_guid = ?,
code = ?,
description = ?,
hidden = ?,
placeholder = ?
WHERE guid = ? AND rowid IN (
SELECT rowid FROM accounts WHERE guid = ? LIMIT 1
)
`
var commodityGUID sql.NullString
if account.CommodityGUID != nil {
commodityGUID = sql.NullString{
String: *account.CommodityGUID,
Valid: true,
}
}
var parentGUID sql.NullString
if account.ParentGUID != nil {
parentGUID = sql.NullString{
String: *account.ParentGUID,
Valid: true,
}
}
var code sql.NullString
if account.Code != nil {
code = sql.NullString{
String: *account.Code,
Valid: true,
}
}
var description sql.NullString
if account.Description != nil {
description = sql.NullString{
String: *account.Description,
Valid: true,
}
}
var hidden sql.NullInt64
if account.Hidden != nil {
hidden = sql.NullInt64{
Int64: *account.Hidden,
Valid: true,
}
}
var placeholder sql.NullInt64
if account.Placeholder != nil {
placeholder = sql.NullInt64{
Int64: *account.Placeholder,
Valid: true,
}
}
result, err := a.db.ExecContext(
ctx,
query,
account.Name,
account.AccountType,
commodityGUID,
account.CommoditySCU,
account.NonSTDSCU,
parentGUID,
code,
description,
hidden,
placeholder,
account.GUID,
account.GUID,
)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return sql.ErrNoRows
}
return nil
}