Created
February 28, 2020 04:33
-
-
Save bokwoon95/b56fea8582f4dde265d1b803e88bdf94 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"flag" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"log" | |
"os" | |
"os/exec" | |
"path/filepath" | |
"regexp" | |
"strings" | |
"sync" | |
"github.com/jmoiron/sqlx" | |
_ "github.com/joho/godotenv/autoload" | |
_ "github.com/lib/pq" | |
) | |
var databaseURL string | |
var urlFlag = flag.String("databaseurl", "", "") | |
var cleanFlag = flag.Bool("clean", false, "") | |
var directories = []string{ | |
".", | |
"skylab/sql", | |
"skylab/applicants/sql", | |
"skylab/students/sql", | |
"skylab/advisers/sql", | |
"skylab/mentors/sql", | |
"skylab/admins/sql", | |
} | |
// Categorizing every sql file into into different slices | |
type SortedFiles struct { | |
initfiles []string | |
viewfiles []string | |
testfiles []string | |
datafiles []string | |
functionfiles []string | |
} | |
func main() { | |
flag.Parse() | |
log.SetFlags(log.LstdFlags | log.Llongfile) | |
plpgsqlCheck, numberOfTables, err := setupDatabase() | |
if err != nil { | |
log.Fatalln(err.Error()) | |
} | |
if numberOfTables == 0 { | |
*cleanFlag = true | |
} | |
var sorted SortedFiles | |
if args := flag.Args(); len(args) == 0 { | |
sorted, err = sortdirs(directories) | |
if err != nil { | |
log.Fatalln(err.Error()) | |
} | |
} else { | |
sorted = sortfiles(args) | |
} | |
e := &SqloadError{} | |
if *cleanFlag || onlyInitfilesPresent(sorted) { | |
e.loadfiles(sorted.initfiles) | |
} | |
e.loadfiles(sorted.viewfiles) | |
e.loadfiles(sorted.functionfiles, plpgsqlCheck) | |
if *cleanFlag { | |
e.loadfiles(sorted.datafiles) | |
} | |
// e.loadfiles(sorted.testfiles) // sql test files are not checked for now | |
if err := e.Err; err != nil { | |
log.Fatalln(err.Error()) | |
} | |
} | |
func onlyInitfilesPresent(sorted SortedFiles) bool { | |
return len(sorted.initfiles) != 0 && | |
len(sorted.viewfiles) == 0 && | |
len(sorted.functionfiles) == 0 && | |
len(sorted.datafiles) == 0 && | |
len(sorted.testfiles) == 0 | |
} | |
// Call sortfiles on each directory | |
func sortdirs(directories []string) (sorted SortedFiles, err error) { | |
for _, directory := range directories { | |
sqlFilepaths, err := filepath.Glob(directory + "/*.sql") | |
if err != nil { | |
return sorted, fmt.Errorf("Error while globbing directory %s: %w", directory, err) | |
} | |
newsorted := sortfiles(sqlFilepaths) | |
sorted.initfiles = append(sorted.initfiles, newsorted.initfiles...) | |
sorted.viewfiles = append(sorted.viewfiles, newsorted.viewfiles...) | |
sorted.testfiles = append(sorted.testfiles, newsorted.testfiles...) | |
sorted.datafiles = append(sorted.datafiles, newsorted.datafiles...) | |
sorted.functionfiles = append(sorted.functionfiles, newsorted.functionfiles...) | |
} | |
return sorted, nil | |
} | |
// Sort files into different slices based on their filename prefix | |
// init.sql is the init file | |
// v_* are view files | |
// test_* are test files (pgTap) | |
// data* are data files | |
// temp* are temporary files, will be ignored | |
func sortfiles(files []string) (sorted SortedFiles) { | |
for _, file := range files { | |
basename := filepath.Base(file) | |
switch { | |
case !strings.HasSuffix(basename, ".sql"), strings.HasPrefix(basename, "temp"): | |
continue | |
case basename == "init.sql": | |
sorted.initfiles = append(sorted.initfiles, file) | |
case strings.HasPrefix(basename, "v_"): | |
sorted.viewfiles = append(sorted.viewfiles, file) | |
case strings.HasPrefix(basename, "test_"): | |
sorted.testfiles = append(sorted.testfiles, file) | |
case strings.HasPrefix(basename, "data"): | |
sorted.datafiles = append(sorted.datafiles, file) | |
default: | |
sorted.functionfiles = append(sorted.functionfiles, file) | |
} | |
} | |
return sorted | |
} | |
// Return the first non empty string | |
func coalesce(vals ...string) string { | |
for _, val := range vals { | |
if val != "" { | |
return val | |
} | |
} | |
return "" | |
} | |
// Obtain databaseURL from '-databaseurl' flag or from DATABASE_URL environment | |
// variable and check if the database is reachable. Return a plpgsqlCheck | |
// function that screens all functions in a file using plpgsql_check | |
func setupDatabase() (plpgsqlCheck func(string) error, numberOfTables int, err error) { | |
// Setup database connection | |
databaseURL = coalesce(*urlFlag, os.Getenv("DATABASE_URL")) | |
if databaseURL == "" { | |
return plpgsqlCheck, numberOfTables, fmt.Errorf("Database URL cannot be empty") | |
} | |
db, err := sqlx.Open("postgres", databaseURL) | |
if err != nil { | |
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to open the database: %w", err) | |
} | |
err = db.Ping() | |
if err != nil { | |
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to ping the database: %w", err) | |
} | |
err = db.QueryRowx(` | |
SELECT | |
COUNT(*) | |
FROM | |
pg_catalog.pg_class c | |
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace | |
WHERE | |
c.relkind IN ('r','p','') | |
AND n.nspname <> 'pg_catalog' | |
AND n.nspname <> 'information_schema' | |
AND n.nspname !~ '^pg_toast' | |
AND pg_catalog.pg_table_is_visible(c.oid) | |
`).Scan(&numberOfTables) | |
if err != nil { | |
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to get number of tables: %w", err) | |
} | |
// Return a plpgsqlCheck function which checks every function in an sql | |
// file with plpgsql_check_function | |
return func(filename string) error { | |
body, err := ioutil.ReadFile(filename) | |
if err != nil { | |
return fmt.Errorf("Error trying to read filename %s: %w", filename, err) | |
} | |
matches := regexp.MustCompile(`CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+([\w.]+)`).FindAllSubmatch(body, -1) | |
var functions []string | |
for _, match := range matches { | |
functions = append(functions, string(match[1])) | |
} | |
for _, function := range functions { | |
var errcount int | |
err = db.QueryRow("SELECT COUNT(*) FROM plpgsql_check_function_tb($1)", function).Scan(&errcount) | |
if err != nil { | |
return fmt.Errorf("Error occurred while trying to run plpgsql_check_function_tb: %w", err) | |
} | |
if errcount > 0 { | |
rows, err := db.Query("SELECT * FROM plpgsql_check_function($1)", function) | |
if err != nil { | |
return fmt.Errorf("Error occurred while trying to run plpgsql_check_function: %w", err) | |
} | |
for rows.Next() { | |
var errstring string | |
err = rows.Scan(&errstring) | |
if err != nil { | |
rows.Close() | |
return fmt.Errorf("Error scanning row into errstring: %w", err) | |
} | |
fmt.Println(errstring) | |
} | |
return fmt.Errorf("Error(s) in function %s", function) | |
} | |
} | |
return err | |
}, numberOfTables, nil | |
} | |
// Convenience struct that allows me to call (*SqloadError).loadfiles | |
// repeatedly without having to check for an error after each call. | |
// i.e. | |
// | |
// e.loadfiles(A_files) | |
// e.loadfiles(B_files) | |
// e.loadfiles(C_files) | |
// if e.Err != nil { | |
// handleError(err) | |
// } | |
// | |
// This is because if (*SqloadError).Err is non-nil, subsequent | |
// (*SqloadError).loadfiles() calls will immediately return and do nothing, so | |
// I only have to check for an error at the very end | |
type SqloadError struct { | |
Err error | |
} | |
// Load a bunch of filenames | |
func (e *SqloadError) loadfiles(filenames []string, postprocessors ...func(string) error) { | |
// If error is present, do nothing and return | |
if e.Err != nil { | |
return | |
} | |
// Otherwise range over the filenames and call loadfile() for each one | |
for _, filename := range filenames { | |
fmt.Println("[" + filename + "]") | |
err := loadfile(filename, postprocessors...) | |
if err != nil { | |
e.Err = err | |
return | |
} | |
} | |
} | |
// Load an sql file into the database using psql. Once done, call each | |
// preprocessor function on the filename | |
func loadfile(filename string, postprocessors ...func(string) error) (err error) { | |
cmdName := "psql" | |
cmdArgs := []string{"-v", "ON_ERROR_STOP=1", "-f", filename, databaseURL} | |
cmd := exec.Command(cmdName, cmdArgs...) | |
var stdoutBuf, stderrBuf bytes.Buffer | |
stdoutIn, _ := cmd.StdoutPipe() | |
stderrIn, _ := cmd.StderrPipe() | |
stdout := io.MultiWriter(os.Stdout, &stdoutBuf) | |
stderr := io.MultiWriter(os.Stderr, &stderrBuf) | |
err = cmd.Start() | |
if err != nil { | |
return fmt.Errorf("Error in cmd.Start(): %w", err) | |
} | |
var wg sync.WaitGroup | |
wg.Add(1) | |
var errStdout, errStderr error | |
go func() { | |
_, errStdout = io.Copy(stdout, stdoutIn) | |
wg.Done() | |
}() | |
_, errStderr = io.Copy(stderr, stderrIn) | |
wg.Wait() | |
err = cmd.Wait() | |
if err != nil { | |
return fmt.Errorf("Command encountered an error") | |
} | |
if errStdout != nil || errStderr != nil { | |
return fmt.Errorf("Failed to capture stdout or stderr") | |
} | |
// Call each preprocessor function on the file | |
for _, fn := range postprocessors { | |
err = fn(filename) | |
if err != nil { | |
return fmt.Errorf("Error when preprocessing file %s: %w", filename, err) | |
} | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment