Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/windows/common/wslutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ constexpr GUID EndianSwap(GUID value)
return value;
}

std::regex BuildImageReferenceRegex()
{
// See: https://github.com/containers/image/blob/main/docker/reference/regexp.go

std::string alphaNum = "[a-z0-9]+";
std::string separator = "(?:[._]|__|[-]*)";
std::string domainComponent = "(?:[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9-]*[a-zA-Z0-9])";
std::string tag = "[\\w][\\w.-]{0,127}";
std::string digest = "[A-Za-z][A-Za-z0-9]*(?:[-_+.][A-Za-z][A-Za-z0-9]*)*[:][[:xdigit:]]{32,}";

auto group = [](const auto& exp) { return std::format("(?:{})", exp); };
auto optional = [&group](const auto& exp) { return group(exp) + "?"; };
auto repeated = [&group](const auto& exp) { return group(exp) + "+"; };
auto capture = [](const auto& exp) { return std::format("({})", exp); };

auto nameComponent = alphaNum + optional(repeated(separator + alphaNum));
auto domain = domainComponent + optional(repeated("\\." + domainComponent)) + optional(":[0-9]+");
auto namePat = optional(domain + "\\/") + nameComponent + optional(repeated("\\/" + nameComponent));

return std::regex("^" + capture(namePat) + optional(":" + capture(tag)) + optional("@" + capture(digest)) + "$");
}

} // namespace

template <typename TInterface>
Expand Down Expand Up @@ -1165,6 +1187,35 @@ std::tuple<uint32_t, uint32_t, uint32_t> wsl::windows::common::wslutil::ParseWsl
}
}

std::pair<std::string, std::optional<std::string>> wsl::windows::common::wslutil::ParseImage(const std::string& Input)
{
static const auto regex = BuildImageReferenceRegex();
std::smatch match;
if (!std::regex_match(Input, match, regex))
{
THROW_HR_WITH_USER_ERROR(E_INVALIDARG, wsl::shared::Localization::MessageWslaInvalidImage(Input.c_str()));
}

const auto& repo = match[1];
const auto& tag = match[2];
const auto& digest = match[3];

THROW_HR_IF_MSG(E_UNEXPECTED, !repo.matched, "Unexpected regex match. Input: %hs", Input.c_str());

if (digest.matched) // <repo>:[tag]@<digest> (If both digest and tag are specified, digest takes precedence).
{
return {repo.str(), digest.str()};
}
else if (tag.matched) // <repo>:<tag>
{
return {repo.str(), tag.str()}; // <repo>
}
else
{
return {repo.str(), std::nullopt};
}
}

