Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cmd/atelet/oci.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ func untar(ctx context.Context, tarData io.Reader, rootPath string) error {

switch hdr.Typeflag {
case tar.TypeReg: // Regular file
// Same "later entry wins" handling: if any entry exists at the target path,
// remove it first. This ensures that:
// 1. If it's a symlink, we don't write through it (security vulnerability / incorrectness).
// 2. If it's a hardlink, we unlink it instead of truncating the shared inode.
// 3. If it's a directory, we recursively remove it so we can write the file.
if _, err := root.Lstat(name); err == nil {
if err := root.RemoveAll(name); err != nil {
return fmt.Errorf("while replacing existing path at %q before regular file: %w", name, err)
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("while checking existing path at %q before regular file: %w", name, err)
}

// Stream directly from tarReader to target file to avoid buffering in memory.
outFile, err := root.OpenFile(name, os.O_CREATE|os.O_RDWR|os.O_TRUNC, mode)
if err != nil {
Expand Down
90 changes: 90 additions & 0 deletions cmd/atelet/oci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,96 @@ func TestUntar_LaterEntryWins(t *testing.T) {
t.Errorf("symlink target = %q, want %q", got, "x")
}
})

t.Run("symlink overwritten by file", func(t *testing.T) {
entries := []tarEntry{
{name: "etc/", typeflag: tar.TypeDir},
{name: "etc/x", typeflag: tar.TypeReg, body: "original"},
{name: "etc/link", typeflag: tar.TypeSymlink, linkname: "x"},
{name: "etc/link", typeflag: tar.TypeReg, body: "replacement"},
}
dir, err := runUntar(t, entries)
if err != nil {
t.Fatalf("untar: %v", err)
}
fi, err := os.Lstat(filepath.Join(dir, "etc/link"))
if err != nil {
t.Fatalf("lstat etc/link: %v", err)
}
if fi.Mode().IsRegular() {
got, err := os.ReadFile(filepath.Join(dir, "etc/link"))
if err != nil {
t.Fatalf("read etc/link: %v", err)
}
if string(got) != "replacement" {
t.Errorf("etc/link content = %q, want %q", got, "replacement")
}
} else {
t.Errorf("etc/link mode is not regular file: %v", fi.Mode())
}
// Also verify etc/x was NOT overwritten
gotX, err := os.ReadFile(filepath.Join(dir, "etc/x"))
if err != nil {
t.Fatalf("read etc/x: %v", err)
}
if string(gotX) != "original" {
t.Errorf("etc/x content was overwritten to %q", gotX)
}
})

t.Run("file overwritten by symlink", func(t *testing.T) {
entries := []tarEntry{
{name: "etc/", typeflag: tar.TypeDir},
{name: "etc/link", typeflag: tar.TypeReg, body: "original-file"},
{name: "etc/link", typeflag: tar.TypeSymlink, linkname: "target-doesnt-exist"},
}
dir, err := runUntar(t, entries)
if err != nil {
t.Fatalf("untar: %v", err)
}
fi, err := os.Lstat(filepath.Join(dir, "etc/link"))
if err != nil {
t.Fatalf("lstat etc/link: %v", err)
}
if fi.Mode()&os.ModeSymlink == 0 {
t.Errorf("etc/link mode is not a symlink: %v", fi.Mode())
}
got, err := os.Readlink(filepath.Join(dir, "etc/link"))
if err != nil {
t.Fatalf("readlink etc/link: %v", err)
}
if got != "target-doesnt-exist" {
t.Errorf("etc/link target = %q, want %q", got, "target-doesnt-exist")
}
})

t.Run("hardlink overwritten by file", func(t *testing.T) {
entries := []tarEntry{
{name: "bin/", typeflag: tar.TypeDir},
{name: "bin/sh", typeflag: tar.TypeReg, body: "sh-original"},
{name: "bin/bash", typeflag: tar.TypeLink, linkname: "bin/sh"},
{name: "bin/bash", typeflag: tar.TypeReg, body: "bash-new"},
}
dir, err := runUntar(t, entries)
if err != nil {
t.Fatalf("untar: %v", err)
}
gotBash, err := os.ReadFile(filepath.Join(dir, "bin/bash"))
if err != nil {
t.Fatalf("read bin/bash: %v", err)
}
if string(gotBash) != "bash-new" {
t.Errorf("bin/bash content = %q, want %q", gotBash, "bash-new")
}
// Verify bin/sh was NOT modified!
gotSh, err := os.ReadFile(filepath.Join(dir, "bin/sh"))
if err != nil {
t.Fatalf("read bin/sh: %v", err)
}
if string(gotSh) != "sh-original" {
t.Errorf("bin/sh content was overwritten to %q (hardlink was not unlinked)", gotSh)
}
})
}

func TestUntar_PathTraversal(t *testing.T) {
Expand Down
Loading