Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TAG := $(shell git rev-list --tags --max-count=1)
VERSION := $(shell git describe --tags ${TAG})
.PHONY: build check fmt lint test test-race vet test-cover-html help install proto ui compose-up-dev
.DEFAULT_GOAL := build
PROTON_COMMIT := "80ce71da5211fde064d83731a80cee2dffbdf84b"
PROTON_COMMIT := "adedf8597689fa6ed65fa5695b287fea679c5dc3"

ui:
@echo " > generating ui build"
Expand Down
82 changes: 74 additions & 8 deletions billing/credit/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"time"

"github.com/pkg/errors"

"github.com/raystack/frontier/billing/customer"
"github.com/raystack/frontier/core/auditrecord"
"github.com/raystack/frontier/internal/bootstrap/schema"
pkgAuditRecord "github.com/raystack/frontier/pkg/auditrecord"
)

type TransactionRepository interface {
Expand All @@ -19,13 +21,25 @@ type TransactionRepository interface {
GetBalanceForRange(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error)
}

type CustomerRepository interface {
GetByID(ctx context.Context, id string) (customer.Customer, error)
}

type AuditRecordRepository interface {
Create(ctx context.Context, record auditrecord.AuditRecord) (auditrecord.AuditRecord, error)
}

type Service struct {
transactionRepository TransactionRepository
customerRepository CustomerRepository
auditRepository AuditRecordRepository
}

func NewService(repository TransactionRepository) *Service {
func NewService(repository TransactionRepository, customerRepo CustomerRepository, auditRepo AuditRecordRepository) *Service {
return &Service{
transactionRepository: repository,
customerRepository: customerRepo,
auditRepository: auditRepo,
}
}

Expand All @@ -42,15 +56,16 @@ func (s Service) Add(ctx context.Context, cred Credit) error {
txSource = cred.Source
}

_, err := s.transactionRepository.CreateEntry(ctx, Transaction{
debitEntry := Transaction{
CustomerID: schema.PlatformOrgID.String(),
Type: DebitType,
Amount: cred.Amount,
Description: cred.Description,
Source: txSource,
UserID: cred.UserID,
Metadata: cred.Metadata,
}, Transaction{
}
creditEntry := Transaction{
ID: cred.ID,
Type: CreditType,
CustomerID: cred.CustomerID,
Expand All @@ -59,13 +74,22 @@ func (s Service) Add(ctx context.Context, cred Credit) error {
Source: txSource,
UserID: cred.UserID,
Metadata: cred.Metadata,
})
}

_, err := s.transactionRepository.CreateEntry(ctx, debitEntry, creditEntry)
if err != nil {
if errors.Is(err, ErrAlreadyApplied) {
return ErrAlreadyApplied
}
return fmt.Errorf("transactionRepository.CreateEntry: %w", err)
}

if creditEntry.CustomerID != schema.PlatformOrgID.String() {
if err := s.createAuditRecord(ctx, creditEntry.CustomerID, pkgAuditRecord.BillingTransactionCreditEvent, creditEntry.ID, creditEntry); err != nil {
return err
}
}

return nil
}

Expand All @@ -82,7 +106,7 @@ func (s Service) Deduct(ctx context.Context, cred Credit) error {
txSource = cred.Source
}

if _, err := s.transactionRepository.CreateEntry(ctx, Transaction{
debitEntry := Transaction{
ID: cred.ID,
CustomerID: cred.CustomerID,
Type: DebitType,
Expand All @@ -91,22 +115,34 @@ func (s Service) Deduct(ctx context.Context, cred Credit) error {
Source: txSource,
UserID: cred.UserID,
Metadata: cred.Metadata,
}, Transaction{
}
creditEntry := Transaction{
Type: CreditType,
CustomerID: schema.PlatformOrgID.String(),
Amount: cred.Amount,
Description: cred.Description,
Source: txSource,
UserID: cred.UserID,
Metadata: cred.Metadata,
}); err != nil {
}

_, err := s.transactionRepository.CreateEntry(ctx, debitEntry, creditEntry)
if err != nil {
if errors.Is(err, ErrAlreadyApplied) {
return ErrAlreadyApplied
} else if errors.Is(err, ErrInsufficientCredits) {
return ErrInsufficientCredits
}
return fmt.Errorf("failed to deduct credits: %w", err)
}

// Create audit record after transaction succeeds
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
if err := s.createAuditRecord(ctx, debitEntry.CustomerID, pkgAuditRecord.BillingTransactionDebitEvent, debitEntry.ID, debitEntry); err != nil {
return err
}
}

return nil
}

Expand All @@ -131,3 +167,33 @@ func (s Service) GetBalanceForRange(ctx context.Context, accountID string, start
func (s Service) GetByID(ctx context.Context, id string) (Transaction, error) {
return s.transactionRepository.GetByID(ctx, id)
}

// createAuditRecord creates an audit record for billing transaction events.
func (s Service) createAuditRecord(ctx context.Context, customerID string, eventType pkgAuditRecord.Event, txID string, txEntry Transaction) error {
customerAcc, err := s.customerRepository.GetByID(ctx, customerID)
if err != nil {
return err
}

_, err = s.auditRepository.Create(ctx, auditrecord.AuditRecord{
Event: eventType,
Resource: auditrecord.Resource{
ID: customerID,
Type: pkgAuditRecord.BillingCustomerType,
Name: customerAcc.Name,
},
Target: &auditrecord.Target{
ID: txID,
Type: pkgAuditRecord.BillingTransactionType,
Metadata: map[string]interface{}{
"amount": txEntry.Amount,
"source": txEntry.Source,
"description": txEntry.Description,
"user_id": txEntry.UserID,
},
},
OccurredAt: time.Now(),
OrgID: customerAcc.OrgID,
})
return err
}
10 changes: 5 additions & 5 deletions billing/credit/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func mockService(t *testing.T) (*credit.Service, *mocks.TransactionRepository) {
t.Helper()
mockTransaction := mocks.NewTransactionRepository(t)
return credit.NewService(mockTransaction), mockTransaction
return credit.NewService(mockTransaction, nil, nil), mockTransaction
}

func TestService_GetBalance(t *testing.T) {
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestService_Add(t *testing.T) {
cred: credit.Credit{
ID: "12",
Amount: 10,
CustomerID: "",
CustomerID: schema.PlatformOrgID.String(),
Metadata: metadata.Metadata{
"a": "a",
},
Expand All @@ -217,7 +217,7 @@ func TestService_Add(t *testing.T) {
want: nil,
setup: func() *credit.Service {
s, mockTransactionRepo := mockService(t)
mockTransactionRepo.EXPECT().CreateEntry(ctx, credit.Transaction{CustomerID: schema.PlatformOrgID.String(), Type: credit.DebitType, Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}, credit.Transaction{Type: credit.CreditType, Amount: 10, ID: "12", Source: "system", Metadata: metadata.Metadata{"a": "a"}}).Return([]credit.Transaction{}, nil)
mockTransactionRepo.EXPECT().CreateEntry(ctx, credit.Transaction{CustomerID: schema.PlatformOrgID.String(), Type: credit.DebitType, Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}, credit.Transaction{Type: credit.CreditType, Amount: 10, ID: "12", CustomerID: schema.PlatformOrgID.String(), Source: "system", Metadata: metadata.Metadata{"a": "a"}}).Return([]credit.Transaction{}, nil)
return s
},
},
Expand Down Expand Up @@ -333,7 +333,7 @@ func TestService_Deduct(t *testing.T) {
cred: credit.Credit{
ID: "12",
Amount: 10,
CustomerID: "customer_id",
CustomerID: schema.PlatformOrgID.String(),
Metadata: metadata.Metadata{
"a": "a",
},
Expand All @@ -342,7 +342,7 @@ func TestService_Deduct(t *testing.T) {
want: nil,
setup: func() *credit.Service {
s, mockTransactionRepo := mockService(t)
mockTransactionRepo.EXPECT().CreateEntry(ctx, credit.Transaction{ID: "12", CustomerID: "customer_id", Type: credit.DebitType, Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}, credit.Transaction{Type: credit.CreditType, CustomerID: schema.PlatformOrgID.String(), Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}).Return([]credit.Transaction{}, nil)
mockTransactionRepo.EXPECT().CreateEntry(ctx, credit.Transaction{ID: "12", CustomerID: schema.PlatformOrgID.String(), Type: credit.DebitType, Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}, credit.Transaction{Type: credit.CreditType, CustomerID: schema.PlatformOrgID.String(), Amount: 10, Source: "system", Metadata: metadata.Metadata{"a": "a"}}).Return([]credit.Transaction{}, nil)
return s
},
},
Expand Down
9 changes: 7 additions & 2 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,15 @@ func buildAPIDependencies(
}
stripeClient := GetStripeClientFunc(logger, cfg)

creditService := credit.NewService(postgres.NewBillingTransactionRepository(dbc))
billingCustomerRepository := postgres.NewBillingCustomerRepository(dbc)
creditService := credit.NewService(
postgres.NewBillingTransactionRepository(dbc),
billingCustomerRepository,
auditRecordRepository,
)
customerService := customer.NewService(
stripeClient,
postgres.NewBillingCustomerRepository(dbc), cfg.Billing, creditService)
billingCustomerRepository, cfg.Billing, creditService)
featureRepository := postgres.NewBillingFeatureRepository(dbc)
priceRepository := postgres.NewBillingPriceRepository(dbc)
productService := product.NewService(
Expand Down
2 changes: 2 additions & 0 deletions core/auditrecord/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type AuditRecord struct {
Target *Target `json:"target"`
OccurredAt time.Time `json:"occurred_at"`
OrgID string `json:"org_id"`
OrgName string `json:"org_name"`
RequestID *string `json:"request_id"`
CreatedAt time.Time `json:"created_at,omitempty"`
Metadata metadata.Metadata `json:"metadata"`
Expand Down Expand Up @@ -65,6 +66,7 @@ type AuditRecordRQLSchema struct {
TargetName string `rql:"name=target_name,type=string"`
OccurredAt time.Time `rql:"name=occurred_at,type=datetime"`
OrgID string `rql:"name=org_id,type=string"`
OrgName string `rql:"name=org_name,type=string"`
RequestID string `rql:"name=request_id,type=string"`
CreatedAt time.Time `rql:"name=created_at,type=datetime"`
IdempotencyKey string `rql:"name=idempotency_key,type=string"`
Expand Down
1 change: 1 addition & 0 deletions internal/api/v1beta1connect/audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ func TransformAuditRecordToPB(record auditrecord.AuditRecord) (*frontierv1beta1.
Target: target,
OccurredAt: timestamppb.New(record.OccurredAt),
OrgId: record.OrgID,
OrgName: record.OrgName,
RequestId: requestID,
CreatedAt: timestamppb.New(record.CreatedAt),
Metadata: metaData,
Expand Down
12 changes: 12 additions & 0 deletions internal/store/postgres/audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type AuditRecord struct {
TargetName sql.NullString `db:"target_name"`
TargetMetadata types.NullJSONText `db:"target_metadata"`
OrganizationID uuid.UUID `db:"org_id"`
OrganizationName string `db:"org_name"`
RequestID sql.NullString `db:"request_id"`
OccurredAt time.Time `db:"occurred_at"`
CreatedAt time.Time `db:"created_at" goqu:"skipinsert"`
Expand Down Expand Up @@ -80,6 +81,7 @@ func (ar *AuditRecord) transformToDomain() (auditrecord.AuditRecord, error) {
Target: nullStringToTargetPtr(ar.TargetID, ar.TargetType, ar.TargetName, ar.TargetMetadata),
OccurredAt: ar.OccurredAt,
OrgID: ar.OrganizationID.String(),
OrgName: ar.OrganizationName,
RequestID: nullStringToPtr(ar.RequestID),
CreatedAt: ar.CreatedAt,
Metadata: nullJSONTextToMetadata(ar.Metadata),
Expand Down Expand Up @@ -267,6 +269,16 @@ func BuildAuditRecord(ctx context.Context, event pkgAuditRecord.Event, resource

// InsertAuditRecordInTx inserts an audit record within a transaction
func InsertAuditRecordInTx(ctx context.Context, tx *sqlx.Tx, record AuditRecord) error {
// Enrich the organization name from DB
if record.OrganizationID != uuid.Nil {
var orgName string
query, params, err := buildOrgNameQuery(record.OrganizationID)
if err == nil {
_ = tx.QueryRowContext(ctx, query, params...).Scan(&orgName)
record.OrganizationName = orgName
}
}

query, params, err := dialect.Insert(TABLE_AUDITRECORDS).
Rows(record).
ToSQL()
Expand Down
24 changes: 20 additions & 4 deletions internal/store/postgres/audit_record_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,24 @@ func NewAuditRecordRepository(dbc *db.Client) *AuditRecordRepository {
var (
auditRecordRQLFilterSupportedColumns = []string{
"id", "event", "actor_id", "actor_type", "actor_name", "resource_id", "resource_type", "resource_name",
"target_id", "target_type", "target_name", "occurred_at", "org_id", "request_id", "created_at", "idempotency_key",
"target_id", "target_type", "target_name", "occurred_at", "org_id", "org_name", "request_id", "created_at", "idempotency_key",
}
auditRecordRQLSearchSupportedColumns = []string{
"id", "event", "actor_id", "actor_type", "actor_name", "resource_id", "resource_type", "resource_name",
"target_id", "target_type", "target_name", "org_id", "request_id", "idempotency_key",
"target_id", "target_type", "target_name", "org_id", "org_name", "request_id", "idempotency_key",
}
auditRecordRQLGroupSupportedColumns = []string{
"event", "actor_type", "resource_type", "target_type", "org_id",
"event", "actor_type", "resource_type", "target_type", "org_id", "org_name",
}
)

func buildOrgNameQuery(orgID interface{}) (string, []interface{}, error) {
return dialect.Select("name").
From(TABLE_ORGANIZATIONS).
Where(goqu.Ex{"id": orgID}).
ToSQL()
}

func (r AuditRecordRepository) Create(ctx context.Context, auditRecord auditrecord.AuditRecord) (auditrecord.AuditRecord, error) {
// External RPC calls will have actor already enriched by service.
// Internal service calls will have empty actor and need enrichment from context.
Expand All @@ -76,6 +83,14 @@ func (r AuditRecordRepository) Create(ctx context.Context, auditRecord auditreco
return auditrecord.AuditRecord{}, err
}

// Enrich organization name from DB
if auditRecord.OrgID != "" {
query, params, err := buildOrgNameQuery(auditRecord.OrgID)
if err == nil {
_ = r.dbc.QueryRowxContext(ctx, query, params...).Scan(&dbRecord.OrganizationName)
}
}

var auditRecordModel AuditRecord
query, params, err := dialect.Insert(TABLE_AUDITRECORDS).Rows(dbRecord).Returning(&AuditRecord{}).ToSQL()
if err != nil {
Expand Down Expand Up @@ -257,6 +272,7 @@ func (r AuditRecordRepository) Export(ctx context.Context, rqlQuery *rql.Query)
goqu.L(`COALESCE(target_name, '')`),
goqu.L(`COALESCE(target_metadata::text, '{}')`),
goqu.L(`COALESCE(org_id::text, '')`),
goqu.L(`COALESCE(org_name, '')`),
goqu.L(`COALESCE(request_id, '')`),
goqu.L(`to_char(occurred_at AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS.MS TZ')`),
goqu.L(`to_char(created_at AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS.MS TZ')`),
Expand Down Expand Up @@ -373,7 +389,7 @@ func (r AuditRecordRepository) streamCursorToCSV(ctx context.Context, tx *sql.Tx
headers := []string{
"Record ID", "Idempotency Key", "Event", "Actor ID", "Actor Type", "Actor Name", "Actor Metadata",
"Resource ID", "Resource Type", "Resource Name", "Resource Metadata", "Target ID", "Target Type",
"Target name", "Target Metadata", "Organization ID", "Request ID", "Occurred At", "Created At", "Metadata",
"Target name", "Target Metadata", "Organization ID", "Organization Name", "Request ID", "Occurred At", "Created At", "Metadata",
}
if err := csvWriter.Write(headers); err != nil {
return fmt.Errorf("failed to write CSV headers: %w", err)
Expand Down
Loading
Loading