Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions spregistry/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package spregistry

import (
"context"
"encoding/json"
"fmt"
"math/big"
"strings"
Expand Down Expand Up @@ -349,19 +350,8 @@ func (c *Contract) GetProvider(ctx context.Context, providerID *big.Int) (*GetPr
return nil, fmt.Errorf("getProvider call failed: %w", err)
}

var res struct {
ProviderID *big.Int `abi:"providerId"`
Info struct {
ServiceProvider common.Address `abi:"serviceProvider"`
Payee common.Address `abi:"payee"`
Name string `abi:"name"`
Description string `abi:"description"`
IsActive bool `abi:"isActive"`
} `abi:"info"`
}

err = c.abi.UnpackIntoInterface(&res, "getProvider", result)
if err != nil {
var res getProviderByAddressOutput
if err := unpackSingleTuple(c.abi, "getProvider", result, &res); err != nil {
return nil, fmt.Errorf("failed to unpack getProvider result: %w", err)
}

Expand Down Expand Up @@ -391,19 +381,8 @@ func (c *Contract) GetProviderByAddress(ctx context.Context, addr common.Address
return nil, fmt.Errorf("getProviderByAddress call failed: %w", err)
}

var res struct {
ProviderID *big.Int `abi:"providerId"`
Info struct {
ServiceProvider common.Address `abi:"serviceProvider"`
Payee common.Address `abi:"payee"`
Name string `abi:"name"`
Description string `abi:"description"`
IsActive bool `abi:"isActive"`
} `abi:"info"`
}

err = c.abi.UnpackIntoInterface(&res, "getProviderByAddress", result)
if err != nil {
var res getProviderByAddressOutput
if err := unpackSingleTuple(c.abi, "getProviderByAddress", result, &res); err != nil {
return nil, fmt.Errorf("failed to unpack getProviderByAddress result: %w", err)
}

Expand All @@ -419,6 +398,44 @@ func (c *Contract) GetProviderByAddress(ctx context.Context, addr common.Address
}, nil
}

// getProviderByAddressOutput mirrors the (providerId, info) tuple
// getProviderByAddress returns. Tagged for json round-trip via
// unpackSingleTuple below.
type getProviderByAddressOutput struct {
ProviderID *big.Int `json:"providerId"`
Info getProviderByAddressOutputInfo `json:"info"`
}

type getProviderByAddressOutputInfo struct {
ServiceProvider common.Address `json:"serviceProvider"`
Payee common.Address `json:"payee"`
Name string `json:"name"`
Description string `json:"description"`
IsActive bool `json:"isActive"`
}

// unpackSingleTuple decodes an ABI method's single-tuple return into dst
// via abi.Unpack + json round-trip. UnpackIntoInterface mishandles this
// shape; Unpack returns the right anonymous struct, json copies it into
// dst by matching json tags. dst must be a pointer to a tagged struct.
func unpackSingleTuple(parsed abi.ABI, method string, payload []byte, dst any) error {
out, err := parsed.Unpack(method, payload)
if err != nil {
return err
}
if len(out) != 1 {
return fmt.Errorf("%s: expected 1 output, got %d", method, len(out))
}
buf, err := json.Marshal(out[0])
if err != nil {
return fmt.Errorf("%s: marshal unpacked tuple: %w", method, err)
}
if err := json.Unmarshal(buf, dst); err != nil {
return fmt.Errorf("%s: decode into %T: %w", method, dst, err)
}
return nil
}

func (c *Contract) GetProviderIDByAddress(ctx context.Context, addr common.Address) (*big.Int, error) {
data, err := c.abi.Pack("getProviderIdByAddress", addr)
if err != nil {
Expand Down Expand Up @@ -446,10 +463,10 @@ func (c *Contract) GetProviderIDByAddress(ctx context.Context, addr common.Addre
}

type GetProviderWithProductResult struct {
ProviderID *big.Int
ProviderInfo RawProviderInfo
Product RawProduct
ProductCapabilityValues [][]byte
ProviderID *big.Int
ProviderInfo RawProviderInfo
Product RawProduct
ProductCapabilityValues [][]byte
}

func (c *Contract) GetProviderWithProduct(ctx context.Context, providerID *big.Int, productType uint8) (*GetProviderWithProductResult, error) {
Expand Down
77 changes: 77 additions & 0 deletions spregistry/contract_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package spregistry

import (
"math/big"
"strings"
"testing"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
)

// TestUnpackSingleTuple_GetProviderByAddress exercises the unpack path
// Contract.GetProviderByAddress uses, against a synthetic return blob.
// Reproduces the calibnet bug if unpackSingleTuple regresses to
// UnpackIntoInterface (which mishandles this shape).
func TestUnpackSingleTuple_GetProviderByAddress(t *testing.T) {
parsedABI, err := abi.JSON(strings.NewReader(SPRegistryABIJSON))
if err != nil {
t.Fatalf("parse ABI: %v", err)
}

method, ok := parsedABI.Methods["getProviderByAddress"]
if !ok {
t.Fatalf("getProviderByAddress not found in ABI")
}

type infoT struct {
ServiceProvider common.Address `abi:"serviceProvider"`
Payee common.Address `abi:"payee"`
Name string `abi:"name"`
Description string `abi:"description"`
IsActive bool `abi:"isActive"`
}
type outT struct {
ProviderID *big.Int `abi:"providerId"`
Info infoT `abi:"info"`
}
want := outT{
ProviderID: big.NewInt(24),
Info: infoT{
ServiceProvider: common.HexToAddress("0xE3e842B9D89ed2Ee3976b9b8916827302618c29e"),
Payee: common.HexToAddress("0xE3e842B9D89ed2Ee3976b9b8916827302618c29e"),
Name: "sp-playground",
Description: "calibnet test SP",
IsActive: true,
},
}

payload, err := method.Outputs.Pack(want)
if err != nil {
t.Fatalf("pack synthetic return: %v", err)
}

var got getProviderByAddressOutput
if err := unpackSingleTuple(parsedABI, "getProviderByAddress", payload, &got); err != nil {
t.Fatalf("unpackSingleTuple: %v", err)
}

if got.ProviderID == nil || got.ProviderID.Cmp(big.NewInt(24)) != 0 {
t.Errorf("ProviderID = %v, want 24", got.ProviderID)
}
if got.Info.ServiceProvider != want.Info.ServiceProvider {
t.Errorf("ServiceProvider = %s, want %s", got.Info.ServiceProvider, want.Info.ServiceProvider)
}
if got.Info.Payee != want.Info.Payee {
t.Errorf("Payee = %s, want %s", got.Info.Payee, want.Info.Payee)
}
if got.Info.Name != want.Info.Name {
t.Errorf("Name = %q, want %q", got.Info.Name, want.Info.Name)
}
if got.Info.Description != want.Info.Description {
t.Errorf("Description = %q, want %q", got.Info.Description, want.Info.Description)
}
if !got.Info.IsActive {
t.Errorf("IsActive = false, want true")
}
}
Loading