@@ -131,6 +131,17 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config,
131131 return configPath , & conf , nil
132132}
133133
134+ // GenerateInputs holds in-memory config and optional file contents so generate can run
135+ // without reading from disk (e.g. in tests). For production, Generate reads from FS and
136+ // fills FileContents before calling generate.
137+ type GenerateInputs struct {
138+ Config * config.Config
139+ ConfigPath string
140+ Dir string
141+ FileContents map [string ][]byte // path -> content; keys match paths used when reading (e.g. filepath.Join(dir, "schema.sql"))
142+ }
143+
144+ // Generate runs codegen for the given directory and config file, reading all input from disk.
134145func Generate (ctx context.Context , dir , filename string , o * Options ) (map [string ]string , error ) {
135146 e := o .Env
136147 stderr := o .Stderr
@@ -151,27 +162,76 @@ func Generate(ctx context.Context, dir, filename string, o *Options) (map[string
151162 return nil , err
152163 }
153164
154- // Comment on why these two methods exist
155165 if conf .Cloud .Project != "" && e .Remote && ! e .NoRemote {
156166 return remoteGenerate (ctx , configPath , conf , dir , stderr )
157167 }
158168
169+ inputs := & GenerateInputs {Config : conf , ConfigPath : configPath , Dir : dir }
170+ inputs .FileContents , err = loadFileContentsFromFS (conf , dir )
171+ if err != nil {
172+ return nil , err
173+ }
174+ return generate (ctx , inputs , o )
175+ }
176+
177+ // generate performs codegen using in-memory inputs. It is used by Generate (with contents
178+ // loaded from disk) and by tests (with pre-filled Config and FileContents, no temp files).
179+ func generate (ctx context.Context , inputs * GenerateInputs , o * Options ) (map [string ]string , error ) {
159180 g := & generator {
160- dir : dir ,
181+ dir : inputs . Dir ,
161182 output : map [string ]string {},
162183 }
163-
164- if err := processQuerySets (ctx , g , conf , dir , o ); err != nil {
184+ if o != nil && o .CodegenHandlerOverride != nil {
185+ g .codegenHandlerOverride = o .CodegenHandlerOverride
186+ }
187+ if err := processQuerySets (ctx , g , inputs , o ); err != nil {
165188 return nil , err
166189 }
167-
168190 return g .output , nil
169191}
170192
193+ // loadFileContentsFromFS reads all schema and query files referenced in conf into a map
194+ // path -> content, using dir to resolve paths.
195+ func loadFileContentsFromFS (conf * config.Config , dir string ) (map [string ][]byte , error ) {
196+ out := make (map [string ][]byte )
197+ for _ , pkg := range conf .SQL {
198+ for _ , rel := range pkg .Schema {
199+ path := filepath .Join (dir , rel )
200+ files , err := sqlpath .Glob ([]string {path })
201+ if err != nil {
202+ return nil , err
203+ }
204+ for _ , f := range files {
205+ b , err := os .ReadFile (f )
206+ if err != nil {
207+ return nil , err
208+ }
209+ out [f ] = b
210+ }
211+ }
212+ for _ , rel := range pkg .Queries {
213+ path := filepath .Join (dir , rel )
214+ files , err := sqlpath .Glob ([]string {path })
215+ if err != nil {
216+ return nil , err
217+ }
218+ for _ , f := range files {
219+ b , err := os .ReadFile (f )
220+ if err != nil {
221+ return nil , err
222+ }
223+ out [f ] = b
224+ }
225+ }
226+ }
227+ return out , nil
228+ }
229+
171230type generator struct {
172- m sync.Mutex
173- dir string
174- output map [string ]string
231+ m sync.Mutex
232+ dir string
233+ output map [string ]string
234+ codegenHandlerOverride grpc.ClientConnInterface
175235}
176236
177237func (g * generator ) Pairs (ctx context.Context , conf * config.Config ) []OutputPair {
@@ -200,7 +260,7 @@ func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair
200260}
201261
202262func (g * generator ) ProcessResult (ctx context.Context , combo config.CombinedSettings , sql OutputPair , result * compiler.Result ) error {
203- out , resp , err := codegen (ctx , combo , sql , result )
263+ out , resp , err := codegen (ctx , combo , sql , result , g . codegenHandlerOverride )
204264 if err != nil {
205265 return err
206266 }
@@ -333,52 +393,54 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
333393 return c .Result (), false
334394}
335395
336- func codegen (ctx context.Context , combo config.CombinedSettings , sql OutputPair , result * compiler.Result ) (string , * plugin.GenerateResponse , error ) {
396+ func codegen (ctx context.Context , combo config.CombinedSettings , sql OutputPair , result * compiler.Result , codegenOverride grpc. ClientConnInterface ) (string , * plugin.GenerateResponse , error ) {
337397 defer trace .StartRegion (ctx , "codegen" ).End ()
338398 req := codeGenRequest (result , combo )
339399 var handler grpc.ClientConnInterface
340400 var out string
341401 switch {
342402 case sql .Plugin != nil :
343403 out = sql .Plugin .Out
344- plug , err := findPlugin (combo .Global , sql .Plugin .Plugin )
345- if err != nil {
346- return "" , nil , fmt .Errorf ("plugin not found: %s" , err )
347- }
404+ if codegenOverride != nil {
405+ handler = codegenOverride
406+ } else {
407+ plug , err := findPlugin (combo .Global , sql .Plugin .Plugin )
408+ if err != nil {
409+ return "" , nil , fmt .Errorf ("plugin not found: %s" , err )
410+ }
348411
349- switch {
350- case plug .Process != nil :
351- handler = & process.Runner {
352- Cmd : plug .Process .Cmd ,
353- Dir : combo .Dir ,
354- Env : plug .Env ,
355- Format : plug .Process .Format ,
412+ switch {
413+ case plug .Process != nil :
414+ handler = & process.Runner {
415+ Cmd : plug .Process .Cmd ,
416+ Dir : combo .Dir ,
417+ Env : plug .Env ,
418+ Format : plug .Process .Format ,
419+ }
420+ case plug .WASM != nil :
421+ handler = & wasm.Runner {
422+ URL : plug .WASM .URL ,
423+ SHA256 : plug .WASM .SHA256 ,
424+ Env : plug .Env ,
425+ }
426+ default :
427+ return "" , nil , fmt .Errorf ("unsupported plugin type" )
356428 }
357- case plug .WASM != nil :
358- handler = & wasm.Runner {
359- URL : plug .WASM .URL ,
360- SHA256 : plug .WASM .SHA256 ,
361- Env : plug .Env ,
429+ global , found := combo .Global .Options [plug .Name ]
430+ if found {
431+ opts , err := convert .YAMLtoJSON (global )
432+ if err != nil {
433+ return "" , nil , fmt .Errorf ("invalid global options: %w" , err )
434+ }
435+ req .GlobalOptions = opts
362436 }
363- default :
364- return "" , nil , fmt .Errorf ("unsupported plugin type" )
365437 }
366-
367438 opts , err := convert .YAMLtoJSON (sql .Plugin .Options )
368439 if err != nil {
369440 return "" , nil , fmt .Errorf ("invalid plugin options: %w" , err )
370441 }
371442 req .PluginOptions = opts
372443
373- global , found := combo .Global .Options [plug .Name ]
374- if found {
375- opts , err := convert .YAMLtoJSON (global )
376- if err != nil {
377- return "" , nil , fmt .Errorf ("invalid global options: %w" , err )
378- }
379- req .GlobalOptions = opts
380- }
381-
382444 case sql .Gen .Go != nil :
383445 out = combo .Go .Out
384446 handler = ext .HandleFunc (golang .Generate )
0 commit comments