Skip to content

Instantly share code, notes, and snippets.

@bokwoon95
Created February 28, 2020 04:33
Show Gist options
  • Save bokwoon95/b56fea8582f4dde265d1b803e88bdf94 to your computer and use it in GitHub Desktop.
Save bokwoon95/b56fea8582f4dde265d1b803e88bdf94 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/jmoiron/sqlx"
_ "github.com/joho/godotenv/autoload"
_ "github.com/lib/pq"
)
var databaseURL string
var urlFlag = flag.String("databaseurl", "", "")
var cleanFlag = flag.Bool("clean", false, "")
var directories = []string{
".",
"skylab/sql",
"skylab/applicants/sql",
"skylab/students/sql",
"skylab/advisers/sql",
"skylab/mentors/sql",
"skylab/admins/sql",
}
// Categorizing every sql file into into different slices
type SortedFiles struct {
initfiles []string
viewfiles []string
testfiles []string
datafiles []string
functionfiles []string
}
func main() {
flag.Parse()
log.SetFlags(log.LstdFlags | log.Llongfile)
plpgsqlCheck, numberOfTables, err := setupDatabase()
if err != nil {
log.Fatalln(err.Error())
}
if numberOfTables == 0 {
*cleanFlag = true
}
var sorted SortedFiles
if args := flag.Args(); len(args) == 0 {
sorted, err = sortdirs(directories)
if err != nil {
log.Fatalln(err.Error())
}
} else {
sorted = sortfiles(args)
}
e := &SqloadError{}
if *cleanFlag || onlyInitfilesPresent(sorted) {
e.loadfiles(sorted.initfiles)
}
e.loadfiles(sorted.viewfiles)
e.loadfiles(sorted.functionfiles, plpgsqlCheck)
if *cleanFlag {
e.loadfiles(sorted.datafiles)
}
// e.loadfiles(sorted.testfiles) // sql test files are not checked for now
if err := e.Err; err != nil {
log.Fatalln(err.Error())
}
}
func onlyInitfilesPresent(sorted SortedFiles) bool {
return len(sorted.initfiles) != 0 &&
len(sorted.viewfiles) == 0 &&
len(sorted.functionfiles) == 0 &&
len(sorted.datafiles) == 0 &&
len(sorted.testfiles) == 0
}
// Call sortfiles on each directory
func sortdirs(directories []string) (sorted SortedFiles, err error) {
for _, directory := range directories {
sqlFilepaths, err := filepath.Glob(directory + "/*.sql")
if err != nil {
return sorted, fmt.Errorf("Error while globbing directory %s: %w", directory, err)
}
newsorted := sortfiles(sqlFilepaths)
sorted.initfiles = append(sorted.initfiles, newsorted.initfiles...)
sorted.viewfiles = append(sorted.viewfiles, newsorted.viewfiles...)
sorted.testfiles = append(sorted.testfiles, newsorted.testfiles...)
sorted.datafiles = append(sorted.datafiles, newsorted.datafiles...)
sorted.functionfiles = append(sorted.functionfiles, newsorted.functionfiles...)
}
return sorted, nil
}
// Sort files into different slices based on their filename prefix
// init.sql is the init file
// v_* are view files
// test_* are test files (pgTap)
// data* are data files
// temp* are temporary files, will be ignored
func sortfiles(files []string) (sorted SortedFiles) {
for _, file := range files {
basename := filepath.Base(file)
switch {
case !strings.HasSuffix(basename, ".sql"), strings.HasPrefix(basename, "temp"):
continue
case basename == "init.sql":
sorted.initfiles = append(sorted.initfiles, file)
case strings.HasPrefix(basename, "v_"):
sorted.viewfiles = append(sorted.viewfiles, file)
case strings.HasPrefix(basename, "test_"):
sorted.testfiles = append(sorted.testfiles, file)
case strings.HasPrefix(basename, "data"):
sorted.datafiles = append(sorted.datafiles, file)
default:
sorted.functionfiles = append(sorted.functionfiles, file)
}
}
return sorted
}
// Return the first non empty string
func coalesce(vals ...string) string {
for _, val := range vals {
if val != "" {
return val
}
}
return ""
}
// Obtain databaseURL from '-databaseurl' flag or from DATABASE_URL environment
// variable and check if the database is reachable. Return a plpgsqlCheck
// function that screens all functions in a file using plpgsql_check
func setupDatabase() (plpgsqlCheck func(string) error, numberOfTables int, err error) {
// Setup database connection
databaseURL = coalesce(*urlFlag, os.Getenv("DATABASE_URL"))
if databaseURL == "" {
return plpgsqlCheck, numberOfTables, fmt.Errorf("Database URL cannot be empty")
}
db, err := sqlx.Open("postgres", databaseURL)
if err != nil {
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to open the database: %w", err)
}
err = db.Ping()
if err != nil {
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to ping the database: %w", err)
}
err = db.QueryRowx(`
SELECT
COUNT(*)
FROM
pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE
c.relkind IN ('r','p','')
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
AND n.nspname !~ '^pg_toast'
AND pg_catalog.pg_table_is_visible(c.oid)
`).Scan(&numberOfTables)
if err != nil {
return plpgsqlCheck, numberOfTables, fmt.Errorf("Unable to get number of tables: %w", err)
}
// Return a plpgsqlCheck function which checks every function in an sql
// file with plpgsql_check_function
return func(filename string) error {
body, err := ioutil.ReadFile(filename)
if err != nil {
return fmt.Errorf("Error trying to read filename %s: %w", filename, err)
}
matches := regexp.MustCompile(`CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+([\w.]+)`).FindAllSubmatch(body, -1)
var functions []string
for _, match := range matches {
functions = append(functions, string(match[1]))
}
for _, function := range functions {
var errcount int
err = db.QueryRow("SELECT COUNT(*) FROM plpgsql_check_function_tb($1)", function).Scan(&errcount)
if err != nil {
return fmt.Errorf("Error occurred while trying to run plpgsql_check_function_tb: %w", err)
}
if errcount > 0 {
rows, err := db.Query("SELECT * FROM plpgsql_check_function($1)", function)
if err != nil {
return fmt.Errorf("Error occurred while trying to run plpgsql_check_function: %w", err)
}
for rows.Next() {
var errstring string
err = rows.Scan(&errstring)
if err != nil {
rows.Close()
return fmt.Errorf("Error scanning row into errstring: %w", err)
}
fmt.Println(errstring)
}
return fmt.Errorf("Error(s) in function %s", function)
}
}
return err
}, numberOfTables, nil
}
// Convenience struct that allows me to call (*SqloadError).loadfiles
// repeatedly without having to check for an error after each call.
// i.e.
//
// e.loadfiles(A_files)
// e.loadfiles(B_files)
// e.loadfiles(C_files)
// if e.Err != nil {
// handleError(err)
// }
//
// This is because if (*SqloadError).Err is non-nil, subsequent
// (*SqloadError).loadfiles() calls will immediately return and do nothing, so
// I only have to check for an error at the very end
type SqloadError struct {
Err error
}
// Load a bunch of filenames
func (e *SqloadError) loadfiles(filenames []string, postprocessors ...func(string) error) {
// If error is present, do nothing and return
if e.Err != nil {
return
}
// Otherwise range over the filenames and call loadfile() for each one
for _, filename := range filenames {
fmt.Println("[" + filename + "]")
err := loadfile(filename, postprocessors...)
if err != nil {
e.Err = err
return
}
}
}
// Load an sql file into the database using psql. Once done, call each
// preprocessor function on the filename
func loadfile(filename string, postprocessors ...func(string) error) (err error) {
cmdName := "psql"
cmdArgs := []string{"-v", "ON_ERROR_STOP=1", "-f", filename, databaseURL}
cmd := exec.Command(cmdName, cmdArgs...)
var stdoutBuf, stderrBuf bytes.Buffer
stdoutIn, _ := cmd.StdoutPipe()
stderrIn, _ := cmd.StderrPipe()
stdout := io.MultiWriter(os.Stdout, &stdoutBuf)
stderr := io.MultiWriter(os.Stderr, &stderrBuf)
err = cmd.Start()
if err != nil {
return fmt.Errorf("Error in cmd.Start(): %w", err)
}
var wg sync.WaitGroup
wg.Add(1)
var errStdout, errStderr error
go func() {
_, errStdout = io.Copy(stdout, stdoutIn)
wg.Done()
}()
_, errStderr = io.Copy(stderr, stderrIn)
wg.Wait()
err = cmd.Wait()
if err != nil {
return fmt.Errorf("Command encountered an error")
}
if errStdout != nil || errStderr != nil {
return fmt.Errorf("Failed to capture stdout or stderr")
}
// Call each preprocessor function on the file
for _, fn := range postprocessors {
err = fn(filename)
if err != nil {
return fmt.Errorf("Error when preprocessing file %s: %w", filename, err)
}
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment