Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/gcs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/internal/oc"
"github.com/Microsoft/hcsshim/internal/version"
"github.com/Microsoft/hcsshim/pkg/amdsevsnp"
"github.com/Microsoft/hcsshim/pkg/securitypolicy"
)

Expand Down Expand Up @@ -359,9 +360,22 @@ func main() {
logrus.WithError(err).Fatal("failed to initialize new runc runtime")
}
mux := bridge.NewBridgeMux()

forceSequential, err := amdsevsnp.IsSNP()
if err != nil {
// IsSNP cannot fail on LCOW
logrus.Errorf("Got unexpected error from IsSNP(): %v", err)
// If it fails, we proceed with forceSequential enabled to be safe
forceSequential = true
}

b := bridge.Bridge{
Handler: mux,
EnableV4: *v4,

// For confidential containers, we protect ourselves against attacks caused
// by concurrent modifications, by processing one request at a time.
ForceSequential: forceSequential,
}
h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter)
// Initialize virtual pod support in the host
Expand Down
53 changes: 53 additions & 0 deletions internal/gcs/unrecoverable_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package gcs

import (
"context"
"fmt"
"os"
"runtime"
"time"

"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/pkg/amdsevsnp"
"github.com/sirupsen/logrus"
)

// UnrecoverableError logs the error and then puts the current thread into an
// infinite sleep loop. This is to be used instead of panicking, as the
// behaviour of GCS panics is unpredictable. This function can be extended to,
// for example, try to shutdown the VM cleanly.
func UnrecoverableError(err error) {
buf := make([]byte, 300*(1<<10))
stackSize := runtime.Stack(buf, true)
stackTrace := string(buf[:stackSize])

errPrint := fmt.Sprintf(
"Unrecoverable error in GCS: %v\n%s",
err, stackTrace,
)

isSnp, err := amdsevsnp.IsSNP()
if err != nil {
// IsSNP() cannot fail on LCOW
// but if it does, we proceed as if we're on SNP to be safe.
isSnp = true
}

if isSnp {
errPrint += "\nThis thread will now enter an infinite loop."
}
log.G(context.Background()).WithError(err).Logf(
logrus.FatalLevel,
"%s",
errPrint,
)

if !isSnp {
panic("Unrecoverable error in GCS: " + err.Error())
} else {
fmt.Fprintf(os.Stderr, "%s\n", errPrint)
for {
time.Sleep(time.Hour)
}
}
}
94 changes: 71 additions & 23 deletions internal/guest/bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ type Bridge struct {
Handler Handler
// EnableV4 enables the v4+ bridge and the schema v2+ interfaces.
EnableV4 bool
// Setting ForceSequential to true will force the bridge to only process one
// request at a time, except for certain long-running operations (as defined
// in asyncMessages).
ForceSequential bool

// responseChan is the response channel used for both request/response
// and publish notification workflows.
Expand All @@ -191,6 +195,14 @@ type Bridge struct {
protVer prot.ProtocolVersion
}

// Messages that will be processed asynchronously even in sequential mode. Note
// that in sequential mode, these messages will still wait for any in-progress
// non-async messages to be handled before they are processed, but once they are
// "acknowledged", the rest will be done asynchronously.
var alwaysAsyncMessages map[prot.MessageIdentifier]bool = map[prot.MessageIdentifier]bool{
prot.ComputeSystemWaitForProcessV1: true,
}

// AssignHandlers creates and assigns the appropriate bridge
// events to be listen for and intercepted on `mux` before forwarding
// to `gcs` for handling.
Expand Down Expand Up @@ -238,6 +250,10 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
defer close(requestErrChan)
defer bridgeIn.Close()

if b.ForceSequential {
log.G(context.Background()).Info("bridge: ForceSequential enabled")
}

// Receive bridge requests and schedule them to be processed.
go func() {
var recverr error
Expand Down Expand Up @@ -340,30 +356,36 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
}()
// Process each bridge request async and create the response writer.
go func() {
for req := range requestChan {
go func(r *Request) {
br := bridgeResponse{
ctx: r.Context,
header: &prot.MessageHeader{
Type: prot.GetResponseIdentifier(r.Header.Type),
ID: r.Header.ID,
},
}
resp, err := b.Handler.ServeMsg(r)
if resp == nil {
resp = &prot.MessageResponseBase{}
}
resp.Base().ActivityID = r.ActivityID
if err != nil {
span := trace.FromContext(r.Context)
if span != nil {
oc.SetSpanStatus(span, err)
}
setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */)
doOneRequest := func(r *Request) {
br := bridgeResponse{
ctx: r.Context,
header: &prot.MessageHeader{
Type: prot.GetResponseIdentifier(r.Header.Type),
ID: r.Header.ID,
},
}
resp, err := b.Handler.ServeMsg(r)
if resp == nil {
resp = &prot.MessageResponseBase{}
}
resp.Base().ActivityID = r.ActivityID
if err != nil {
span := trace.FromContext(r.Context)
if span != nil {
oc.SetSpanStatus(span, err)
}
br.response = resp
b.responseChan <- br
}(req)
setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */)
}
br.response = resp
b.responseChan <- br
}

for req := range requestChan {
if b.ForceSequential && !alwaysAsyncMessages[req.Header.Type] {
runSequentialRequestHandler(req, doOneRequest)
} else {
go doOneRequest(req)
}
}
}()
// Process each bridge response sync. This channel is for request/response and publish workflows.
Expand Down Expand Up @@ -423,6 +445,32 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
}
}

// Do handleFn(r), but prints a warning if handleFn does not, or takes too long
// to return.
func runSequentialRequestHandler(r *Request, handleFn func(*Request)) {
// Note that this is only a context used for triggering the blockage
// warning, the request processing still uses r.Context. We don't want to
// cancel the request handling itself when we reach the 5s timeout.
timeoutCtx, cancel := context.WithTimeout(r.Context, 5*time.Second)
go func() {
<-timeoutCtx.Done()
if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
log.G(timeoutCtx).WithFields(logrus.Fields{
// We want to log those even though we're providing r.Context, since if
// the request never finishes the span end log will never get written,
// and we may therefore not be able to find out about the following info
// otherwise:
"message-type": r.Header.Type.String(),
"message-id": r.Header.ID,
"activity-id": r.ActivityID,
"container-id": r.ContainerID,
}).Warnf("bridge: request processing thread in sequential mode blocked on the current request for more than 5 seconds")
}
}()
defer cancel()
handleFn(r)
}

// PublishNotification writes a specific notification to the bridge.
func (b *Bridge) PublishNotification(n *prot.ContainerNotification) {
ctx, span := oc.StartSpan(context.Background(),
Expand Down
8 changes: 1 addition & 7 deletions internal/guest/bridge/bridge_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,10 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro
return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message)
}

c, err := b.hostState.GetCreatedContainer(request.ContainerID)
err = b.hostState.DeleteContainerState(ctx, request.ContainerID)
if err != nil {
return nil, err
}
// remove container state regardless of delete's success
defer b.hostState.RemoveContainer(request.ContainerID)

if err := c.Delete(ctx); err != nil {
return nil, err
}

return &prot.MessageResponseBase{}, nil
}
13 changes: 13 additions & 0 deletions internal/guest/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"time"

Expand All @@ -32,6 +33,18 @@ var (
// maxDNSSearches is limited to 6 in `man 5 resolv.conf`
const maxDNSSearches = 6

var validHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]{0,255}$`)

// Check that the hostname is safe. This function is less strict than
// technically allowed, but ensures that when the hostname is inserted to
// /etc/hosts, it cannot lead to injection attacks.
func ValidateHostname(hostname string) error {
if !validHostnameRegex.MatchString(hostname) {
return errors.Errorf("hostname %q invalid: must match %s", hostname, validHostnameRegex.String())
}
return nil
}

// GenerateEtcHostsContent generates a /etc/hosts file based on `hostname`.
func GenerateEtcHostsContent(ctx context.Context, hostname string) string {
_, span := oc.StartSpan(ctx, "network::GenerateEtcHostsContent")
Expand Down
35 changes: 35 additions & 0 deletions internal/guest/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -122,6 +123,40 @@ func Test_MergeValues(t *testing.T) {
}
}

func Test_ValidateHostname(t *testing.T) {
validNames := []string{
"localhost",
"my-hostname",
"my.hostname",
"my-host-name123",
"_underscores.are.allowed.too",
"", // Allow not specifying a hostname
}

invalidNames := []string{
"localhost\n13.104.0.1 ip6-localhost ip6-loopback localhost",
"localhost\n2603:1000::1 ip6-localhost ip6-loopback",
"hello@microsoft.com",
"has space",
"has,comma",
"\x00",
"a\nb",
strings.Repeat("a", 1000),
}

for _, n := range validNames {
if err := ValidateHostname(n); err != nil {
t.Fatalf("expected %q to be valid, got: %v", n, err)
}
}

for _, n := range invalidNames {
if err := ValidateHostname(n); err == nil {
t.Fatalf("expected %q to be invalid, but got nil error", n)
}
}
}

func Test_GenerateEtcHostsContent(t *testing.T) {
type testcase struct {
name string
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ type Container struct {
// and deal with the extra pointer dereferencing overhead.
status atomic.Uint32

// Set to true when the init process for the container has exited
terminated atomic.Bool

// scratchDirPath represents the path inside the UVM where the scratch directory
// of this container is located. Usually, this is either `/run/gcs/c/<containerID>` or
// `/run/gcs/c/<UVMID>/container_<containerID>` if scratch is shared with UVM scratch.
Expand Down
1 change: 1 addition & 0 deletions internal/guest/runtime/hcsv2/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func newProcess(c *Container, spec *oci.Process, process runtime.Process, pid ui
log.G(ctx).WithError(err).Error("failed to wait for runc process")
}
p.exitCode = exitCode
c.terminated.Store(true)
log.G(ctx).WithField("exitCode", p.exitCode).Debug("process exited")

// Free any process waiters
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/sandbox_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ func setupSandboxContainerSpec(ctx context.Context, id string, spec *oci.Spec) (

// Write the hostname
hostname := spec.Hostname
if err = network.ValidateHostname(hostname); err != nil {
return err
}
if hostname == "" {
var err error
hostname, err = os.Hostname()
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/standalone_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ func setupStandaloneContainerSpec(ctx context.Context, id string, spec *oci.Spec
}()

hostname := spec.Hostname
if err = network.ValidateHostname(hostname); err != nil {
return err
}
if hostname == "" {
var err error
hostname, err = os.Hostname()
Expand Down
Loading