internal/store/transactions.go
package store
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
)
type Transaction struct {
GUID string
CurrencyGUID string
Num string
PostDate *time.Time
EnterDate *time.Time
Description *string
Splits []*Split
}
type TransactionQuery struct {
whereClauses []string
args []any
orderFields []orderField
limit *int
offset *int
}
func NewTransactionQuery() *TransactionQuery {
return &TransactionQuery{
whereClauses: make([]string, 0),
args: make([]any, 0),
orderFields: make([]orderField, 0),
}
}
func (q *TransactionQuery) Where(clause string, args ...any) *TransactionQuery {
q.whereClauses = append(q.whereClauses, clause)
q.args = append(q.args, args...)
return q
}
func (q *TransactionQuery) OrderBy(field string, descending bool) *TransactionQuery {
q.orderFields = append(q.orderFields, orderField{field: field, descending: descending})
return q
}
func (q *TransactionQuery) Limit(limit int) *TransactionQuery {
if limit != 0 {
q.limit = &limit
}
return q
}
func (q *TransactionQuery) Offset(offset int) *TransactionQuery {
q.offset = &offset
return q
}
func (q *TransactionQuery) Page(page, pageSize int) *TransactionQuery {
offset := (page - 1) * pageSize
return q.Limit(pageSize).Offset(offset)
}
func (q *TransactionQuery) Build() string {
var b strings.Builder
b.WriteString(`
SELECT
transactions.guid,
transactions.currency_guid,
transactions.num,
transactions.post_date,
transactions.enter_date,
transactions.description,
splits.guid,
splits.account_guid,
splits.memo,
splits.action,
splits.reconcile_state,
splits.reconcile_date,
splits.value_num,
splits.value_denom,
splits.quantity_num,
splits.quantity_denom,
splits.lot_guid,
accounts.guid,
accounts.name,
accounts.account_type,
accounts.commodity_guid,
accounts.commodity_scu,
accounts.non_std_scu,
accounts.parent_guid,
accounts.code,
accounts.description,
accounts.hidden,
accounts.placeholder
FROM transactions
LEFT JOIN splits ON splits.tx_guid = transactions.guid
LEFT JOIN accounts ON accounts.guid = splits.account_guid
`)
if len(q.whereClauses) > 0 {
b.WriteString("\nWHERE ")
b.WriteString(strings.Join(q.whereClauses, " AND "))
}
if len(q.orderFields) > 0 {
b.WriteString("\nORDER BY ")
orders := make([]string, len(q.orderFields))
for i, field := range q.orderFields {
direction := "ASC"
if field.descending {
direction = "DESC"
}
orders[i] = fmt.Sprintf("%s %s", field.field, direction)
}
b.WriteString(strings.Join(orders, ", "))
}
if q.limit != nil {
b.WriteString(fmt.Sprintf("\nLIMIT %d", *q.limit))
}
if q.offset != nil {
b.WriteString(fmt.Sprintf("\nOFFSET %d", *q.offset))
}
return b.String()
}
func (q *TransactionQuery) Copy() *TransactionQuery {
if q == nil {
return nil
}
copied := &TransactionQuery{
whereClauses: make([]string, len(q.whereClauses)),
args: make([]any, len(q.args)),
orderFields: make([]orderField, len(q.orderFields)),
}
copy(copied.whereClauses, q.whereClauses)
copy(copied.args, q.args)
copy(copied.orderFields, q.orderFields)
if q.limit != nil {
limit := *q.limit
copied.limit = &limit
}
if q.offset != nil {
offset := *q.offset
copied.offset = &offset
}
return copied
}
func (q *TransactionQuery) Args() []any {
return q.args
}
type TransactionsStorer interface {
All(ctx context.Context, q *TransactionQuery) ([]*Transaction, error)
Get(ctx context.Context, guid string) (*Transaction, error)
}
type TransactionsStore struct {
db DBTX
}
func (t TransactionsStore) Get(ctx context.Context, guid string) (*Transaction, error) {
q := NewTransactionQuery().Where("transactions.guid=?", guid)
rows, err := t.db.QueryContext(ctx, q.Build(), q.Args()...)
if err != nil {
return nil, err
}
transactions, err := scanTransactions(rows)
if err != nil {
return nil, err
}
if len(transactions) == 0 {
return nil, sql.ErrNoRows
}
for _, transaction := range transactions {
for _, split := range transaction.Splits {
account := split.Account
accountFullName, err := getFullAccountName(ctx, t.db, account)
if err != nil {
return nil, err
}
account.FullName = accountFullName
}
}
return transactions[0], nil
}
func (t TransactionsStore) All(ctx context.Context, q *TransactionQuery) ([]*Transaction, error) {
var guidQuery strings.Builder
guidQuery.WriteString("SELECT guid FROM transactions")
if len(q.whereClauses) > 0 {
guidQuery.WriteString("\nWHERE ")
guidQuery.WriteString(strings.Join(q.whereClauses, " AND "))
}
if len(q.orderFields) > 0 {
guidQuery.WriteString("\nORDER BY ")
orders := make([]string, len(q.orderFields))
for i, field := range q.orderFields {
direction := "ASC"
if field.descending {
direction = "DESC"
}
orders[i] = fmt.Sprintf("%s %s", field.field, direction)
}
guidQuery.WriteString(strings.Join(orders, ", "))
}
if q.limit != nil {
guidQuery.WriteString(fmt.Sprintf("\nLIMIT %d", *q.limit))
}
if q.offset != nil {
guidQuery.WriteString(fmt.Sprintf("\nOFFSET %d", *q.offset))
}
guidRows, err := t.db.QueryContext(ctx, guidQuery.String(), q.Args()...)
if err != nil {
return nil, err
}
defer guidRows.Close()
var guids []string
for guidRows.Next() {
var guid string
if err := guidRows.Scan(&guid); err != nil {
return nil, err
}
guids = append(guids, guid)
}
if err := guidRows.Err(); err != nil {
return nil, err
}
transactions := []*Transaction{}
if len(guids) == 0 {
return transactions, nil
}
placeholders := make([]string, len(guids))
guidArgs := make([]any, len(guids))
for i, guid := range guids {
placeholders[i] = "?"
guidArgs[i] = guid
}
fullQuery := NewTransactionQuery()
fullQuery.Where(fmt.Sprintf("transactions.guid IN (%s)", strings.Join(placeholders, ",")), guidArgs...)
for _, orderField := range q.orderFields {
fullQuery.OrderBy(orderField.field, orderField.descending)
}
rows, err := t.db.QueryContext(ctx, fullQuery.Build(), fullQuery.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
transactions, err = scanTransactions(rows)
if err != nil {
return nil, err
}
for _, transaction := range transactions {
for _, split := range transaction.Splits {
account := split.Account
accountFullName, err := getFullAccountName(ctx, t.db, account)
if err != nil {
return nil, err
}
account.FullName = accountFullName
}
}
return transactions, nil
}
func scanTransactions(rows *sql.Rows) ([]*Transaction, error) {
transactionMap := make(map[string]*Transaction)
var orderedGUIDs []string
for rows.Next() {
var transactionDescription, transactionPostDate, transactionEnterDate sql.NullString
var transactionGUID, transactionCurrencyGUID, transactionNum sql.NullString
var splitGUID, splitAccountGUID, splitMemo, splitAction, splitReconcileState sql.NullString
var splitReconcileDate, splitLogGUID sql.NullString
var splitValueNum, splitValueDenom, splitQuantityNum, splitQuantityDenom sql.NullInt64
var accountGUID, accountName, accountAccountType sql.NullString
var accountCommodityGUID, accountParentGUID, accountCode, accountDescription sql.NullString
var accountCommoditySCU, accountNonStdSCU, accountHidden, accountPlaceholder sql.NullInt64
err := rows.Scan(
&transactionGUID,
&transactionCurrencyGUID,
&transactionNum,
&transactionPostDate,
&transactionEnterDate,
&transactionDescription,
&splitGUID,
&splitAccountGUID,
&splitMemo,
&splitAction,
&splitReconcileState,
&splitReconcileDate,
&splitValueNum,
&splitValueDenom,
&splitQuantityNum,
&splitQuantityDenom,
&splitLogGUID,
&accountGUID,
&accountName,
&accountAccountType,
&accountCommodityGUID,
&accountCommoditySCU,
&accountNonStdSCU,
&accountParentGUID,
&accountCode,
&accountDescription,
&accountHidden,
&accountPlaceholder,
)
if err != nil {
return nil, err
}
transaction, exists := transactionMap[transactionGUID.String]
if !exists {
transaction = &Transaction{
GUID: transactionGUID.String,
CurrencyGUID: transactionCurrencyGUID.String,
Num: transactionNum.String,
Splits: make([]*Split, 0),
}
if transactionPostDate.Valid {
pd, err := time.Parse("2006-01-02 15:04:05", transactionPostDate.String)
if err != nil {
return nil, err
}
transaction.PostDate = &pd
}
if transactionEnterDate.Valid {
ed, err := time.Parse("2006-01-02 15:04:05", transactionEnterDate.String)
if err != nil {
return nil, err
}
transaction.EnterDate = &ed
}
if transactionDescription.Valid {
transaction.Description = &transactionDescription.String
}
transactionMap[transactionGUID.String] = transaction
orderedGUIDs = append(orderedGUIDs, transaction.GUID)
}
if splitGUID.Valid {
split := Split{
GUID: splitGUID.String,
TXGUID: transactionGUID.String,
AccountGUID: splitAccountGUID.String,
ReconcileState: splitReconcileState.String,
}
if splitMemo.Valid {
split.Memo = splitMemo.String
}
if splitAction.Valid {
split.Action = splitAction.String
}
if splitValueNum.Valid && splitValueDenom.Valid {
split.ValueNum = splitValueNum.Int64
split.ValueDenom = splitValueDenom.Int64
}
if splitQuantityNum.Valid && splitQuantityDenom.Valid {
split.QuantityNum = splitQuantityNum.Int64
split.QuantityDenom = splitQuantityDenom.Int64
}
if splitReconcileDate.Valid {
rd, err := time.Parse("2006-01-02 15:04:05", splitReconcileDate.String)
if err != nil {
return nil, err
}
split.ReconcileDate = &rd
}
if splitLogGUID.Valid {
split.LogGUID = &splitLogGUID.String
}
if accountGUID.Valid {
account := Account{
GUID: accountGUID.String,
Name: accountName.String,
AccountType: accountAccountType.String,
CommoditySCU: accountCommoditySCU.Int64,
NonSTDSCU: accountNonStdSCU.Int64,
}
if accountCommodityGUID.Valid {
account.CommodityGUID = &accountCommodityGUID.String
}
if accountParentGUID.Valid {
account.ParentGUID = &accountParentGUID.String
}
if accountCode.Valid {
account.Code = &accountCode.String
}
if accountDescription.Valid {
account.Description = &accountDescription.String
}
if accountHidden.Valid {
account.Hidden = &accountHidden.Int64
}
if accountPlaceholder.Valid {
account.Placeholder = &accountPlaceholder.Int64
}
split.Account = &account
}
transaction.Splits = append(transaction.Splits, &split)
}
}
if err := rows.Err(); err != nil {
return nil, err
}
result := make([]*Transaction, 0, len(orderedGUIDs))
for _, guid := range orderedGUIDs {
result = append(result, transactionMap[guid])
}
return result, nil
}