diff --git a/internal/analyzer/analyzer.go b/internal/analyzer/analyzer.go index 3d7e3a0287..674f283db9 100644 --- a/internal/analyzer/analyzer.go +++ b/internal/analyzer/analyzer.go @@ -110,7 +110,21 @@ func (c *CachedAnalyzer) Close(ctx context.Context) error { return c.a.Close(ctx) } +func (c *CachedAnalyzer) EnsureConn(ctx context.Context, migrations []string) error { + return c.a.EnsureConn(ctx, migrations) +} + +func (c *CachedAnalyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + return c.a.GetColumnNames(ctx, query) +} + type Analyzer interface { Analyze(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysis.Analysis, error) Close(context.Context) error + // EnsureConn initializes the database connection with the given migrations. + // This is required for database-only mode where we need to connect before analyzing queries. + EnsureConn(ctx context.Context, migrations []string) error + // GetColumnNames returns the column names for a query by preparing it against the database. + // This is used for star expansion in database-only mode. + GetColumnNames(ctx context.Context, query string) ([]string, error) } diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 00e8871c7e..05b5445ebb 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -295,7 +295,7 @@ func remoteGenerate(ctx context.Context, configPath string, conf *config.Config, func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() - c, err := compiler.NewCompiler(sql, combo) + c, err := compiler.NewCompiler(sql, combo, parserOpts) defer func() { if c != nil { c.Close(ctx) diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 84fbb20a3c..1a95b586f4 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -1,6 +1,7 @@ package compiler import ( + "context" "errors" "fmt" "io" @@ -39,11 +40,20 @@ func (c *Compiler) parseCatalog(schemas []string) error { } contents := migrations.RemoveRollbackStatements(string(blob)) c.schema = append(c.schema, contents) + + // In database-only mode, we parse the schema to validate syntax + // but don't update the catalog - the database will be the source of truth stmts, err := c.parser.Parse(strings.NewReader(contents)) if err != nil { merr.Add(filename, contents, 0, err) continue } + + // Skip catalog updates in database-only mode + if c.databaseOnlyMode { + continue + } + for i := range stmts { if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) @@ -58,6 +68,15 @@ func (c *Compiler) parseCatalog(schemas []string) error { } func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { + ctx := context.Background() + + // In database-only mode, initialize the database connection before parsing queries + if c.databaseOnlyMode && c.analyzer != nil { + if err := c.analyzer.EnsureConn(ctx, c.schema); err != nil { + return nil, fmt.Errorf("failed to initialize database connection: %w", err) + } + } + var q []*Query merr := multierr.New() set := map[string]struct{}{} @@ -113,6 +132,7 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { if len(q) == 0 { return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(c.conf.Queries, ",")) } + return &Result{ Catalog: c.catalog, Queries: q, diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 75749cd6df..64fdf3d5c7 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -14,6 +14,7 @@ import ( sqliteanalyze "github.com/sqlc-dev/sqlc/internal/engine/sqlite/analyzer" "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/x/expander" ) type Compiler struct { @@ -27,9 +28,15 @@ type Compiler struct { selector selector schema []string + + // databaseOnlyMode indicates that the compiler should use database-only analysis + // and skip building the internal catalog from schema files (analyzer.database: only) + databaseOnlyMode bool + // expander is used to expand SELECT * and RETURNING * in database-only mode + expander *expander.Expander } -func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) { +func NewCompiler(conf config.SQL, combo config.CombinedSettings, parserOpts opts.Parser) (*Compiler, error) { c := &Compiler{conf: conf, combo: combo} if conf.Database != nil && conf.Database.Managed { @@ -37,13 +44,33 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.client = client } + // Check for database-only mode (analyzer.database: only) + // This feature requires the analyzerv2 experiment to be enabled + databaseOnlyMode := conf.Analyzer.Database.IsOnly() && parserOpts.Experiment.AnalyzerV2 + switch conf.Engine { case config.EngineSQLite: - c.parser = sqlite.NewParser() + parser := sqlite.NewParser() + c.parser = parser c.catalog = sqlite.NewCatalog() c.selector = newSQLiteSelector() - if conf.Database != nil { - if conf.Analyzer.Database == nil || *conf.Analyzer.Database { + + if databaseOnlyMode { + // Database-only mode requires a database connection + if conf.Database == nil { + return nil, fmt.Errorf("analyzer.database: only requires database configuration") + } + if conf.Database.URI == "" && !conf.Database.Managed { + return nil, fmt.Errorf("analyzer.database: only requires database.uri or database.managed") + } + c.databaseOnlyMode = true + // Create the SQLite analyzer (implements Analyzer interface) + sqliteAnalyzer := sqliteanalyze.New(*conf.Database) + c.analyzer = analyzer.Cached(sqliteAnalyzer, combo.Global, *conf.Database) + // Create the expander using the analyzer as the column getter + c.expander = expander.New(c.analyzer, parser, parser) + } else if conf.Database != nil { + if conf.Analyzer.Database.IsEnabled() { c.analyzer = analyzer.Cached( sqliteanalyze.New(*conf.Database), combo.Global, @@ -56,11 +83,27 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.catalog = dolphin.NewCatalog() c.selector = newDefaultSelector() case config.EnginePostgreSQL: - c.parser = postgresql.NewParser() + parser := postgresql.NewParser() + c.parser = parser c.catalog = postgresql.NewCatalog() c.selector = newDefaultSelector() - if conf.Database != nil { - if conf.Analyzer.Database == nil || *conf.Analyzer.Database { + + if databaseOnlyMode { + // Database-only mode requires a database connection + if conf.Database == nil { + return nil, fmt.Errorf("analyzer.database: only requires database configuration") + } + if conf.Database.URI == "" && !conf.Database.Managed { + return nil, fmt.Errorf("analyzer.database: only requires database.uri or database.managed") + } + c.databaseOnlyMode = true + // Create the PostgreSQL analyzer (implements Analyzer interface) + pgAnalyzer := pganalyze.New(c.client, *conf.Database) + c.analyzer = analyzer.Cached(pgAnalyzer, combo.Global, *conf.Database) + // Create the expander using the analyzer as the column getter + c.expander = expander.New(c.analyzer, parser, parser) + } else if conf.Database != nil { + if conf.Analyzer.Database.IsEnabled() { c.analyzer = analyzer.Cached( pganalyze.New(c.client, *conf.Database), combo.Global, diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 681d291122..751cb3271a 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -71,7 +71,56 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } var anlys *analysis - if c.analyzer != nil { + if c.databaseOnlyMode && c.expander != nil { + // In database-only mode, use the expander for star expansion + // and rely entirely on the database analyzer for type resolution + expandedQuery, err := c.expander.Expand(ctx, rawSQL) + if err != nil { + return nil, fmt.Errorf("star expansion failed: %w", err) + } + + // Parse named parameters from the expanded query + expandedStmts, err := c.parser.Parse(strings.NewReader(expandedQuery)) + if err != nil { + return nil, fmt.Errorf("parsing expanded query failed: %w", err) + } + if len(expandedStmts) == 0 { + return nil, errors.New("no statements in expanded query") + } + expandedRaw := expandedStmts[0].Raw + + // Use the analyzer to get type information from the database + result, err := c.analyzer.Analyze(ctx, expandedRaw, expandedQuery, c.schema, nil) + if err != nil { + return nil, err + } + + // Convert the analyzer result to the internal analysis format + var cols []*Column + for _, col := range result.Columns { + cols = append(cols, convertColumn(col)) + } + var params []Parameter + for _, p := range result.Params { + params = append(params, Parameter{ + Number: int(p.Number), + Column: convertColumn(p.Column), + }) + } + + // Determine the insert table if applicable + var table *ast.TableName + if insert, ok := expandedRaw.Stmt.(*ast.InsertStmt); ok { + table, _ = ParseTableName(insert.Relation) + } + + anlys = &analysis{ + Table: table, + Columns: cols, + Parameters: params, + Query: expandedQuery, + } + } else if c.analyzer != nil { inference, _ := c.inferQuery(raw, rawSQL) if inference == nil { inference = &analysis{} diff --git a/internal/config/config.go b/internal/config/config.go index 0ff805fccd..d3e610ef05 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -122,8 +122,75 @@ type SQL struct { Analyzer Analyzer `json:"analyzer" yaml:"analyzer"` } +// AnalyzerDatabase represents the database analyzer setting. +// It can be a boolean (true/false) or the string "only" for database-only mode. +type AnalyzerDatabase struct { + value *bool // nil means not set, true/false for boolean values + isOnly bool // true when set to "only" +} + +// IsEnabled returns true if the database analyzer should be used. +// Returns true for both `true` and `"only"` settings. +func (a AnalyzerDatabase) IsEnabled() bool { + if a.isOnly { + return true + } + return a.value == nil || *a.value +} + +// IsOnly returns true if the analyzer is set to "only" mode. +func (a AnalyzerDatabase) IsOnly() bool { + return a.isOnly +} + +func (a *AnalyzerDatabase) UnmarshalJSON(data []byte) error { + // Try to unmarshal as boolean first + var b bool + if err := json.Unmarshal(data, &b); err == nil { + a.value = &b + a.isOnly = false + return nil + } + + // Try to unmarshal as string + var s string + if err := json.Unmarshal(data, &s); err == nil { + if s == "only" { + a.isOnly = true + a.value = nil + return nil + } + return errors.New("analyzer.database must be true, false, or \"only\"") + } + + return errors.New("analyzer.database must be true, false, or \"only\"") +} + +func (a *AnalyzerDatabase) UnmarshalYAML(unmarshal func(interface{}) error) error { + // Try to unmarshal as boolean first + var b bool + if err := unmarshal(&b); err == nil { + a.value = &b + a.isOnly = false + return nil + } + + // Try to unmarshal as string + var s string + if err := unmarshal(&s); err == nil { + if s == "only" { + a.isOnly = true + a.value = nil + return nil + } + return errors.New("analyzer.database must be true, false, or \"only\"") + } + + return errors.New("analyzer.database must be true, false, or \"only\"") +} + type Analyzer struct { - Database *bool `json:"database" yaml:"database"` + Database AnalyzerDatabase `json:"database" yaml:"database"` } // TODO: Figure out a better name for this diff --git a/internal/config/v_one.json b/internal/config/v_one.json index a0667a7c9c..e5ce9ec549 100644 --- a/internal/config/v_one.json +++ b/internal/config/v_one.json @@ -79,7 +79,10 @@ "type": "object", "properties": { "database": { - "type": "boolean" + "oneOf": [ + {"type": "boolean"}, + {"const": "only"} + ] } } }, diff --git a/internal/config/v_two.json b/internal/config/v_two.json index acf914997d..22591d7335 100644 --- a/internal/config/v_two.json +++ b/internal/config/v_two.json @@ -82,7 +82,10 @@ "type": "object", "properties": { "database": { - "type": "boolean" + "oneOf": [ + {"type": "boolean"}, + {"const": "only"} + ] } } }, diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index cd7072a7a9..7634918446 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -263,8 +263,9 @@ func TestReplay(t *testing.T) { opts := cmd.Options{ Env: cmd.Env{ - Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), - NoRemote: true, + Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), + Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), + NoRemote: true, }, Stderr: &stderr, MutateConfig: testctx.Mutate(t, path), diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..90b88c3389 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +type Product struct { + ID int32 + Name string + Price string +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..8d31d41cdf --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,65 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getProductStats = `-- name: GetProductStats :one +WITH product_stats AS ( + SELECT COUNT(*) as total, AVG(price) as avg_price FROM products +) +SELECT total, avg_price FROM product_stats +` + +type GetProductStatsRow struct { + Total int64 + AvgPrice string +} + +func (q *Queries) GetProductStats(ctx context.Context) (GetProductStatsRow, error) { + row := q.db.QueryRowContext(ctx, getProductStats) + var i GetProductStatsRow + err := row.Scan(&i.Total, &i.AvgPrice) + return i, err +} + +const listExpensiveProducts = `-- name: ListExpensiveProducts :many +WITH expensive AS ( + SELECT id, name, price FROM products WHERE price > 100 +) +SELECT id, name, price FROM expensive +` + +type ListExpensiveProductsRow struct { + ID int32 + Name string + Price string +} + +func (q *Queries) ListExpensiveProducts(ctx context.Context) ([]ListExpensiveProductsRow, error) { + rows, err := q.db.QueryContext(ctx, listExpensiveProducts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListExpensiveProductsRow + for rows.Next() { + var i ListExpensiveProductsRow + if err := rows.Scan(&i.ID, &i.Name, &i.Price); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..4626fe0f04 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql @@ -0,0 +1,11 @@ +-- name: ListExpensiveProducts :many +WITH expensive AS ( + SELECT * FROM products WHERE price > 100 +) +SELECT * FROM expensive; + +-- name: GetProductStats :one +WITH product_stats AS ( + SELECT COUNT(*) as total, AVG(price) as avg_price FROM products +) +SELECT * FROM product_stats; diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..17aaa6e650 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE products ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + price NUMERIC(10,2) NOT NULL +); diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..2b42787339 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go @@ -0,0 +1,59 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type Status string + +const ( + StatusPending Status = "pending" + StatusActive Status = "active" + StatusCompleted Status = "completed" +) + +func (e *Status) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = Status(s) + case string: + *e = Status(s) + default: + return fmt.Errorf("unsupported scan type for Status: %T", src) + } + return nil +} + +type NullStatus struct { + Status Status + Valid bool // Valid is true if Status is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullStatus) Scan(value interface{}) error { + if value == nil { + ns.Status, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.Status.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.Status), nil +} + +type Task struct { + ID int32 + Title string + Status Status +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..263a6b6736 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,80 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const createTask = `-- name: CreateTask :one +INSERT INTO tasks (title, status) VALUES ($1, $2) RETURNING id, title, status +` + +type CreateTaskParams struct { + Title string + Status Status +} + +func (q *Queries) CreateTask(ctx context.Context, arg CreateTaskParams) (Task, error) { + row := q.db.QueryRowContext(ctx, createTask, arg.Title, arg.Status) + var i Task + err := row.Scan(&i.ID, &i.Title, &i.Status) + return i, err +} + +const getTasksByStatus = `-- name: GetTasksByStatus :many +SELECT id, title, status FROM tasks WHERE status = $1 +` + +func (q *Queries) GetTasksByStatus(ctx context.Context, status Status) ([]Task, error) { + rows, err := q.db.QueryContext(ctx, getTasksByStatus, status) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Task + for rows.Next() { + var i Task + if err := rows.Scan(&i.ID, &i.Title, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTasks = `-- name: ListTasks :many +SELECT id, title, status FROM tasks +` + +func (q *Queries) ListTasks(ctx context.Context) ([]Task, error) { + rows, err := q.db.QueryContext(ctx, listTasks) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Task + for rows.Next() { + var i Task + if err := rows.Scan(&i.ID, &i.Title, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..11dcd9bf48 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql @@ -0,0 +1,8 @@ +-- name: ListTasks :many +SELECT * FROM tasks; + +-- name: GetTasksByStatus :many +SELECT * FROM tasks WHERE status = $1; + +-- name: CreateTask :one +INSERT INTO tasks (title, status) VALUES ($1, $2) RETURNING *; diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..443ae9845f --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql @@ -0,0 +1,7 @@ +CREATE TYPE status AS ENUM ('pending', 'active', 'completed'); + +CREATE TABLE tasks ( + id SERIAL PRIMARY KEY, + title TEXT NOT NULL, + status status NOT NULL DEFAULT 'pending' +); diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go new file mode 100644 index 0000000000..eaf05e5c00 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go new file mode 100644 index 0000000000..203224ead2 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go @@ -0,0 +1,65 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES (?, ?) RETURNING id, name, bio +` + +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, createAuthor, arg.Name, arg.Bio) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors WHERE id = ? +` + +func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql new file mode 100644 index 0000000000..8fe23a8600 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql @@ -0,0 +1,8 @@ +-- name: GetAuthor :one +SELECT * FROM authors WHERE id = ?; + +-- name: ListAuthors :many +SELECT * FROM authors; + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES (?, ?) RETURNING *; diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql new file mode 100644 index 0000000000..22fc0993c1 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT +); diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml new file mode 100644 index 0000000000..d2da6c31b2 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: sqlite + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..ec1cb8d670 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int32 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..9e2820cdbd --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,93 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING id, name, bio +` + +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, createAuthor, arg.Name, arg.Bio) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const deleteAuthor = `-- name: DeleteAuthor :one +DELETE FROM authors WHERE id = $1 RETURNING id, name, bio +` + +func (q *Queries) DeleteAuthor(ctx context.Context, id int32) (Author, error) { + row := q.db.QueryRowContext(ctx, deleteAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors WHERE id = $1 +` + +func (q *Queries) GetAuthor(ctx context.Context, id int32) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAuthor = `-- name: UpdateAuthor :one +UPDATE authors SET name = $1, bio = $2 WHERE id = $3 RETURNING id, name, bio +` + +type UpdateAuthorParams struct { + Name string + Bio sql.NullString + ID int32 +} + +func (q *Queries) UpdateAuthor(ctx context.Context, arg UpdateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, updateAuthor, arg.Name, arg.Bio, arg.ID) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..e091a5eaef --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql @@ -0,0 +1,14 @@ +-- name: ListAuthors :many +SELECT * FROM authors; + +-- name: GetAuthor :one +SELECT * FROM authors WHERE id = $1; + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING *; + +-- name: UpdateAuthor :one +UPDATE authors SET name = $1, bio = $2 WHERE id = $3 RETURNING *; + +-- name: DeleteAuthor :one +DELETE FROM authors WHERE id = $1 RETURNING *; diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..ca6ad1e2cf --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT +); diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/engine/postgresql/analyzer/analyze.go b/internal/engine/postgresql/analyzer/analyze.go index 5a08fa98ec..ee03e4d3c5 100644 --- a/internal/engine/postgresql/analyzer/analyze.go +++ b/internal/engine/postgresql/analyzer/analyze.go @@ -17,6 +17,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/shfmt" "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -320,3 +321,227 @@ func (a *Analyzer) Close(_ context.Context) error { } return nil } + +// SQL queries for schema introspection +const introspectTablesQuery = ` +SELECT + n.nspname AS schema_name, + c.relname AS table_name, + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + a.attnotnull AS not_null, + a.attndims AS array_dims, + COALESCE( + (SELECT true FROM pg_index i + WHERE i.indrelid = c.oid + AND i.indisprimary + AND a.attnum = ANY(i.indkey)), + false + ) AS is_primary_key +FROM + pg_catalog.pg_class c +JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace +JOIN + pg_catalog.pg_attribute a ON a.attrelid = c.oid +WHERE + c.relkind IN ('r', 'v', 'p') -- tables, views, partitioned tables + AND a.attnum > 0 -- skip system columns + AND NOT a.attisdropped + AND n.nspname = ANY($1) +ORDER BY + n.nspname, c.relname, a.attnum +` + +const introspectEnumsQuery = ` +SELECT + n.nspname AS schema_name, + t.typname AS type_name, + e.enumlabel AS enum_value +FROM + pg_catalog.pg_type t +JOIN + pg_catalog.pg_namespace n ON n.oid = t.typnamespace +JOIN + pg_catalog.pg_enum e ON e.enumtypid = t.oid +WHERE + t.typtype = 'e' + AND n.nspname = ANY($1) +ORDER BY + n.nspname, t.typname, e.enumsortorder +` + +type introspectedColumn struct { + SchemaName string `db:"schema_name"` + TableName string `db:"table_name"` + ColumnName string `db:"column_name"` + DataType string `db:"data_type"` + NotNull bool `db:"not_null"` + ArrayDims int `db:"array_dims"` + IsPrimaryKey bool `db:"is_primary_key"` +} + +type introspectedEnum struct { + SchemaName string `db:"schema_name"` + TypeName string `db:"type_name"` + EnumValue string `db:"enum_value"` +} + +// IntrospectSchema queries the database to build a catalog containing +// tables, columns, and enum types for the specified schemas. +func (a *Analyzer) IntrospectSchema(ctx context.Context, schemas []string) (*catalog.Catalog, error) { + if a.pool == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + c, err := a.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer c.Release() + + // Query tables and columns + rows, err := c.Query(ctx, introspectTablesQuery, schemas) + if err != nil { + return nil, fmt.Errorf("introspect tables: %w", err) + } + columns, err := pgx.CollectRows(rows, pgx.RowToStructByName[introspectedColumn]) + if err != nil { + return nil, fmt.Errorf("collect table rows: %w", err) + } + + // Query enums + enumRows, err := c.Query(ctx, introspectEnumsQuery, schemas) + if err != nil { + return nil, fmt.Errorf("introspect enums: %w", err) + } + enums, err := pgx.CollectRows(enumRows, pgx.RowToStructByName[introspectedEnum]) + if err != nil { + return nil, fmt.Errorf("collect enum rows: %w", err) + } + + // Build catalog + cat := &catalog.Catalog{ + DefaultSchema: "public", + SearchPath: schemas, + } + + // Create schema map for quick lookup + schemaMap := make(map[string]*catalog.Schema) + for _, schemaName := range schemas { + schema := &catalog.Schema{Name: schemaName} + cat.Schemas = append(cat.Schemas, schema) + schemaMap[schemaName] = schema + } + + // Group columns by table + tableMap := make(map[string]*catalog.Table) + for _, col := range columns { + key := col.SchemaName + "." + col.TableName + tbl, exists := tableMap[key] + if !exists { + tbl = &catalog.Table{ + Rel: &ast.TableName{ + Schema: col.SchemaName, + Name: col.TableName, + }, + } + tableMap[key] = tbl + if schema, ok := schemaMap[col.SchemaName]; ok { + schema.Tables = append(schema.Tables, tbl) + } + } + + dt, isArray, dims := parseType(col.DataType) + tbl.Columns = append(tbl.Columns, &catalog.Column{ + Name: col.ColumnName, + Type: ast.TypeName{Name: dt}, + IsNotNull: col.NotNull, + IsArray: isArray || col.ArrayDims > 0, + ArrayDims: max(dims, col.ArrayDims), + }) + } + + // Group enum values by type + enumMap := make(map[string]*catalog.Enum) + for _, e := range enums { + key := e.SchemaName + "." + e.TypeName + enum, exists := enumMap[key] + if !exists { + enum = &catalog.Enum{ + Name: e.TypeName, + } + enumMap[key] = enum + if schema, ok := schemaMap[e.SchemaName]; ok { + schema.Types = append(schema.Types, enum) + } + } + enum.Vals = append(enum.Vals, e.EnumValue) + } + + return cat, nil +} + +// EnsureConn initializes the database connection pool if not already done. +// This is useful for database-only mode where we need to connect before analyzing queries. +func (a *Analyzer) EnsureConn(ctx context.Context, migrations []string) error { + if a.pool != nil { + return nil + } + + var uri string + if a.db.Managed { + if a.client == nil { + return fmt.Errorf("client is nil") + } + edb, err := a.client.CreateDatabase(ctx, &dbmanager.CreateDatabaseRequest{ + Engine: "postgresql", + Migrations: migrations, + }) + if err != nil { + return err + } + uri = edb.Uri + } else if a.dbg.OnlyManagedDatabases { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + } + + conf, err := pgxpool.ParseConfig(uri) + if err != nil { + return err + } + pool, err := pgxpool.NewWithConfig(ctx, conf) + if err != nil { + return err + } + a.pool = pool + return nil +} + +// GetColumnNames implements the expander.ColumnGetter interface. +// It prepares a query and returns the column names from the result set description. +func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + if a.pool == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + conn, err := a.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} diff --git a/internal/engine/sqlite/analyzer/analyze.go b/internal/engine/sqlite/analyzer/analyze.go index 3b526816f0..3af9f99a30 100644 --- a/internal/engine/sqlite/analyzer/analyze.go +++ b/internal/engine/sqlite/analyzer/analyze.go @@ -14,6 +14,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/shfmt" "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -182,6 +183,144 @@ func (a *Analyzer) Close(_ context.Context) error { return nil } +// EnsureConn initializes the database connection if not already done. +// This is useful for database-only mode where we need to connect before analyzing queries. +func (a *Analyzer) EnsureConn(ctx context.Context, migrations []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn != nil { + return nil + } + + var uri string + applyMigrations := a.db.Managed + if a.db.Managed { + // For managed databases, create an in-memory database + uri = ":memory:" + } else if a.dbg.OnlyManagedDatabases { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + // For in-memory databases, we need to apply migrations since the database starts empty + if isInMemoryDatabase(uri) { + applyMigrations = true + } + } + + conn, err := sqlite3.Open(uri) + if err != nil { + return fmt.Errorf("failed to open sqlite database: %w", err) + } + a.conn = conn + + // Apply migrations for managed or in-memory databases + if applyMigrations { + for _, m := range migrations { + if len(strings.TrimSpace(m)) == 0 { + continue + } + if err := a.conn.Exec(m); err != nil { + a.conn.Close() + a.conn = nil + return fmt.Errorf("migration failed: %s: %w", m, err) + } + } + } + + return nil +} + +// GetColumnNames implements the expander.ColumnGetter interface. +// It prepares a query and returns the column names from the result set description. +func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + stmt, _, err := a.conn.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + colCount := stmt.ColumnCount() + columns := make([]string, colCount) + for i := 0; i < colCount; i++ { + columns[i] = stmt.ColumnName(i) + } + + return columns, nil +} + +// IntrospectSchema queries the database to build a catalog containing +// tables and columns for the database. +func (a *Analyzer) IntrospectSchema(ctx context.Context, schemas []string) (*catalog.Catalog, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + // Build catalog + cat := &catalog.Catalog{ + DefaultSchema: "main", + } + + // Create default schema + mainSchema := &catalog.Schema{Name: "main"} + cat.Schemas = append(cat.Schemas, mainSchema) + + // Query tables from sqlite_master + stmt, _, err := a.conn.Prepare("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + if err != nil { + return nil, fmt.Errorf("introspect tables: %w", err) + } + + tableNames := []string{} + for stmt.Step() { + tableName := stmt.ColumnText(0) + tableNames = append(tableNames, tableName) + } + stmt.Close() + + // For each table, get column information using PRAGMA table_info + for _, tableName := range tableNames { + tbl := &catalog.Table{ + Rel: &ast.TableName{ + Name: tableName, + }, + } + + pragmaStmt, _, err := a.conn.Prepare(fmt.Sprintf("PRAGMA table_info('%s')", tableName)) + if err != nil { + return nil, fmt.Errorf("pragma table_info for %s: %w", tableName, err) + } + + for pragmaStmt.Step() { + // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk + colName := pragmaStmt.ColumnText(1) + colType := pragmaStmt.ColumnText(2) + notNull := pragmaStmt.ColumnInt(3) != 0 + + tbl.Columns = append(tbl.Columns, &catalog.Column{ + Name: colName, + Type: ast.TypeName{Name: normalizeType(colType)}, + IsNotNull: notNull, + }) + } + pragmaStmt.Close() + + mainSchema.Tables = append(mainSchema.Tables, tbl) + } + + return cat, nil +} + // isInMemoryDatabase checks if a SQLite URI refers to an in-memory database func isInMemoryDatabase(uri string) bool { if uri == ":memory:" || uri == "" { diff --git a/internal/opts/experiment.go b/internal/opts/experiment.go index 73ca5d7de0..00d4b1b6f1 100644 --- a/internal/opts/experiment.go +++ b/internal/opts/experiment.go @@ -25,9 +25,9 @@ import ( // Experiment holds the state of all experimental features. // Add new experiments as boolean fields to this struct. type Experiment struct { - // Add experimental feature flags here as they are introduced. - // Example: - // NewParser bool // Enable new SQL parser + // AnalyzerV2 enables the database-only analyzer mode (analyzer.database: only) + // which uses the database for all type resolution instead of parsing schema files. + AnalyzerV2 bool } // ExperimentFromEnv returns an Experiment initialized from the SQLCEXPERIMENT @@ -75,10 +75,8 @@ func ExperimentFromString(val string) Experiment { // known experiment. func isKnownExperiment(name string) bool { switch strings.ToLower(name) { - // Add experiment names here as they are introduced. - // Example: - // case "newparser": - // return true + case "analyzerv2": + return true default: return false } @@ -87,21 +85,17 @@ func isKnownExperiment(name string) bool { // setExperiment sets the experiment flag with the given name to the given value. func setExperiment(e *Experiment, name string, enabled bool) { switch strings.ToLower(name) { - // Add experiment cases here as they are introduced. - // Example: - // case "newparser": - // e.NewParser = enabled + case "analyzerv2": + e.AnalyzerV2 = enabled } } // Enabled returns a slice of all enabled experiment names. func (e Experiment) Enabled() []string { var enabled []string - // Add enabled experiments here as they are introduced. - // Example: - // if e.NewParser { - // enabled = append(enabled, "newparser") - // } + if e.AnalyzerV2 { + enabled = append(enabled, "analyzerv2") + } return enabled } diff --git a/internal/opts/experiment_test.go b/internal/opts/experiment_test.go index 7845c0b13e..e9a8618e89 100644 --- a/internal/opts/experiment_test.go +++ b/internal/opts/experiment_test.go @@ -43,28 +43,26 @@ func TestExperimentFromString(t *testing.T) { input: "foo,,bar", want: Experiment{}, }, - // Add tests for specific experiments as they are introduced. - // Example: - // { - // name: "enable newparser", - // input: "newparser", - // want: Experiment{NewParser: true}, - // }, - // { - // name: "disable newparser", - // input: "nonewparser", - // want: Experiment{NewParser: false}, - // }, - // { - // name: "enable then disable", - // input: "newparser,nonewparser", - // want: Experiment{NewParser: false}, - // }, - // { - // name: "case insensitive", - // input: "NewParser,NONEWPARSER", - // want: Experiment{NewParser: false}, - // }, + { + name: "enable analyzerv2", + input: "analyzerv2", + want: Experiment{AnalyzerV2: true}, + }, + { + name: "disable analyzerv2", + input: "noanalyzerv2", + want: Experiment{AnalyzerV2: false}, + }, + { + name: "enable then disable analyzerv2", + input: "analyzerv2,noanalyzerv2", + want: Experiment{AnalyzerV2: false}, + }, + { + name: "analyzerv2 case insensitive", + input: "AnalyzerV2", + want: Experiment{AnalyzerV2: true}, + }, } for _, tt := range tests { @@ -88,13 +86,11 @@ func TestExperimentEnabled(t *testing.T) { exp: Experiment{}, want: nil, }, - // Add tests for specific experiments as they are introduced. - // Example: - // { - // name: "newparser enabled", - // exp: Experiment{NewParser: true}, - // want: []string{"newparser"}, - // }, + { + name: "analyzerv2 enabled", + exp: Experiment{AnalyzerV2: true}, + want: []string{"analyzerv2"}, + }, } for _, tt := range tests { @@ -124,13 +120,11 @@ func TestExperimentString(t *testing.T) { exp: Experiment{}, want: "", }, - // Add tests for specific experiments as they are introduced. - // Example: - // { - // name: "newparser enabled", - // exp: Experiment{NewParser: true}, - // want: "newparser", - // }, + { + name: "analyzerv2 enabled", + exp: Experiment{AnalyzerV2: true}, + want: "analyzerv2", + }, } for _, tt := range tests { @@ -159,18 +153,16 @@ func TestIsKnownExperiment(t *testing.T) { input: "", want: false, }, - // Add tests for specific experiments as they are introduced. - // Example: - // { - // name: "newparser lowercase", - // input: "newparser", - // want: true, - // }, - // { - // name: "newparser mixed case", - // input: "NewParser", - // want: true, - // }, + { + name: "analyzerv2 lowercase", + input: "analyzerv2", + want: true, + }, + { + name: "analyzerv2 mixed case", + input: "AnalyzerV2", + want: true, + }, } for _, tt := range tests { diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index 8c3606dd4d..62e6f1c9cf 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -37,9 +37,16 @@ func (n *SelectStmt) Format(buf *TrackedBuffer, d format.Dialect) { } if items(n.ValuesLists) { - buf.WriteString("VALUES (") - buf.astFormat(n.ValuesLists, d) - buf.WriteString(")") + buf.WriteString("VALUES ") + // ValuesLists is a list of rows, where each row is a List of values + for i, row := range n.ValuesLists.Items { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString("(") + buf.astFormat(row, d) + buf.WriteString(")") + } return }