Last active
July 16, 2023 15:30
-
-
Save funwithbots/c7c4595fa8e9c6764a5e8a539b440e1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package database | |
/* | |
Bill Shaw, Copyright 2023 | |
https://github.com/funwithbots | |
This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published | |
by the Free Software Foundation, either version 3 of the License, or any later version. | |
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. | |
A full disclosure of the GUN GPL v3.0 is available at https://www.gnu.org/licenses/. | |
*/ | |
import ( | |
"context" | |
"database/sql" | |
"errors" | |
"fmt" | |
"log" | |
"strings" | |
) | |
// BulkInsert inserts data in batches. | |
// Each row must contain the same number of columns. | |
func BulkInsert(ctx context.Context, db *sql.DB, sql string, batchSize int, data [][]interface{}) (int, error) { | |
if data == nil || len(data) == 0 { | |
return 0, errors.New("no data provided") | |
} | |
columns := 0 | |
if data[0] == nil || len(data[0]) == 0 { | |
return 0, errors.New("no columns provided") | |
} else { | |
columns = len(data[0]) | |
} | |
rows := len(data) | |
// Start transaction | |
tx, err := db.BeginTx(ctx, nil) | |
if err != nil { | |
return 0, err | |
} | |
defer tx.Rollback() | |
// Prepare bulk insert statement | |
bulkStmt := sql + batchPattern(batchSize, columns) | |
stmt, err := tx.PrepareContext(ctx, bulkStmt) | |
if err != nil { | |
log.Printf("BulkInsert(): Initial Prepare failed: %v\n", err) | |
return 0, err | |
} | |
defer stmt.Close() | |
// Insert in batches | |
start := 0 | |
batchID := 0 | |
for { | |
batchID++ | |
if start >= rows { | |
break | |
} | |
if len(data[start:]) < batchSize { | |
// Last batch | |
batchSize = len(data[start:]) | |
_ = stmt.Close() | |
bulkStmt = sql + batchPattern(batchSize, columns) | |
stmt, err = tx.PrepareContext(ctx, bulkStmt) | |
if err != nil { | |
log.Printf("BulkInsert(): Last Prepare failed: %v\n", err) | |
return 0, err | |
} | |
defer stmt.Close() | |
} | |
vals := make([]interface{}, 0, batchSize*columns) | |
for i := start; i < start+batchSize; i++ { | |
vals = append(vals, data[i]...) | |
} | |
if _, err := tx.ExecContext(ctx, bulkStmt, vals...); err != nil { | |
log.Printf("BulkInsert(): failed on batch %d: %v\n%s\n", batchID, err, bulkStmt[:140]) | |
} | |
start += batchSize | |
fmt.Printf(".") | |
} | |
if err = tx.Commit(); err != nil { | |
log.Printf("BulkInsert(): Commit failed: %v\n", err) | |
return 0, err | |
} | |
log.Printf("\nInserted %d rows\n", rows) | |
return rows, nil | |
} | |
func batchPattern(batchSize, columns int) string { | |
pattern := fmt.Sprintf("(%s)", strings.Repeat("?, ", columns-1)+"?") | |
pattern = fmt.Sprintf(" %s%s", strings.Repeat(pattern+",", batchSize-1), pattern) | |
return strings.TrimRight(pattern, ",") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package database_test | |
import ( | |
"context" | |
"strings" | |
"testing" | |
"github.com/DATA-DOG/go-sqlmock" | |
"go-bricklink-cli/pkg/database" | |
) | |
func TestBulkInsert(t *testing.T) { | |
tests := []struct { | |
name string | |
batchSize int | |
rows int | |
want int | |
wantErr bool | |
}{ | |
{ | |
name: "count", | |
rows: 99, | |
batchSize: 100, | |
want: 99, | |
wantErr: false, | |
}, | |
{ | |
name: "count", | |
rows: 100, | |
batchSize: 100, | |
want: 100, | |
wantErr: false, | |
}, | |
{ | |
name: "count", | |
rows: 101, | |
batchSize: 100, | |
want: 101, | |
wantErr: false, | |
}, | |
{ | |
name: "count", | |
rows: 0, | |
batchSize: 100, | |
want: 0, | |
wantErr: true, | |
}, | |
} | |
sql := "INSERT INTO test (id, name) VALUES" | |
sqlMatch := strings.Join(strings.Fields(sql)[:3], " ") | |
row := []interface{}{1, "one"} | |
for _, tt := range tests { | |
// mock database connection | |
db, mock, err := sqlmock.New() | |
if err != nil { | |
t.Fatalf("error creating mock database: %v", err) | |
} | |
defer db.Close() | |
// set mock database expectations | |
mock.ExpectBegin() | |
mock.ExpectPrepare(sqlMatch) | |
size := tt.batchSize | |
if tt.batchSize > tt.rows { | |
size = tt.rows | |
} else { | |
mock.ExpectExec("INSERT INTO test").WillReturnResult(sqlmock.NewResult(1, int64(size))) | |
} | |
if tt.batchSize != tt.rows { | |
mock.ExpectPrepare(sqlMatch) | |
mock.ExpectExec("INSERT INTO test").WillReturnResult(sqlmock.NewResult(1, int64(tt.rows%tt.batchSize))) | |
} | |
mock.ExpectCommit() | |
// run test | |
data := make([][]interface{}, tt.rows) | |
for i := 0; i < tt.rows; i++ { | |
data[i] = row | |
} | |
got, err := database.BulkInsert(context.Background(), db, sql, tt.batchSize, data) | |
if (err != nil) != tt.wantErr { | |
t.Errorf("BulkInsert() error = %v, wantErr %v", err, tt.wantErr) | |
return | |
} | |
if got != tt.want { | |
t.Errorf("BulkInsert() got = %v, want %v", got, tt.want) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment