Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ venv
.python-version
cloud_sql_python_connector.egg-info/
dist/
.idea
.coverage
sponge_log.xml
21 changes: 17 additions & 4 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,23 @@ async def _get_metadata(
# resolve dnsName into IP address for PSC
# Note that we have to check for PSC enablement also because CAS
# instances also set the dnsName field.
# Remove trailing period from DNS name. Required for SSL in Python
dns_name = ret_dict.get("dnsName", "").rstrip(".")
if dns_name and ret_dict.get("pscEnabled"):
ip_addresses["PSC"] = dns_name
if ret_dict.get("pscEnabled"):
# Find PSC instance DNS name in the dns_names field
psc_dns_names = [
d["name"]
for d in ret_dict.get("dnsNames", [])
if d["connectionType"] == "PRIVATE_SERVICE_CONNECT"
and d["dnsScope"] == "INSTANCE"
]
dns_name = psc_dns_names[0] if psc_dns_names else None

# Fall back do dns_name field if dns_names is not set
if dns_name is None:
dns_name = ret_dict.get("dnsName", None)

# Remove trailing period from DNS name. Required for SSL in Python
if dns_name:
ip_addresses["PSC"] = dns_name.rstrip(".")

return {
"ip_addresses": ip_addresses,
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
"PRIMARY": "127.0.0.1",
"PRIVATE": "10.0.0.1",
},
legacy_dns_name: bool = False,
cert_before: datetime = datetime.datetime.now(datetime.timezone.utc),
cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(hours=1),
Expand All @@ -237,6 +238,7 @@ def __init__(
self.psc_enabled = False
self.cert_before = cert_before
self.cert_expiration = cert_expiration
self.legacy_dns_name = legacy_dns_name
# create self signed CA cert
self.server_ca, self.server_key = generate_cert(
self.project, self.name, cert_before, cert_expiration
Expand All @@ -255,12 +257,22 @@ async def connect_settings(self, request: Any) -> web.Response:
"instance": self.name,
"expirationTime": str(self.cert_expiration),
},
"dnsName": "abcde.12345.us-central1.sql.goog",
"pscEnabled": self.psc_enabled,
"ipAddresses": ip_addrs,
"region": self.region,
"databaseVersion": self.db_version,
}
if self.legacy_dns_name:
response["dnsName"] = "abcde.12345.us-central1.sql.goog"
else:
response["dnsNames"] = [
{
"name": "abcde.12345.us-central1.sql.goog",
"connectionType": "PRIVATE_SERVICE_CONNECT",
"dnsScope": "INSTANCE",
}
]

return web.Response(content_type="application/json", body=json.dumps(response))

async def generate_ephemeral(self, request: Any) -> web.Response:
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None:
assert isinstance(resp["server_ca_cert"], str)


@pytest.mark.asyncio
async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> None:
"""
Test _get_metadata returns successfully with PSC IP type.
"""
# set PSC to enabled on test instance
fake_client.instance.psc_enabled = True
fake_client.instance.legacy_dns_name = True
resp = await fake_client._get_metadata(
"test-project",
"test-region",
"test-instance",
)
assert resp["database_version"] == "POSTGRES_15"
assert resp["ip_addresses"] == {
"PRIMARY": "127.0.0.1",
"PRIVATE": "10.0.0.1",
"PSC": "abcde.12345.us-central1.sql.goog",
}
assert isinstance(resp["server_ca_cert"], str)


@pytest.mark.asyncio
async def test_get_ephemeral(fake_client: CloudSQLClient) -> None:
"""
Expand Down
Loading