Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ type (
lockTimeoutModifiers []string
insertStatements []string
outputFormat outputFormat

templateDb string
}

timeoutModifier struct {
Expand All @@ -117,6 +119,8 @@ type (
statementTimeoutModifiers []timeoutModifier
lockTimeoutModifiers []timeoutModifier
insertStatements []insertStatement

templateDb string
}
)

Expand Down Expand Up @@ -173,6 +177,13 @@ func createPlanFlags(cmd *cobra.Command) *planFlags {

cmd.Flags().Var(&flags.outputFormat, "output-format", "Change the output format for what is printed. Defaults to pretty-printed human-readable output. (options: pretty, json)")

cmd.Flags().StringVar(
&flags.templateDb,
"template-db",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name schema-template-db

"template0",
"Template database to use when creating temporary databases",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add to this doc string: "This is used to derive implicit schema objects for the target schema"

)

return flags
}

Expand Down Expand Up @@ -250,6 +261,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) {
statementTimeoutModifiers: statementTimeoutModifiers,
lockTimeoutModifiers: lockTimeoutModifiers,
insertStatements: insertStatements,
templateDb: p.templateDb,
}, nil
}

Expand Down Expand Up @@ -384,7 +396,7 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
copiedConfig := connConfig.Copy()
copiedConfig.Database = dbName
return openDbWithPgxConfig(copiedConfig)
}, tempdb.WithRootDatabase(connConfig.Database))
}, tempdb.WithRootDatabase(connConfig.Database), tempdb.WithTemplateDatabase(planConfig.templateDb))
if err != nil {
return diff.Plan{}, err
}
Expand Down
44 changes: 38 additions & 6 deletions cmd/pg-schema-diff/plan_cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -32,11 +33,11 @@ func TestParseTimeoutModifierStr(t *testing.T) {
},
{
opt: "timeout=15m",
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "pattern" missing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we remove these. I understand why they were included, but I prefer not to have these comments, since they might quickly become outdated

},
{
opt: `pattern="some pattern"`,
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "timeout" missing
},
{
opt: `pattern="normal" timeout=5m some-unknown-key=5m`,
Expand Down Expand Up @@ -80,19 +81,19 @@ func TestParseInsertStatementStr(t *testing.T) {
},
{
opt: "statement=no-index timeout=5m6s lock_timeout=1m11s",
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "index" missing
},
{
opt: "index=0 timeout=5m6s lock_timeout=1m11s",
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "statement" missing
},
{
opt: "index=0 statement=no-timeout lock_timeout=1m11s",
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "timeout" missing
},
{
opt: "index=0 statement=no-lock-timeout-timeout timeout=5m6s",
expectedErrContains: "could not find key",
expectedErrContains: "could not find key", // "lock_timeout" missing
},
{
opt: "index=not-an-int statement=some-statement timeout=5m6s lock_timeout=1m11s",
Expand All @@ -118,3 +119,34 @@ func TestParseInsertStatementStr(t *testing.T) {
})
}
}

func TestPlanFlagsTemplateDBDefault(t *testing.T) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably okay to not include these, since we don't have analog tests. The functionality it is testing is pretty straightforward too, i.e., flag parsing

cmd := &cobra.Command{}
flags := createPlanFlags(cmd)

err := cmd.ParseFlags([]string{
"--schema-dir=/no/such/dir",
})
require.NoError(t, err)

planCfg, err := parsePlanConfig(*flags)
require.NoError(t, err, "parsePlanConfig should not fail with a dummy --schema-dir")

assert.Equal(t, "template0", planCfg.templateDb)
}

func TestPlanFlagsTemplateDBOverride(t *testing.T) {
cmd := &cobra.Command{}
flags := createPlanFlags(cmd)

err := cmd.ParseFlags([]string{
"--template-db=template1",
"--schema-dir=/no/such/dir",
})
require.NoError(t, err)

planCfg, err := parsePlanConfig(*flags)
require.NoError(t, err, "parsePlanConfig should not fail with dummy --schema-dir")

assert.Equal(t, "template1", planCfg.templateDb)
}
39 changes: 28 additions & 11 deletions pkg/tempdb/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
DefaultOnInstanceMetadataSchema = "pgschemadiff_tmp_metadata"
DefaultOnInstanceMetadataTable = "metadata"

DefaultTemplateDatabase = "template0"
DefaultStatementTimeout = 3 * time.Second
)

Expand Down Expand Up @@ -57,11 +58,12 @@ type (
)
type (
onInstanceFactoryOptions struct {
dbPrefix string
metadataSchema string
metadataTable string
logger log.Logger
rootDatabase string
dbPrefix string
metadataSchema string
metadataTable string
logger log.Logger
rootDatabase string
templateDatabase string
}

OnInstanceFactoryOpt func(*onInstanceFactoryOptions)
Expand Down Expand Up @@ -102,6 +104,13 @@ func WithRootDatabase(db string) OnInstanceFactoryOpt {
}
}

// WithTemplateDatabase sets the template DB that CREATE DATABASE will use.
func WithTemplateDatabase(templateDB string) OnInstanceFactoryOpt {
return func(opts *onInstanceFactoryOptions) {
opts.templateDatabase = templateDB
}
}

type (
CreateConnPoolForDbFn func(ctx context.Context, dbName string) (*sql.DB, error)

Expand All @@ -126,11 +135,12 @@ type (
// when the temporary database was created, e.g., to create a TTL
func NewOnInstanceFactory(ctx context.Context, createConnPoolForDb CreateConnPoolForDbFn, opts ...OnInstanceFactoryOpt) (_ Factory, _retErr error) {
options := onInstanceFactoryOptions{
dbPrefix: DefaultOnInstanceDbPrefix,
metadataSchema: DefaultOnInstanceMetadataSchema,
metadataTable: DefaultOnInstanceMetadataTable,
rootDatabase: "postgres",
logger: log.SimpleLogger(),
dbPrefix: DefaultOnInstanceDbPrefix,
metadataSchema: DefaultOnInstanceMetadataSchema,
metadataTable: DefaultOnInstanceMetadataTable,
rootDatabase: "postgres",
logger: log.SimpleLogger(),
templateDatabase: DefaultTemplateDatabase,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making this a parameter of the factory, let's make this a parameter of the Create function, since one factory doesn't need to be restricted to one type of template

}
for _, opt := range opts {
opt(&options)
Expand Down Expand Up @@ -175,8 +185,15 @@ func (o *onInstanceFactory) Create(ctx context.Context) (_ *Database, _retErr er
defer rootConn.Close()

tempDbName := o.options.dbPrefix + strings.ReplaceAll(dbUUID.String(), "-", "_")

createDbSql := fmt.Sprintf(
"CREATE DATABASE %s TEMPLATE %s;",
tempDbName,
pgx.Identifier{o.options.templateDatabase}.Sanitize(),
)

// Create the temporary database using template0, the default Postgres template with no user-defined objects.
if _, err = rootConn.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s TEMPLATE template0;", tempDbName)); err != nil {
if _, err = rootConn.ExecContext(ctx, createDbSql); err != nil {
return nil, fmt.Errorf("creating temporary database: %w", err)
}
defer util.DoOnErrOrPanic(&_retErr, func() {
Expand Down
111 changes: 111 additions & 0 deletions pkg/tempdb/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,114 @@ func (suite *onInstanceTempDbFactorySuite) TestDropTempDB_CannotDropNonTempDb()
func TestOnInstanceFactorySuite(t *testing.T) {
suite.Run(t, new(onInstanceTempDbFactorySuite))
}

func (suite *onInstanceTempDbFactorySuite) TestCreate_UnknownTemplateDatabase() {
rootDbName := "some_other_root"
rootDb, err := suite.engine.CreateDatabaseWithName(rootDbName)
suite.Require().NoError(err, "failed to create the root DB")
suite.T().Cleanup(func() {
suite.Require().NoError(rootDb.DropDB())
})

factory := suite.mustBuildFactory(
WithRootDatabase(rootDbName),
WithTemplateDatabase("non_existent_template_db"),
)
defer func() {
suite.Require().NoError(factory.Close())
}()

_, err = factory.Create(context.Background())
suite.ErrorContains(
err,
"template database \"non_existent_template_db\" does not exist",
"Expected an error about non-existent template DB",
)
}

func (suite *onInstanceTempDbFactorySuite) TestCreate_UsesCustomTemplateDatabase() {
// 1) Create the template DB
templateDbName := "mytemplatedb"
templateDb, err := suite.engine.CreateDatabaseWithName(templateDbName)
suite.Require().NoError(err, "failed to create the custom template database")

// 2) Connect to it and create a table, then close all connections
templateDbPool, err := suite.getConnPoolForDb(templateDbName)
suite.Require().NoError(err, "could not get conn pool for template DB")

conn, err := templateDbPool.Conn(context.Background())
suite.Require().NoError(err, "could not get a connection from template DB pool")

_, err = conn.ExecContext(context.Background(), `
CREATE TABLE template_table (
id SERIAL PRIMARY KEY,
name TEXT
);
`)
suite.Require().NoError(err, "failed to create table in template DB")

suite.NoError(conn.Close())
suite.NoError(templateDbPool.Close())

// 3) Terminate any lingering sessions on "mytemplatedb"
rootConnPool, err := suite.getConnPoolForDb("postgres")
suite.Require().NoError(err, "failed to get conn pool for 'postgres'")

rootConn, err := rootConnPool.Conn(context.Background())
suite.Require().NoError(err, "failed to get root connection")

_, err = rootConn.ExecContext(context.Background(), `
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = $1
AND pid <> pg_backend_pid()
`, templateDbName)
suite.Require().NoError(err, "failed to terminate leftover connections to the template DB")

suite.NoError(rootConn.Close())
suite.NoError(rootConnPool.Close())

// 4) Create the "root" DB for CREATE DATABASE statements
rootDbName := "mytemplate_root"
rootDb, err := suite.engine.CreateDatabaseWithName(rootDbName)
suite.Require().NoError(err, "failed to create the root DB for create statements")

suite.T().Cleanup(func() {
suite.NoError(rootDb.DropDB())
suite.NoError(templateDb.DropDB())
})

// 5) Build the factory that uses our template DB
factory := suite.mustBuildFactory(
WithRootDatabase(rootDbName),
WithTemplateDatabase(templateDbName),
)

defer func() {
suite.Require().NoError(factory.Close())
}()

// 6) Create a new DB from our template
newDb, err := factory.Create(context.Background())
suite.Require().NoError(err, "should create a DB from mytemplatedb without error")

// 7) Verify the table is inherited in the newly created DB
newConn, err := newDb.ConnPool.Conn(context.Background())
suite.Require().NoError(err, "could not get conn from newly created DB pool")

var count int
err = newConn.QueryRowContext(context.Background(), `
SELECT count(*)
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind = 'r'
AND c.relname = 'template_table'
AND n.nspname = 'public';
`).Scan(&count)

suite.Require().NoError(err, "failed to check existence of 'template_table'")
suite.Equal(1, count, "expected 'template_table' to exist in the new DB")

suite.NoError(newConn.Close())
suite.NoError(newDb.Close(context.Background()))
}