Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ jobs:
POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME
POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS
POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME
SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME
SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER
SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS
Expand All @@ -102,6 +103,7 @@ jobs:
POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}"
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}"
POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}"
POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}"
SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}"
SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}"
SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}"
Expand Down
32 changes: 30 additions & 2 deletions tests/system/test_asyncpg_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

import asyncio
import os
from typing import Any
from typing import Any, Union

import asyncpg
import sqlalchemy
import sqlalchemy.ext.asyncio

from google.cloud.sql.connector import Connector
from google.cloud.sql.connector import DefaultResolver
from google.cloud.sql.connector import DnsResolver


async def create_sqlalchemy_engine(
Expand All @@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
password: str,
db: str,
refresh_strategy: str = "background",
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]:
"""Creates a connection pool for a Cloud SQL instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
Expand Down Expand Up @@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
or "background". For serverless environments use "lazy" to avoid
errors resulting from CPU being throttled.
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
Resolver class for resolving instance connection name. Use
google.cloud.sql.connector.DnsResolver when resolving DNS domain
names or google.cloud.sql.connector.DefaultResolver for regular
instance connection names ("my-project:my-region:my-instance").
"""
loop = asyncio.get_running_loop()
connector = Connector(loop=loop, refresh_strategy=refresh_strategy)
connector = Connector(
loop=loop, refresh_strategy=refresh_strategy, resolver=resolver
)

async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await connector.connect_async(
Expand Down Expand Up @@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
await connector.close_async()


async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None:
"""Basic test to get time from database."""
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"]
user = os.environ["POSTGRES_USER"]
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
db = os.environ["POSTGRES_DB"]

pool, connector = await create_sqlalchemy_engine(
inst_conn_name, user, password, db, resolver=DnsResolver
)

async with pool.connect() as conn:
res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()
assert res[0] == 1

await connector.close_async()


async def test_connection_with_asyncpg() -> None:
"""Basic test to get time from database."""
inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]
Expand Down
30 changes: 29 additions & 1 deletion tests/system/test_pg8000_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import os

# [START cloud_sql_connector_postgres_pg8000]
from typing import Union

import pg8000
import sqlalchemy

from google.cloud.sql.connector import Connector
from google.cloud.sql.connector import DefaultResolver
from google.cloud.sql.connector import DnsResolver


def create_sqlalchemy_engine(
Expand All @@ -30,6 +34,7 @@ def create_sqlalchemy_engine(
password: str,
db: str,
refresh_strategy: str = "background",
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for a Cloud SQL instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
Expand Down Expand Up @@ -64,8 +69,13 @@ def create_sqlalchemy_engine(
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
or "background". For serverless environments use "lazy" to avoid
errors resulting from CPU being throttled.
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
Resolver class for resolving instance connection name. Use
google.cloud.sql.connector.DnsResolver when resolving DNS domain
names or google.cloud.sql.connector.DefaultResolver for regular
instance connection names ("my-project:my-region:my-instance").
"""
connector = Connector(refresh_strategy=refresh_strategy)
connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver)

def getconn() -> pg8000.dbapi.Connection:
conn: pg8000.dbapi.Connection = connector.connect(
Expand Down Expand Up @@ -153,3 +163,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None:
curr_time = time[0]
assert type(curr_time) is datetime
connector.close()


def test_custom_SAN_with_dns_pg8000_connection() -> None:
"""Basic test to get time from database."""
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"]
user = os.environ["POSTGRES_USER"]
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
db = os.environ["POSTGRES_DB"]

engine, connector = create_sqlalchemy_engine(
inst_conn_name, user, password, db, resolver=DnsResolver
)
with engine.connect() as conn:
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
conn.commit()
curr_time = time[0]
assert type(curr_time) is datetime
connector.close()
Loading