diff --git a/user_repository.go b/user_repository.go index dee3616..fd9c5bf 100644 --- a/user_repository.go +++ b/user_repository.go @@ -12,12 +12,17 @@ import ( "github.com/google/uuid" ) -func NewUserRepository(db *sql.DB) *UserRepository { +type DB interface { + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +func NewUserRepository(db DB) *UserRepository { return &UserRepository{db: db} } type UserRepository struct { - db *sql.DB + db DB } func (r *UserRepository) ReadUser(ctx context.Context, userID uuid.UUID) (User, error) {