Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,12 @@ func executeWithRetry(
// comment s.host.dec() line to avoid double increment; issue #322
// s.host.dec()
s.host.SetIsActive(false)
nextHost := s.cluster.getHost()
var nextHost *topology.Node
if s.replicaNum > 0 || s.nodeNum > 0 {
nextHost = s.cluster.getSpecificHost(s.replicaNum, s.nodeNum)
} else {
nextHost = s.cluster.getHost()
}
// The query could be retried if it has no stickiness to a certain server
if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" {
// the query execution has been failed
Expand Down Expand Up @@ -917,6 +922,11 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
}

s := newScope(req, u, c, cu, sessionId, sessionTimeout)
replicaNum, nodeNum, err := getSpecificHostNum(req, c)
if err != nil {
return nil, http.StatusBadRequest, err
}

s := newScope(req, u, c, cu, sessionId, sessionTimeout, replicaNum, nodeNum)
return s, 0, nil
}
90 changes: 85 additions & 5 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type scope struct {

sessionId string
sessionTimeout int
replicaNum int
nodeNum int

remoteAddr string
localAddr string
Expand All @@ -57,10 +59,14 @@ type scope struct {
requestPacketSize int
}

func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope {
h := c.getHost()
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int, replicaNum, nodeNum int) *scope {
var h *topology.Node
if sessionId != "" {
h = c.getHostSticky(sessionId)
} else if replicaNum > 0 || nodeNum > 0 {
h = c.getSpecificHost(replicaNum, nodeNum)
} else {
h = c.getHost()
}
var localAddr string
if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
Expand All @@ -75,6 +81,8 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId
clusterUser: cu,
sessionId: sessionId,
sessionTimeout: sessionTimeout,
replicaNum: replicaNum,
nodeNum: nodeNum,

remoteAddr: req.RemoteAddr,
localAddr: localAddr,
Expand Down Expand Up @@ -185,11 +193,13 @@ func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, lab
var h *topology.Node
// Choose new host, since the previous one may become obsolete
// after sleeping.
if s.sessionId == "" {
h = s.cluster.getHost()
} else {
if s.sessionId != "" {
// if request has session_id, set same host
h = s.cluster.getHostSticky(s.sessionId)
} else if s.replicaNum > 0 || s.nodeNum > 0 {
h = s.cluster.getSpecificHost(s.replicaNum, s.nodeNum)
} else {
h = s.cluster.getHost()
}

s.host = h
Expand Down Expand Up @@ -720,6 +730,8 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c
return nil, err
}
r.hosts = hosts
c.maxNodeNum = len(r.hosts)
c.maxReplicaNum = 1
return []*replica{r}, nil
}

Expand All @@ -735,7 +747,9 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c
}
r.hosts = hosts
replicas[i] = r
c.maxNodeNum = max(c.maxNodeNum, len(r.hosts))
}
c.maxReplicaNum = len(replicas)
return replicas, nil
}

Expand Down Expand Up @@ -775,6 +789,9 @@ type cluster struct {
replicas []*replica
nextReplicaIdx uint32

maxReplicaNum int
maxNodeNum int

users map[string]*clusterUser

killQueryUserName string
Expand Down Expand Up @@ -937,6 +954,59 @@ func (r *replica) getHostSticky(sessionId string) *topology.Node {
return h
}

// getSpecificReplica returns specific replica by replicaNum from the cluster.
//
// Always returns non-nil.
func (c *cluster) getSpecificReplica(replicaNum, nodeNum int) *replica {
if replicaNum > 0 {
return c.replicas[replicaNum-1]
}
if nodeNum == 0 {
return c.getReplica()
}

idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
n := uint32(len(c.replicas))
if n == 1 {
return c.replicas[0]
}

var r *replica
reqs := ^uint32(0)

// Scan all the replicas for the least loaded and nodeNum-satisfied replica.
for i := uint32(0); i < n; i++ {
tmpIdx := (idx + i) % n
tmpR := c.replicas[tmpIdx]
if nodeNum > len(tmpR.hosts) {
continue
}
if tmpR.isActive() || r == nil {
tmpReqs := tmpR.load()
if tmpReqs < reqs || !r.isActive() {
r = tmpR
reqs = tmpReqs
}
}
}

// The returned replica may be inactive. This is OK,
// since this means all the nodeNum-satisfied replicas are inactive,
// so let's try proxying the request to any replica.
return r
}

// getSpecificHost returns specific host by nodeNum from replica.
//
// Always returns non-nil.
func (r *replica) getSpecificHost(nodeNum int) *topology.Node {
if nodeNum > 0 {
return r.hosts[nodeNum-1]
}

return r.getHost()
}

// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
Expand Down Expand Up @@ -991,6 +1061,16 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node {
return r.getHostSticky(sessionId)
}

// getSpecificHost returns specific host by num from cluster.
// Both replicaNum/nodeNum start from 1 and satisfy [0, maxReplicaNum/maxNodeNum], 0 means no specific host num.
// If both are 0, getSpecificHost equals to getHost.
//
// Always returns non-nil.
func (c *cluster) getSpecificHost(replicaNum, nodeNum int) *topology.Node {
r := c.getSpecificReplica(replicaNum, nodeNum)
return r.getSpecificHost(nodeNum)
}

// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
Expand Down
78 changes: 78 additions & 0 deletions scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,78 @@ func TestGetHostSticky(t *testing.T) {
}
}

