Created
July 7, 2022 20:08
-
-
Save ks2211/b05ffaac4875cc653612d1fd68725c87 to your computer and use it in GitHub Desktop.
ent-gosqlmock
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ( | |
"database/sql" | |
"github.com/DATA-DOG/go-sqlmock" | |
"context" | |
"database/sql" | |
"fmt" | |
entsql "entgo.io/ent/dialect/sql" | |
"yourproject.com/project/ent" | |
_ "github.com/lib/pq" // needed for ent | |
) | |
// CreateDBDrvier creates sql driver for ent. | |
func CreateDBDriver(cfgString string) (*sql.DB, error) { | |
// create driver | |
db, err := sql.Open("postgres", cfgString) | |
if err != nil { | |
return nil, err | |
} | |
db.SetMaxIdleConns(10) | |
db.SetMaxOpenConns(100) | |
return db, nil | |
} | |
// DB holds ent client. | |
type DB struct { | |
client *ent.Client | |
} | |
// NewDB creates ent client. | |
func NewDB(driver *sql.DB) *DB { | |
conn := ent.NewClient(ent.Driver( | |
entsql.OpenDB("postgres", driver)), | |
) | |
conn.Debug() | |
return &DB{ | |
conn | |
} | |
} | |
// QueryEntity does a select on entities table. | |
func (d *DB) QueryEntity(ctx context.Context, id int64)(*ent.Entity, error) { | |
return d.client.Entity.Query(). | |
Where(entity.ID(id)). | |
Only(ctx) | |
} | |
func main() { | |
drv, _ := CreateDBDriver("user=blah ...") | |
db := NewDB(drv) | |
defer db.Close() | |
// do other stuff | |
} | |
// tests | |
func TestQueryEntity(t *testing.T) { | |
// create mock db driver and sql mocker | |
mockDb, mock, err := sqlmock.New() | |
if err != nil { | |
t.Fatalf("error create mock %v", err) | |
} | |
// create the db struct | |
db, err := NewDB(mockDb) | |
if err != nil { | |
t.Fatalf("error create db %v", err) | |
} | |
// mock rows | |
mockRows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "test") | |
// mock query | |
mock.ExpectQuery(`SELECT (.+) FROM "entities"`).WillReturnRows(mockRows) | |
// do query | |
dbEntity, err := db.QueryEntity(context.TODO()) | |
if err != nil { | |
t.Errorf("fail %v", err) | |
} | |
t.Log("pass") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment