Skip to content

Instantly share code, notes, and snippets.

@funwithbots
Last active July 16, 2023 15:30
Show Gist options
  • Save funwithbots/c7c4595fa8e9c6764a5e8a539b440e1b to your computer and use it in GitHub Desktop.
Save funwithbots/c7c4595fa8e9c6764a5e8a539b440e1b to your computer and use it in GitHub Desktop.
package database
/*
Bill Shaw, Copyright 2023
https://github.com/funwithbots
This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published
by the Free Software Foundation, either version 3 of the License, or any later version.
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
A full disclosure of the GUN GPL v3.0 is available at https://www.gnu.org/licenses/.
*/
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
)
// BulkInsert inserts data in batches.
// Each row must contain the same number of columns.
func BulkInsert(ctx context.Context, db *sql.DB, sql string, batchSize int, data [][]interface{}) (int, error) {
if data == nil || len(data) == 0 {
return 0, errors.New("no data provided")
}
columns := 0
if data[0] == nil || len(data[0]) == 0 {
return 0, errors.New("no columns provided")
} else {
columns = len(data[0])
}
rows := len(data)
// Start transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer tx.Rollback()
// Prepare bulk insert statement
bulkStmt := sql + batchPattern(batchSize, columns)
stmt, err := tx.PrepareContext(ctx, bulkStmt)
if err != nil {
log.Printf("BulkInsert(): Initial Prepare failed: %v\n", err)
return 0, err
}
defer stmt.Close()
// Insert in batches
start := 0
batchID := 0
for {
batchID++
if start >= rows {
break
}
if len(data[start:]) < batchSize {
// Last batch
batchSize = len(data[start:])
_ = stmt.Close()
bulkStmt = sql + batchPattern(batchSize, columns)
stmt, err = tx.PrepareContext(ctx, bulkStmt)
if err != nil {
log.Printf("BulkInsert(): Last Prepare failed: %v\n", err)
return 0, err
}
defer stmt.Close()
}
vals := make([]interface{}, 0, batchSize*columns)
for i := start; i < start+batchSize; i++ {
vals = append(vals, data[i]...)
}
if _, err := tx.ExecContext(ctx, bulkStmt, vals...); err != nil {
log.Printf("BulkInsert(): failed on batch %d: %v\n%s\n", batchID, err, bulkStmt[:140])
}
start += batchSize
fmt.Printf(".")
}
if err = tx.Commit(); err != nil {
log.Printf("BulkInsert(): Commit failed: %v\n", err)
return 0, err
}
log.Printf("\nInserted %d rows\n", rows)
return rows, nil
}
func batchPattern(batchSize, columns int) string {
pattern := fmt.Sprintf("(%s)", strings.Repeat("?, ", columns-1)+"?")
pattern = fmt.Sprintf(" %s%s", strings.Repeat(pattern+",", batchSize-1), pattern)
return strings.TrimRight(pattern, ",")
}
package database_test
import (
"context"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"go-bricklink-cli/pkg/database"
)
func TestBulkInsert(t *testing.T) {
tests := []struct {
name string
batchSize int
rows int
want int
wantErr bool
}{
{
name: "count",
rows: 99,
batchSize: 100,
want: 99,
wantErr: false,
},
{
name: "count",
rows: 100,
batchSize: 100,
want: 100,
wantErr: false,
},
{
name: "count",
rows: 101,
batchSize: 100,
want: 101,
wantErr: false,
},
{
name: "count",
rows: 0,
batchSize: 100,
want: 0,
wantErr: true,
},
}
sql := "INSERT INTO test (id, name) VALUES"
sqlMatch := strings.Join(strings.Fields(sql)[:3], " ")
row := []interface{}{1, "one"}
for _, tt := range tests {
// mock database connection
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error creating mock database: %v", err)
}
defer db.Close()
// set mock database expectations
mock.ExpectBegin()
mock.ExpectPrepare(sqlMatch)
size := tt.batchSize
if tt.batchSize > tt.rows {
size = tt.rows
} else {
mock.ExpectExec("INSERT INTO test").WillReturnResult(sqlmock.NewResult(1, int64(size)))
}
if tt.batchSize != tt.rows {
mock.ExpectPrepare(sqlMatch)
mock.ExpectExec("INSERT INTO test").WillReturnResult(sqlmock.NewResult(1, int64(tt.rows%tt.batchSize)))
}
mock.ExpectCommit()
// run test
data := make([][]interface{}, tt.rows)
for i := 0; i < tt.rows; i++ {
data[i] = row
}
got, err := database.BulkInsert(context.Background(), db, sql, tt.batchSize, data)
if (err != nil) != tt.wantErr {
t.Errorf("BulkInsert() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("BulkInsert() got = %v, want %v", got, tt.want)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment