Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ocp/data/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ type DatabaseData interface {
// --------------------------------------------------------------------------------
SaveVmMetadata(ctx context.Context, record *vm_metadata.Record) error
GetVmMetadataByMint(ctx context.Context, mint string) (*vm_metadata.Record, error)
GetAllVms(ctx context.Context) ([]string, error)

// VM Storage
// --------------------------------------------------------------------------------
Expand Down Expand Up @@ -895,6 +896,9 @@ func (dp *DatabaseProvider) SaveVmMetadata(ctx context.Context, record *vm_metad
func (dp *DatabaseProvider) GetVmMetadataByMint(ctx context.Context, mint string) (*vm_metadata.Record, error) {
return dp.vmMetadata.GetByMint(ctx, mint)
}
func (dp *DatabaseProvider) GetAllVms(ctx context.Context) ([]string, error) {
return dp.vmMetadata.GetAllVms(ctx)
}

// VM RAM
// --------------------------------------------------------------------------------
Expand Down
19 changes: 19 additions & 0 deletions ocp/data/vm/metadata/memory/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ func (s *store) Save(_ context.Context, record *metadata.Record) error {
return nil
}

// GetAllVms implements vm.metadata.Store.GetAllVms
func (s *store) GetAllVms(_ context.Context) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()

seen := make(map[string]struct{})
var vms []string
for _, item := range s.records {
if _, ok := seen[item.Vm]; !ok {
seen[item.Vm] = struct{}{}
vms = append(vms, item.Vm)
}
}
if len(vms) == 0 {
return nil, metadata.ErrNotFound
}
return vms, nil
}

// GetByMint implements vm.metadata.Store.GetByMint
func (s *store) GetByMint(_ context.Context, mint string) (*metadata.Record, error) {
s.mu.Lock()
Expand Down
14 changes: 14 additions & 0 deletions ocp/data/vm/metadata/postgres/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error {
})
}

func dbGetAllVms(ctx context.Context, db *sqlx.DB) ([]string, error) {
var res []string
query := `SELECT DISTINCT vm FROM ` + tableName

err := db.SelectContext(ctx, &res, query)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, metadata.ErrNotFound
}
return res, nil
}

func dbGetByMint(ctx context.Context, db *sqlx.DB, mint string) (*model, error) {
var res model
query := `SELECT id, mint, authority, vm, vm_bump, omnibus, omnibus_bump, days_locked, state, version, created_at FROM ` + tableName + `
Expand Down
5 changes: 5 additions & 0 deletions ocp/data/vm/metadata/postgres/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (s *store) Save(ctx context.Context, record *metadata.Record) error {
return nil
}

// GetAllVms implements vm.metadata.Store.GetAllVms
func (s *store) GetAllVms(ctx context.Context) ([]string, error) {
return dbGetAllVms(ctx, s.db)
}

// GetByMint implements vm.metadata.Store.GetByMint
func (s *store) GetByMint(ctx context.Context, mint string) (*metadata.Record, error) {
obj, err := dbGetByMint(ctx, s.db, mint)
Expand Down
3 changes: 3 additions & 0 deletions ocp/data/vm/metadata/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type Store interface {

// GetByMint returns the VM metadata record for the given mint
GetByMint(ctx context.Context, mint string) (*Record, error)

// GetAllVms returns all VM public keys
GetAllVms(ctx context.Context) ([]string, error)
}

func (r *Record) Validate() error {
Expand Down
52 changes: 52 additions & 0 deletions ocp/data/vm/metadata/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
func RunTests(t *testing.T, s metadata.Store, teardown func()) {
for _, tf := range []func(t *testing.T, s metadata.Store){
testHappyPath,
testGetAllVms,
} {
tf(t, s)
teardown()
Expand Down Expand Up @@ -84,6 +85,57 @@ func testHappyPath(t *testing.T, s metadata.Store) {
})
}

func testGetAllVms(t *testing.T, s metadata.Store) {
t.Run("testGetAllVms", func(t *testing.T) {
ctx := context.Background()

// No records returns ErrNotFound
_, err := s.GetAllVms(ctx)
assert.Equal(t, metadata.ErrNotFound, err)

// Save a record and verify GetAllVms returns its VM
require.NoError(t, s.Save(ctx, &metadata.Record{
Mint: "mint1",
Authority: "authority1",
Vm: "vm1",
Omnibus: "omnibus1",
}))

vms, err := s.GetAllVms(ctx)
require.NoError(t, err)
assert.Len(t, vms, 1)
assert.Contains(t, vms, "vm1")

// Save another record with a different VM
require.NoError(t, s.Save(ctx, &metadata.Record{
Mint: "mint2",
Authority: "authority2",
Vm: "vm2",
Omnibus: "omnibus2",
}))

vms, err = s.GetAllVms(ctx)
require.NoError(t, err)
assert.Len(t, vms, 2)
assert.Contains(t, vms, "vm1")
assert.Contains(t, vms, "vm2")

// Save another record with the same VM, should not duplicate
require.NoError(t, s.Save(ctx, &metadata.Record{
Mint: "mint3",
Authority: "authority3",
Vm: "vm1",
Omnibus: "omnibus3",
}))

vms, err = s.GetAllVms(ctx)
require.NoError(t, err)
assert.Len(t, vms, 2)
assert.Contains(t, vms, "vm1")
assert.Contains(t, vms, "vm2")
})
}

func assertEquivalentRecords(t *testing.T, obj1, obj2 *metadata.Record) {
assert.Equal(t, obj1.Mint, obj2.Mint)
assert.Equal(t, obj1.Authority, obj2.Authority)
Expand Down
Loading