func TestGetSpecificHost(t *testing.T) {
c := testGetCluster()

t.Run("SpecifyReplicaNum", func(t *testing.T) {
h := c.getSpecificHost(1, 0)
if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.22" {
t.Fatalf("Expected host from replica1, got: %s", h.Host())
}

h = c.getSpecificHost(2, 0)
if h.Host() != "127.0.0.33" && h.Host() != "127.0.0.44" {
t.Fatalf("Expected host from replica2, got: %s", h.Host())
}

h = c.getSpecificHost(3, 0)
if h.Host() != "127.0.0.55" && h.Host() != "127.0.0.66" {
t.Fatalf("Expected host from replica3, got: %s", h.Host())
}
})

t.Run("SpecifyNodeNum", func(t *testing.T) {
h := c.getSpecificHost(0, 1)
if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.33" && h.Host() != "127.0.0.55" {
t.Fatalf("Expected first node from any replica, got: %s", h.Host())
}

h = c.getSpecificHost(0, 2)
if h.Host() != "127.0.0.22" && h.Host() != "127.0.0.44" && h.Host() != "127.0.0.66" {
t.Fatalf("Expected second node from any replica, got: %s", h.Host())
}
})

t.Run("SpecifyReplicaNumAndNodeNum", func(t *testing.T) {
h := c.getSpecificHost(1, 1)
if h.Host() != "127.0.0.11" {
t.Fatalf("Expected 127.0.0.11, got: %s", h.Host())
}

h = c.getSpecificHost(1, 2)
if h.Host() != "127.0.0.22" {
t.Fatalf("Expected 127.0.0.22, got: %s", h.Host())
}

h = c.getSpecificHost(2, 1)
if h.Host() != "127.0.0.33" {
t.Fatalf("Expected 127.0.0.33, got: %s", h.Host())
}
})

t.Run("SpecifyBothNumsZero", func(t *testing.T) {
h := c.getSpecificHost(0, 0)
if h == nil {
t.Fatalf("getSpecificHost(0, 0) returned nil")
}
found := false
for _, r := range c.replicas {
for _, node := range r.hosts {
if h.Host() == node.Host() {
found = true
break
}
}
if found {
break
}
}
if !found {
t.Fatalf("getSpecificHost(0, 0) returned unknown host: %s", h.Host())
}
})
}

func TestIncQueued(t *testing.T) {
u := testGetUser()
cu := testGetClusterUser()
Expand Down Expand Up @@ -485,6 +557,12 @@ func testGetCluster() *cluster {
topology.NewNode(&url.URL{Host: "127.0.0.66"}, nil, "", r3.name, topology.WithDefaultActiveState(true)),
}
r3.name = "replica3"

c.maxReplicaNum = len(c.replicas)
for _, r := range c.replicas {
c.maxNodeNum = max(c.maxNodeNum, len(r.hosts))
}

return c
}

Expand Down
47 changes: 47 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,53 @@ func getSessionTimeout(req *http.Request) int {
return 60
}

// getSpecificHostNum retrieves specific host num, including replica and node num
// num starts from 1, 0 means no specific host num
// shard_num is alias for node_num, and override node_num if both are specified
func getSpecificHostNum(req *http.Request, c *cluster) (int, int, error) {
params := req.URL.Query()
var replicaNum, nodeNum int
var err error
// replica num
replicaNumStr := params.Get("replica_num")
if replicaNumStr != "" {
replicaNum, err = strconv.Atoi(replicaNumStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid replica num %q", replicaNumStr)
}
if replicaNum < 0 || replicaNum > c.maxReplicaNum {
return -1, -1, fmt.Errorf("invalid replica num %q", replicaNumStr)
}
}
// node num (shard_num is alias for node_num)
nodeNumStr := params.Get("node_num")
if nodeNumStr != "" {
nodeNum, err = strconv.Atoi(nodeNumStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid node num %q", nodeNumStr)
}
if nodeNum < 0 || nodeNum > c.maxNodeNum {
return -1, -1, fmt.Errorf("invalid node num %q", nodeNumStr)
}
}
shardNumStr := params.Get("shard_num")
if shardNumStr != "" {
nodeNum, err = strconv.Atoi(shardNumStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid shard num %q", shardNumStr)
}
if nodeNum < 0 || nodeNum > c.maxNodeNum {
return -1, -1, fmt.Errorf("invalid shard num %q", shardNumStr)
}
}
// validate if both replicaNum and nodeNum are specified
if replicaNum > 0 && nodeNum > 0 && nodeNum > len(c.replicas[replicaNum-1].hosts) {
return -1, -1, fmt.Errorf("invalid host num (%q, %q)", replicaNumStr, nodeNumStr)
}

return replicaNum, nodeNum, nil
}

// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
Expand Down
Loading