Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdks/go/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func main() {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject pipeline options into context
ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions())

// (2) Retrieve the staged files.
//
// The Go SDK harness downloads the worker binary and invokes
Expand Down
41 changes: 37 additions & 4 deletions sdks/go/pkg/beam/artifact/materialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/errorx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx"
"google.golang.org/protobuf/proto"
structpb "google.golang.org/protobuf/types/known/structpb"
)

// TODO(lostluck): 2018/05/28 Extract these from their enum descriptors in the pipeline_v1 proto
Expand Down Expand Up @@ -131,6 +132,7 @@ func newMaterializeWithClient(ctx context.Context, client jobpb.ArtifactRetrieva
RoleUrn: URNStagingTo,
RolePayload: rolePayload,
},
expectedSha256: filePayload.Sha256,
})
}

Expand Down Expand Up @@ -183,8 +185,9 @@ func MustExtractFilePayload(artifact *pipepb.ArtifactInformation) (string, strin
}

type artifact struct {
client jobpb.ArtifactRetrievalServiceClient
dep *pipepb.ArtifactInformation
client jobpb.ArtifactRetrievalServiceClient
dep *pipepb.ArtifactInformation
expectedSha256 string
}

func (a artifact) retrieve(ctx context.Context, dest string) error {
Expand Down Expand Up @@ -231,7 +234,15 @@ func (a artifact) retrieve(ctx context.Context, dest string) error {
stat, _ := fd.Stat()
log.Printf("Downloaded: %v (sha256: %v, size: %v)", filename, sha256Hash, stat.Size())

return fd.Close()
if err := fd.Close(); err != nil {
return err
}

if isArtifactValidationEnabled(ctx) && a.expectedSha256 != "" && sha256Hash != a.expectedSha256 {
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.expectedSha256)
}

return nil
}

func writeChunks(stream jobpb.ArtifactRetrievalService_GetArtifactClient, w io.Writer) (string, error) {
Expand Down Expand Up @@ -442,7 +453,7 @@ func retrieve(ctx context.Context, client jobpb.LegacyArtifactRetrievalServiceCl
}

// Artifact Sha256 hash is an optional field in metadata so we should only validate when its present.
if a.Sha256 != "" && sha256Hash != a.Sha256 {
if isArtifactValidationEnabled(ctx) && a.Sha256 != "" && sha256Hash != a.Sha256 {
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256)
}
return nil
Expand Down Expand Up @@ -511,3 +522,25 @@ func queue2slice(q chan *jobpb.ArtifactMetadata) []*jobpb.ArtifactMetadata {
}
return ret
}

type contextKey string

const pipelineOptionsKey contextKey = "pipeline_options"

// WithPipelineOptions returns a new context carrying the full pipeline options struct.
func WithPipelineOptions(ctx context.Context, options *structpb.Struct) context.Context {
return context.WithValue(ctx, pipelineOptionsKey, options)
}

// isArtifactValidationEnabled parses pipeline options to check if "disable_integrity_checks" is enabled.
func isArtifactValidationEnabled(ctx context.Context) bool {
options, _ := ctx.Value(pipelineOptionsKey).(*structpb.Struct)
if options != nil {
for _, v := range options.GetFields()["options"].GetStructValue().GetFields()["experiments"].GetListValue().GetValues() {
if v.GetStringValue() == "disable_integrity_checks" {
return false
}
}
}
return true
}
126 changes: 126 additions & 0 deletions sdks/go/pkg/beam/artifact/materialize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
structpb "google.golang.org/protobuf/types/known/structpb"
)

// TestRetrieve tests that we can successfully retrieve fresh files.
Expand Down Expand Up @@ -82,6 +83,57 @@ func TestMultiRetrieve(t *testing.T) {
}
}

func TestRetrieveWithBadShaFails(t *testing.T) {
cc := startServer(t)
defer cc.Close()

ctx := grpcx.WriteWorkerID(context.Background(), "idA")
keys := []string{"foo"}
st := "whatever"
rt, artifacts := populate(ctx, cc, t, keys, 300, st)

dst := makeTempDir(t)
defer os.RemoveAll(dst)

client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
for _, a := range artifacts {
a.Sha256 = "badhash" // mutate hash
if err := Retrieve(ctx, client, a, rt, dst); err == nil {
t.Errorf("expected materialization to fail due to bad sha256 mismatch")
}
}
}

func TestRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
cc := startServer(t)
defer cc.Close()

options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"disable_integrity_checks"},
},
})
ctx := WithPipelineOptions(grpcx.WriteWorkerID(context.Background(), "idA"), options)
keys := []string{"foo"}
st := "whatever"
rt, artifacts := populate(ctx, cc, t, keys, 300, st)

dst := makeTempDir(t)
defer os.RemoveAll(dst)

client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
for _, a := range artifacts {
originalHash := a.Sha256
a.Sha256 = "badhash" // mutate hash
filename := makeFilename(dst, a.Name)
if err := Retrieve(ctx, client, a, rt, dst); err != nil {
t.Errorf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
continue
}
verifySHA256(t, filename, originalHash)
}
}

// populate stages a set of artifacts with the given keys, each with
// slightly different sizes and chucksizes.
func populate(ctx context.Context, cc *grpc.ClientConn, t *testing.T, keys []string, size int, st string) (string, []*jobpb.ArtifactMetadata) {
Expand Down Expand Up @@ -266,6 +318,65 @@ func TestNewRetrieveWithResolution(t *testing.T) {
checkStagedFiles(mds, dest, expected, t)
}

func TestIsArtifactValidationEnabled(t *testing.T) {
ctx := context.Background()
if !isArtifactValidationEnabled(ctx) {
t.Errorf("empty context should have validation enabled")
}

options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"disable_integrity_checks"},
},
})
ctx2 := WithPipelineOptions(ctx, options)
if isArtifactValidationEnabled(ctx2) {
t.Errorf("populated context should have validation disabled")
}
}

func TestNewRetrieveWithBadShaFails(t *testing.T) {
expected := map[string]string{"a.txt": "a"}
client := &fakeRetrievalService{artifacts: expected}
dest := makeTempDir(t)
defer os.RemoveAll(dest)
ctx := grpcx.WriteWorkerID(context.Background(), "worker")

_, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
if err == nil {
t.Fatalf("expected materialization to fail due to bad sha256 mismatch")
}
}

func TestNewRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
expected := map[string]string{"a.txt": "a"}
client := &fakeRetrievalService{artifacts: expected}
dest := makeTempDir(t)
defer os.RemoveAll(dest)

options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"disable_integrity_checks"},
},
})
ctx := WithPipelineOptions(grpcx.WriteWorkerID(context.Background(), "worker"), options)

mds, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
if err != nil {
t.Fatalf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
}

generated := make(map[string]string)
for _, md := range mds {
name, _ := MustExtractFilePayload(md)
payload, _ := proto.Marshal(&pipepb.ArtifactStagingToRolePayload{
StagedName: name})
generated[name] = string(payload)
}

checkStagedFiles(mds, dest, generated, t)
}

func checkStagedFiles(mds []*pipepb.ArtifactInformation, dest string, expected map[string]string, t *testing.T) {
if len(mds) != len(expected) {
t.Errorf("wrong number of artifacts staged %v vs %v", len(mds), len(expected))
Expand Down Expand Up @@ -323,6 +434,21 @@ func (fake *fakeRetrievalService) fileArtifactsWithoutStagingTo() []*pipepb.Arti
return artifacts
}

func (fake *fakeRetrievalService) fileArtifactsWithBadSha() []*pipepb.ArtifactInformation {
var artifacts []*pipepb.ArtifactInformation
for name := range fake.artifacts {
payload, _ := proto.Marshal(&pipepb.ArtifactFilePayload{
Path: filepath.Join("/tmp", name),
Sha256: "badhash",
})
artifacts = append(artifacts, &pipepb.ArtifactInformation{
TypeUrn: URNFileArtifact,
TypePayload: payload,
})
}
return artifacts
}

func (fake *fakeRetrievalService) urlArtifactsWithoutStagingTo() []*pipepb.ArtifactInformation {
var artifacts []*pipepb.ArtifactInformation
for name := range fake.artifacts {
Expand Down
3 changes: 3 additions & 0 deletions sdks/java/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func main() {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject pipeline options into context
ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions())

// (2) Retrieve the staged user jars. We ignore any disk limit,
// because the staged jars are mandatory.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,9 @@ def _stage_resources(self, pipeline, options):
else:
remote_name = os.path.basename(type_payload.path)
is_staged_role = False

if self._enable_caching and not type_payload.sha256:
# compute sha256 even if caching is disabled.
# This is used to check the payload integrity along with caching.
if not type_payload.sha256:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For same feature flag purpose, shall we also skip computing sha256 (to presere original behavior) if disable_integrity_check experiment is set?

type_payload.sha256 = self._compute_sha256(type_payload.path)

if type_payload.sha256 and type_payload.sha256 in staged_hashes:
Expand Down
42 changes: 24 additions & 18 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,13 +1340,19 @@ def test_stage_resources(self):
])
}))
client = apiclient.DataflowApplicationClient(pipeline_options)
with mock.patch.object(apiclient._LegacyDataflowStager,
'stage_job_resources') as mock_stager:
client._stage_resources(pipeline, pipeline_options)
with mock.patch.object(apiclient.DataflowApplicationClient,
'_compute_sha256',
side_effect=lambda path: 'hash' + path):
with mock.patch.object(apiclient._LegacyDataflowStager,
'stage_job_resources') as mock_stager:
client._stage_resources(pipeline, pipeline_options)
mock_stager.assert_called_once_with(
[('/tmp/foo1', 'foo1', ''), ('/tmp/bar1', 'bar1', ''),
('/tmp/baz', 'baz1', ''), ('/tmp/renamed1', 'renamed1', 'abcdefg'),
('/tmp/foo2', 'foo2', ''), ('/tmp/bar2', 'bar2', '')],
[('/tmp/foo1', 'foo1', 'hash/tmp/foo1'),
('/tmp/bar1', 'bar1', 'hash/tmp/bar1'),
('/tmp/baz', 'baz1', 'hash/tmp/baz'),
('/tmp/renamed1', 'renamed1', 'abcdefg'),
('/tmp/foo2', 'foo2', 'hash/tmp/foo2'),
('/tmp/bar2', 'bar2', 'hash/tmp/bar2')],
staging_location='gs://test-location/staging')

pipeline_expected = beam_runner_api_pb2.Pipeline(
Expand All @@ -1357,26 +1363,26 @@ def test_stage_resources(self):
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/foo1'
).SerializeToString(),
url='gs://test-location/staging/foo1',
sha256='hash/tmp/foo1').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
staged_name='foo1').SerializeToString()),
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/bar1').
SerializeToString(),
url='gs://test-location/staging/bar1',
sha256='hash/tmp/bar1').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
staged_name='bar1').SerializeToString()),
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/baz1').
SerializeToString(),
url='gs://test-location/staging/baz1',
sha256='hash/tmp/baz').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
Expand All @@ -1396,26 +1402,26 @@ def test_stage_resources(self):
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/foo2').
SerializeToString(),
url='gs://test-location/staging/foo2',
sha256='hash/tmp/foo2').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
staged_name='foo2').SerializeToString()),
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/bar2').
SerializeToString(),
url='gs://test-location/staging/bar2',
sha256='hash/tmp/bar2').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
staged_name='bar2').SerializeToString()),
beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='gs://test-location/staging/baz1').
SerializeToString(),
url='gs://test-location/staging/baz1',
sha256='hash/tmp/baz').SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
role_payload=beam_runner_api_pb2.
ArtifactStagingToRolePayload(
Expand Down
3 changes: 3 additions & 0 deletions sdks/python/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ func launchSDKProcess() error {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject pipeline options into context
ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions())

experiments := getExperiments(options)
pipNoBuildIsolation = false
if slices.Contains(experiments, "pip_no_build_isolation") {
Expand Down
3 changes: 3 additions & 0 deletions sdks/typescript/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func main() {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject pipeline options into context
ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions())

// (2) Retrieve and install the staged packages.

dir := filepath.Join(*semiPersistDir, *id, "staged")
Expand Down
Loading