void wsl::windows::common::wslutil::PrintSystemError(_In_ HRESULT result, _Inout_ FILE* const stream)
{
fwprintf(stream, L"%ls\n", GetSystemErrorString(result).c_str());
Expand Down
2 changes: 2 additions & 0 deletions src/windows/common/wslutil.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ void ParseIpv6Address(const char* Address, in_addr6& Result);

std::tuple<uint32_t, uint32_t, uint32_t> ParseWslPackageVersion(_In_ const std::wstring& Version);

std::pair<std::string, std::optional<std::string>> ParseImage(const std::string& Input);

void PrintSystemError(_In_ HRESULT result, _Inout_ FILE* stream = stdout);

void PrintMessageImpl(_In_ const std::wstring& message, _In_ va_list& args, _Inout_ FILE* stream = stdout);
Expand Down
2 changes: 1 addition & 1 deletion src/windows/service/inc/wslaservice.idl
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ interface IWSLASession : IUnknown
HRESULT GetState([out] WSLASessionState* State);

// Image management.
HRESULT PullImage([in] LPCSTR ImageUri, [in, unique] const WslaRegistryAuthInformation* RegistryAuthenticationInformation, [in, unique] IProgressCallback* ProgressCallback);
HRESULT PullImage([in] LPCSTR Image, [in, unique] const WslaRegistryAuthInformation* RegistryAuthenticationInformation, [in, unique] IProgressCallback* ProgressCallback);
HRESULT BuildImage([in] const WSLABuildImageOptions* Options, [in, unique] IProgressCallback* ProgressCallback);
HRESULT LoadImage([in] ULONG ImageHandle, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
HRESULT ImportImage([in] ULONG ImageHandle, [in] LPCSTR ImageName, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
Expand Down
8 changes: 5 additions & 3 deletions src/windows/wslasession/DockerHTTPClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@ DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE
{
}

std::unique_ptr<DockerHTTPClient::HTTPRequestContext> DockerHTTPClient::PullImage(const std::string& Repo, const std::optional<std::string>& Tag)
std::unique_ptr<DockerHTTPClient::HTTPRequestContext> DockerHTTPClient::PullImage(const std::string& Repo, const std::optional<std::string>& tagOrDigest)
{
auto url = URL::Create("/images/create");

// TODO: Support pulling from other registries.
url.SetParameter("fromImage", std::format("library/{}", Repo));

if (Tag.has_value())
if (tagOrDigest.has_value())
{
url.SetParameter("tag", Tag.value());
url.SetParameter("tag", tagOrDigest.value());
}
Comment thread
OneBlue marked this conversation as resolved.

return SendRequestImpl(verb::post, url, {}, {});
Expand Down
2 changes: 1 addition & 1 deletion src/windows/wslasession/DockerHTTPClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class DockerHTTPClient
std::vector<std::string> labels;
};

std::unique_ptr<HTTPRequestContext> PullImage(const std::string& Repo, const std::optional<std::string>& Tag);
std::unique_ptr<HTTPRequestContext> PullImage(const std::string& Repo, const std::optional<std::string>& tagOrDigest);
Comment thread
OneBlue marked this conversation as resolved.
std::unique_ptr<HTTPRequestContext> ImportImage(const std::string& Repo, const std::string& Tag, uint64_t ContentLength);
std::unique_ptr<HTTPRequestContext> LoadImage(uint64_t ContentLength);
void TagImage(const std::string& Id, const std::string& Repo, const std::string& Tag);
Expand Down
40 changes: 16 additions & 24 deletions src/windows/wslasession/WSLASession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ constexpr auto c_containerdStorage = "/var/lib/docker";

namespace {

std::pair<std::string, std::optional<std::string>> ParseImage(const std::string& Input)
{
size_t separator = Input.find_last_of(':');
if (separator == std::string::npos)
{
return {Input, {}};
}

THROW_HR_WITH_USER_ERROR_IF(E_INVALIDARG, Localization::MessageWslaInvalidImage(Input), separator >= Input.size() - 1 || separator == 0);

return {Input.substr(0, separator), Input.substr(separator + 1)};
}

void ValidateName(LPCSTR Name)
{
const auto& locale = std::locale::classic();
Expand Down Expand Up @@ -305,20 +292,26 @@ void WSLASession::StartDockerd()
m_dockerdProcess->GetExitEvent(), std::bind(&WSLASession::OnDockerdExited, this)));
}

HRESULT WSLASession::PullImage(LPCSTR ImageUri, const WslaRegistryAuthInformation* RegistryAuthenticationInformation, IProgressCallback* ProgressCallback)
HRESULT WSLASession::PullImage(LPCSTR Image, const WslaRegistryAuthInformation* RegistryAuthenticationInformation, IProgressCallback* ProgressCallback)
try
{
UNREFERENCED_PARAMETER(RegistryAuthenticationInformation);

COMServiceExecutionContext context;

RETURN_HR_IF_NULL(E_POINTER, ImageUri);

auto [repo, tag] = ParseImage(ImageUri);
RETURN_HR_IF_NULL(E_POINTER, Image);

auto lock = m_lock.lock_shared();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());

auto [repo, tagOrDigest] = wslutil::ParseImage(Image);

if (!tagOrDigest.has_value())
{
tagOrDigest = "latest";
}
Comment on lines +302 to +312
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WSLASession::PullImage() validates Image != nullptr but does not enforce WSLA_MAX_IMAGE_NAME_LENGTH like other image APIs (e.g., ImportImage, InspectImage). Since this now runs a fairly complex std::regex parse, a very long input could cause unnecessary CPU/memory usage. Add a length check (returning E_INVALIDARG) before calling wslutil::ParseImage.

Copilot uses AI. Check for mistakes.

auto requestContext = m_dockerClient->PullImage(repo, tag);
auto requestContext = m_dockerClient->PullImage(repo, tagOrDigest);

auto io = CreateIOContext();

Expand All @@ -340,7 +333,7 @@ try
}

std::string contentString{Content.begin(), Content.end()};
WSL_LOG("ImagePullProgress", TraceLoggingValue(ImageUri, "Image"), TraceLoggingValue(contentString.c_str(), "Content"));
WSL_LOG("ImagePullProgress", TraceLoggingValue(Image, "Image"), TraceLoggingValue(contentString.c_str(), "Content"));

if (ProgressCallback == nullptr)
{
Expand Down Expand Up @@ -575,15 +568,15 @@ try
RETURN_HR_IF_NULL(E_POINTER, ImageName);
RETURN_HR_IF(E_INVALIDARG, strlen(ImageName) > WSLA_MAX_IMAGE_NAME_LENGTH);

auto [repo, tag] = ParseImage(ImageName);
Comment thread
OneBlue marked this conversation as resolved.
auto [repo, tagOrDigest] = wslutil::ParseImage(ImageName);

THROW_HR_IF_MSG(E_INVALIDARG, !tag.has_value(), "Expected tag for image import: %hs", ImageName);
THROW_HR_IF_MSG(E_INVALIDARG, !tagOrDigest.has_value(), "Expected tag for image import: %hs", ImageName);

auto lock = m_lock.lock_shared();

THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());

auto requestContext = m_dockerClient->ImportImage(repo, tag.value(), ContentSize);
auto requestContext = m_dockerClient->ImportImage(repo, tagOrDigest.value(), ContentSize);

ImportImageImpl(*requestContext, ImageHandle);
return S_OK;
Expand Down Expand Up @@ -868,8 +861,7 @@ try

// Extract repo name from tag (format: "repo:tag")
// and lookup corresponding digest from the map
auto repoName = ParseImage(tag).first;
size_t colonPos = tag.find(':');
auto repoName = wslutil::ParseImage(tag).first;
auto it = repoToDigest.find(repoName);
if (it != repoToDigest.end())
{
Expand Down
2 changes: 1 addition & 1 deletion src/windows/wslasession/WSLASession.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
IFACEMETHOD(GetState)(_Out_ WSLASessionState* State) override;

// Image management.
IFACEMETHOD(PullImage)(_In_ LPCSTR ImageUri, _In_opt_ const WslaRegistryAuthInformation* RegistryAuthenticationInformation, _In_opt_ IProgressCallback* ProgressCallback) override;
IFACEMETHOD(PullImage)(_In_ LPCSTR Image, _In_opt_ const WslaRegistryAuthInformation* RegistryAuthenticationInformation, _In_opt_ IProgressCallback* ProgressCallback) override;
IFACEMETHOD(BuildImage)(_In_ const WSLABuildImageOptions* Options, _In_opt_ IProgressCallback* ProgressCallback) override;
IFACEMETHOD(LoadImage)(_In_ ULONG ImageHandle, _In_ IProgressCallback* ProgressCallback, _In_ ULONGLONG ContentLength) override;
IFACEMETHOD(ImportImage)(_In_ ULONG ImageHandle, _In_ LPCSTR ImageName, _In_ IProgressCallback* ProgressCallback, _In_ ULONGLONG ContentLength) override;
Expand Down
142 changes: 132 additions & 10 deletions test/windows/WSLATests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,74 @@ class WSLATests

VERIFY_ARE_EQUAL(expectedError, comError->Message.get());
}

// Validate that PullImage() returns the appropriate error if the session is terminated.
{
VERIFY_SUCCEEDED(m_defaultSession->Terminate());

auto cleanup = wil::scope_exit([&]() {
ResetTestSession(); // Reopen the test session since the session was terminated.
});

VERIFY_ARE_EQUAL(m_defaultSession->PullImage("hello-world:linux", nullptr, nullptr), HRESULT_FROM_WIN32(ERROR_INVALID_STATE));
}
}

TEST_METHOD(PullImageAdvanced)
{
WSL2_TEST_ONLY();

// TODO: Enable once custom registries are supported, to avoid hitting public registry rate limits.
SKIP_TEST_UNSTABLE();
Comment thread
OneBlue marked this conversation as resolved.

auto validatePull = [&](const std::string& Image, const std::optional<std::string>& ExpectedTag = {}) {
VERIFY_SUCCEEDED(m_defaultSession->PullImage(Image.c_str(), nullptr, nullptr));

Comment thread
OneBlue marked this conversation as resolved.
auto cleanup = wil::scope_exit([&]() {
WSLADeleteImageOptions options{.Flags = WSLADeleteImageFlagsForce};
options.Image = ExpectedTag.has_value() ? ExpectedTag->c_str() : Image.c_str();
wil::unique_cotaskmem_array_ptr<WSLADeletedImageInformation> deletedImages;
LOG_IF_FAILED(m_defaultSession->DeleteImage(&options, &deletedImages, deletedImages.size_address<ULONG>()));
});

if (!ExpectedTag.has_value())
{

wil::unique_cotaskmem_array_ptr<WSLAImageInformation> images;
VERIFY_SUCCEEDED(m_defaultSession->ListImages(nullptr, images.addressof(), images.size_address<ULONG>()));

for (const auto& e : images)
{
wil::unique_cotaskmem_ansistring json;
VERIFY_SUCCEEDED(m_defaultSession->InspectImage(e.Hash, &json));

auto parsed = wsl::shared::FromJson<wsl::windows::common::wsla_schema::InspectImage>(json.get());

for (const auto& repoTag : parsed.RepoDigests.value_or({}))
{
if (Image == repoTag)
{
return;
}
}
}

LogError("Expected digest '%hs' not found ", Image.c_str());

VERIFY_FAIL();
}
else
{
ExpectImagePresent(*m_defaultSession, ExpectedTag->c_str());
}
};

validatePull("ubuntu@sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30", {});

validatePull("ubuntu", "ubuntu:latest");
validatePull("debian:bookworm", "debian:bookworm");

// TODO: Add test coverage with custom registries once supported.
}

TEST_METHOD(ListImages)
Expand Down Expand Up @@ -5329,23 +5397,23 @@ class WSLATests
std::string longName(WSLA_MAX_CONTAINER_NAME_LENGTH + 1, 'a');
expectInvalidArg(longName);

auto expectInvalidPull = [&](const char* name, const char* errorPattern) {
auto expectInvalidPull = [&](const char* name) {
VERIFY_ARE_EQUAL(m_defaultSession->PullImage(name, nullptr, nullptr), E_INVALIDARG);

auto comError = wsl::windows::common::wslutil::GetCOMErrorInfo();
VERIFY_IS_TRUE(comError.has_value());

VerifyPatternMatch(wsl::shared::string::WideToMultiByte(comError->Message.get()), errorPattern);
VERIFY_ARE_EQUAL(comError->Message.get(), std::format(L"Invalid image: '{}'", name));
};

expectInvalidPull("?foo&bar/url\n:name", "invalid reference format");
expectInvalidPull("?:&", "invalid reference format");
expectInvalidPull("/:/", "invalid reference format");
expectInvalidPull("\n: ", "invalid reference format");
expectInvalidPull("invalid\nrepo:valid-image", "invalid reference format");
expectInvalidPull("bad!repo:valid-image", "invalid reference format");
expectInvalidPull("repo:badimage!name", "invalid tag format");
expectInvalidPull("bad+image", "invalid reference format");
expectInvalidPull("?foo&bar/url\n:name");
expectInvalidPull("?:&");
expectInvalidPull("/:/");
expectInvalidPull("\n: ");
expectInvalidPull("invalid\nrepo:valid-image");
expectInvalidPull("bad!repo:valid-image");
expectInvalidPull("repo:badimage!name");
expectInvalidPull("bad+image");
}

TEST_METHOD(PageReporting)
Expand Down Expand Up @@ -5886,4 +5954,58 @@ class WSLATests
VERIFY_ARE_EQUAL(m_defaultSession->PruneContainers(&filter, 1, 0, nullptr), HRESULT_FROM_WIN32(RPC_X_NULL_REF_POINTER));
}
}

TEST_METHOD(ImageParsing)
{
using wsl::windows::common::wslutil::ParseImage;

auto ValidateImageParsing = [](const std::string& input, const std::string& expectedRepo, const std::optional<std::string>& expectedTag) {
auto [repo, tag] = ParseImage(input);
VERIFY_ARE_EQUAL(repo, expectedRepo);
VERIFY_ARE_EQUAL(tag.value_or("<empty>"), expectedTag.value_or("<empty>"));
};

ValidateImageParsing("ubuntu:22.04", "ubuntu", "22.04");
ValidateImageParsing("ubuntu", "ubuntu", {});
ValidateImageParsing("library/ubuntu:latest", "library/ubuntu", "latest");
ValidateImageParsing("myregistry.io:5000/myimage:v1", "myregistry.io:5000/myimage", "v1");
ValidateImageParsing("myregistry.io:5000/myimage", "myregistry.io:5000/myimage", {});

ValidateImageParsing(
"registry.example.com:8080/org/project/image:stable", "registry.example.com:8080/org/project/image", "stable");

ValidateImageParsing("localhost:5000/myimage:latest", "localhost:5000/myimage", "latest");
ValidateImageParsing("ghcr.io/owner/repo:sha-abc123", "ghcr.io/owner/repo", "sha-abc123");

ValidateImageParsing(
"ubuntu@sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30",
"ubuntu",
"sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30");

// Validate that the digest takes precedence over the tag.
ValidateImageParsing(
"ubuntu:latest@sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30",
"ubuntu",
"sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30");

ValidateImageParsing(
"myregistry.io:5000/myimage@sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30",
"myregistry.io:5000/myimage",
"sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30");

ValidateImageParsing(
"ubuntu:22.04@sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30",
"ubuntu",
"sha256:2e863c44b718727c860746568e1d54afd13b2fa71b160f5cd9058fc436217b30");

ValidateImageParsing("pytorch/pytorch", "pytorch/pytorch", {});

// Invalid inputs
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage(":debian:latest"); }), E_INVALIDARG);
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage("debian:latest@"); }), E_INVALIDARG);
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage(""); }), E_INVALIDARG);
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage(":"); }), E_INVALIDARG);
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage("a:"); }), E_INVALIDARG);
VERIFY_ARE_EQUAL(wil::ResultFromException([]() { ParseImage(":b"); }), E_INVALIDARG);
}
};