Last active
April 15, 2024 17:36
-
-
Save rjnienaber/e7a833542d00a430f1c784ee842fa7fb to your computer and use it in GitHub Desktop.
repository pattern in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package controllers | |
import ( | |
"fmt" | |
"net/http" | |
"strconv" | |
"github.com/gin-gonic/gin" | |
"github.com/rjnienaber/server-sent-events/internal/errorcodes" | |
"github.com/rjnienaber/server-sent-events/internal/repositories" | |
"github.com/rjnienaber/server-sent-events/internal/views" | |
) | |
type examController struct { | |
scoreMesageFinder repositories.ScoreMessageFinder | |
} | |
func RegisterExamRoutes(engine *gin.Engine, repo repositories.ScoreMessageFinder) { | |
controller := examController{scoreMesageFinder: repo} | |
exams := engine.Group("/exams") | |
exams.GET("", controller.getAll) | |
exams.GET("/:id", controller.getById) | |
} | |
func (c examController) getById(g *gin.Context) { | |
idStr := g.Params.ByName("id") | |
params := map[string]string{"examId": idStr} | |
id, err := strconv.Atoi(idStr) | |
if err != nil { | |
msg := fmt.Sprintf("invalid id received '%s'", idStr) | |
writeError(g, http.StatusBadRequest, errorcodes.InvalidExamId, msg, ¶ms) | |
return | |
} | |
scores, err := c.scoreMesageFinder.FindByExamId(id) | |
if err != nil { | |
msg := fmt.Sprintf("could not retrieve scores for exam id '%s'", idStr) | |
writeError(g, http.StatusInternalServerError, errorcodes.InternalError, msg, ¶ms) | |
return | |
} | |
if len(scores) == 0 { | |
msg := fmt.Sprintf("no scores found for exam id '%s'", idStr) | |
writeError(g, http.StatusNotFound, errorcodes.NoExamScoresFound, msg, ¶ms) | |
return | |
} | |
g.JSON(http.StatusOK, views.ExamViewBuilder(id, scores)) | |
} | |
func (c examController) getAll(g *gin.Context) { | |
exams, err := c.scoreMesageFinder.GroupByExam() | |
if err != nil { | |
g.String(http.StatusBadRequest, err.Error()) | |
return | |
} | |
examViews := views.ExamsViewBuilder(exams) | |
g.JSON(http.StatusOK, examViews) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package controllers | |
import ( | |
"bytes" | |
"encoding/json" | |
"errors" | |
"net/http/httptest" | |
"testing" | |
"github.com/gin-gonic/gin" | |
"github.com/rjnienaber/server-sent-events/internal/models" | |
"github.com/stretchr/testify/assert" | |
) | |
type mockFinder struct { | |
findByExamId func(examId int) (scores []models.ScoreMessage, err error) | |
} | |
func (m mockFinder) FindByExamId(examId int) (scores []models.ScoreMessage, err error) { | |
if m.findByExamId != nil { | |
f := m.findByExamId | |
return f(examId) | |
} | |
panic("implement me") | |
} | |
func (m mockFinder) GroupByExam() (examScoreGroup []models.ExamScores, err error) { | |
panic("implement me") | |
} | |
func assertBasicResponseDetails(t *testing.T, recorder *httptest.ResponseRecorder, statusCode int) { | |
assert.Equal(t, statusCode, recorder.Code) | |
assert.Equal(t, []string{"application/json; charset=utf-8"}, recorder.Result().Header["Content-Type"]) | |
} | |
func createTestContext(params map[string]string) (recorder *httptest.ResponseRecorder, context *gin.Context) { | |
recorder = httptest.NewRecorder() | |
gin.SetMode(gin.ReleaseMode) | |
context, _ = gin.CreateTestContext(recorder) | |
for key, value := range params { | |
context.Params = []gin.Param{{Key: key, Value: value}} | |
} | |
return | |
} | |
func TestExamControllerReturns400WhenInvalidID(t *testing.T) { | |
recorder, context := createTestContext(map[string]string{"id": "abcd"}) | |
controller := examController{} | |
controller.getById(context) | |
assertBasicResponseDetails(t, recorder, 400) | |
expected := `{"statusCode":400,"code":"invalid_exam_id","message":"invalid id received 'abcd'","params":{"examId":"abcd"}}` | |
assert.Equal(t, expected, recorder.Body.String()) | |
} | |
func TestExamControllerReturns500WhenDatabaseAccessFails(t *testing.T) { | |
recorder, context := createTestContext(map[string]string{"id": "999"}) | |
controller := examController{mockFinder{ | |
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) { | |
return nil, errors.New("database access failed") | |
}, | |
}} | |
controller.getById(context) | |
assertBasicResponseDetails(t, recorder, 500) | |
expected := `{"statusCode":500,"code":"internal_error","message":"could not retrieve scores for exam id '999'","params":{"examId":"999"}}` | |
assert.Equal(t, expected, recorder.Body.String()) | |
} | |
func TestExamControllerReturns404WhenExamIsNotFound(t *testing.T) { | |
recorder, context := createTestContext(map[string]string{"id": "123"}) | |
controller := examController{mockFinder{ | |
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) { | |
return | |
}, | |
}} | |
controller.getById(context) | |
assertBasicResponseDetails(t, recorder, 404) | |
expected := `{"statusCode":404,"code":"no_exam_scores_found","message":"no scores found for exam id '123'","params":{"examId":"123"}}` | |
assert.Equal(t, expected, recorder.Body.String()) | |
} | |
func TestExamControllerReturnsJSONWhenExamIsFound(t *testing.T) { | |
recorder, context := createTestContext(map[string]string{"id": "456"}) | |
controller := examController{mockFinder{ | |
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) { | |
return []models.ScoreMessage{{StudentId: "John.Doe", Exam: 456, Score: 0.75}}, nil | |
}, | |
}} | |
controller.getById(context) | |
assertBasicResponseDetails(t, recorder, 200) | |
expectedBytes := []byte(`{ | |
"exam": { | |
"id": 456, | |
"scores": [{ | |
"id": "John.Doe", | |
"score": "0.750000", | |
"_links": { | |
"exam": { | |
"href": "/exams/456" | |
}, | |
"self": { | |
"href": "/students/John.Doe" | |
} | |
} | |
}], | |
"average": "0.750000", | |
"_links": { | |
"self": { | |
"href": "/exams/456" | |
} | |
} | |
} | |
}`) | |
expectedBuffer := new(bytes.Buffer) | |
err := json.Compact(expectedBuffer, expectedBytes) | |
assert.NoError(t, err) | |
assert.Equal(t, expectedBuffer.String(), recorder.Body.String()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package repositories | |
import ( | |
"log" | |
"os" | |
"time" | |
"github.com/rjnienaber/server-sent-events/internal/models" | |
"gorm.io/driver/sqlite" | |
"gorm.io/gorm" | |
"gorm.io/gorm/logger" | |
) | |
type Repositories struct { | |
ScoreMessages ScoreMessageRepository | |
logLevel logger.LogLevel | |
dialector *gorm.Dialector | |
} | |
type Option func(svc *Repositories) | |
// create logger here so we can control the log level from the beginning | |
// should be the same as logger.Default | |
func createLogger(logLevel logger.LogLevel) logger.Interface { | |
return logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ | |
SlowThreshold: 200 * time.Millisecond, | |
LogLevel: logLevel, | |
Colorful: true, | |
}) | |
} | |
func openDatabaseConnection(dialector gorm.Dialector, logger logger.Interface) (*gorm.DB, error) { | |
db, err := gorm.Open(dialector, &gorm.Config{Logger: logger}) | |
if err != nil { | |
return nil, err | |
} | |
err = db.AutoMigrate(&models.ScoreMessage{}) | |
if err != nil { | |
return nil, err | |
} | |
return db, nil | |
} | |
func NewRepositories(opts ...Option) (Repositories, error) { | |
repos := Repositories{} | |
for _, opt := range opts { | |
opt(&repos) | |
} | |
dbLogger := createLogger(repos.logLevel) | |
// no connector given, use in memory sqlite db | |
if repos.dialector == nil { | |
WithSqlite("file::memory:")(&repos) | |
} | |
db, err := openDatabaseConnection(*repos.dialector, dbLogger) | |
if err != nil { | |
return Repositories{}, err | |
} | |
repos.ScoreMessages = ScoreMessageRepository{db: db} | |
return repos, nil | |
} | |
func WithLogLevel(logLevel logger.LogLevel) Option { | |
return func(repos *Repositories) { | |
repos.logLevel = logLevel | |
} | |
} | |
func WithSqlite(dsn string) Option { | |
return func(repos *Repositories) { | |
conn := sqlite.Open(dsn) | |
repos.dialector = &conn | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package repositories | |
import ( | |
"testing" | |
"github.com/stretchr/testify/assert" | |
"gorm.io/gorm/logger" | |
) | |
func TestDatabaseReturnsErrorOnInvalidDsn(t *testing.T) { | |
_, err := NewRepositories(WithLogLevel(logger.Silent), WithSqlite("asdfasfd/asdfasfda")) | |
assert.EqualError(t, err, "unable to open database file: no such file or directory") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package repositories | |
import ( | |
"sort" | |
"github.com/rjnienaber/server-sent-events/internal/models" | |
"gorm.io/gorm" | |
) | |
type ScoreMessageFinder interface { | |
FindByExamId(examId int) (scores []models.ScoreMessage, err error) | |
GroupByExam() (examScores []models.ExamScores, err error) | |
} | |
type ScoreMessageSaver interface { | |
Save(score models.ScoreMessage) error | |
} | |
type ScoreMessageRepository struct { | |
db *gorm.DB | |
} | |
func (r ScoreMessageRepository) FindAll() (scores []models.ScoreMessage, err error) { | |
tx := r.db.Find(&scores) | |
return scores, tx.Error | |
} | |
func (r ScoreMessageRepository) Save(score models.ScoreMessage) error { | |
return r.db.Save(&score).Error | |
} | |
func (r ScoreMessageRepository) SaveAll(scores []models.ScoreMessage) error { | |
return r.db.Save(&scores).Error | |
} | |
func (r ScoreMessageRepository) FindByExamId(examId int) (scores []models.ScoreMessage, err error) { | |
err = r.db.Where("exam = ?", examId).Find(&scores).Error | |
return | |
} | |
func (r ScoreMessageRepository) FindByStudentId(studentId string) (scores []models.ScoreMessage, err error) { | |
err = r.db.Where("student_id = ?", studentId).Find(&scores).Error | |
return | |
} | |
func (r ScoreMessageRepository) GroupByExam() (examScores []models.ExamScores, err error) { | |
scores, err := r.FindAll() | |
if err != nil { | |
return | |
} | |
examIds := []int{} | |
groups := map[int][]models.ScoreMessage{} | |
for _, score := range scores { | |
if examScores, ok := groups[score.Exam]; ok { | |
groups[score.Exam] = append(examScores, score) | |
} else { | |
groups[score.Exam] = []models.ScoreMessage{score} | |
examIds = append(examIds, score.Exam) | |
} | |
} | |
sort.Ints(examIds) | |
for _, examId := range examIds { | |
scores := groups[examId] | |
examScores = append(examScores, models.ExamScores{Exam: examId, Scores: scores}) | |
} | |
return | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package repositories | |
import ( | |
"testing" | |
"github.com/rjnienaber/server-sent-events/internal/models" | |
"github.com/stretchr/testify/assert" | |
) | |
func createRepository(t *testing.T) ScoreMessageRepository { | |
repos, err := NewRepositories() | |
assert.NoError(t, err) | |
return repos.ScoreMessages | |
} | |
func TestScoreMessagesSaveAndFind(t *testing.T) { | |
repo := createRepository(t) | |
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75} | |
err := repo.Save(msg) | |
assert.NoError(t, err) | |
records, err := repo.FindAll() | |
assert.NoError(t, err) | |
assert.Len(t, records, 1) | |
assert.Equal(t, msg.StudentId, records[0].StudentId) | |
assert.Equal(t, msg.Exam, records[0].Exam) | |
assert.Equal(t, msg.Score, records[0].Score) | |
} | |
func TestScoreMessagesFindByExamId(t *testing.T) { | |
repo := createRepository(t) | |
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75} | |
err := repo.Save(msg) | |
assert.NoError(t, err) | |
records, err := repo.FindByExamId(123) | |
assert.NoError(t, err) | |
assert.Len(t, records, 1) | |
assert.Equal(t, msg.StudentId, records[0].StudentId) | |
assert.Equal(t, msg.Exam, records[0].Exam) | |
assert.Equal(t, msg.Score, records[0].Score) | |
} | |
func TestScoreMessagesFindByStudentId(t *testing.T) { | |
repo := createRepository(t) | |
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75} | |
err := repo.Save(msg) | |
assert.NoError(t, err) | |
records, err := repo.FindByStudentId("John.Doe") | |
assert.NoError(t, err) | |
assert.Len(t, records, 1) | |
assert.Equal(t, msg.StudentId, records[0].StudentId) | |
assert.Equal(t, msg.Exam, records[0].Exam) | |
assert.Equal(t, msg.Score, records[0].Score) | |
} | |
func TestScoreMessagesGroupByExam(t *testing.T) { | |
repo := createRepository(t) | |
msgs := []models.ScoreMessage{ | |
{StudentId: "John.Doe", Exam: 123, Score: 0.75}, | |
{StudentId: "Jane.Doe", Exam: 456, Score: 0.8}, | |
{StudentId: "Josiah.Doe", Exam: 123, Score: 0.55}, | |
} | |
err := repo.SaveAll(msgs) | |
assert.NoError(t, err) | |
examScores, err := repo.GroupByExam() | |
assert.NoError(t, err) | |
assert.Len(t, examScores, 2) | |
examOne := examScores[0] | |
assert.Equal(t, 123, examOne.Exam) | |
assert.Len(t, examOne.Scores, 2) | |
assert.Equal(t, "John.Doe", examOne.Scores[0].StudentId) | |
assert.Equal(t, "Josiah.Doe", examOne.Scores[1].StudentId) | |
examTwo := examScores[1] | |
assert.Equal(t, 456, examTwo.Exam) | |
assert.Len(t, examTwo.Scores, 1) | |
assert.Equal(t, "Jane.Doe", examTwo.Scores[0].StudentId) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment