1
0
mirror of https://github.com/jesseduffield/lazygit.git synced 2025-06-17 00:18:05 +02:00

Show github pull request status against branch

This commit is contained in:
Jesse Duffield
2024-06-03 22:12:09 +10:00
parent 26c3e0d333
commit bcb70119bb
81 changed files with 2837 additions and 453 deletions

View File

@ -345,6 +345,9 @@ git:
# length. Set to 40 to disable truncation.
truncateCopiedCommitHashesTo: 12
# If true and if if `gh` is installed and on version >=2, we will use `gh` to display pull requests against branches.
enableGithubCli: true
# Periodic update checks
update:
# One of: 'prompt' (default) | 'background' | 'never'

10
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/adrg/xdg v0.4.0
github.com/atotto/clipboard v0.1.4
github.com/aybabtme/humanlog v0.4.1
github.com/cli/go-gh/v2 v2.9.0
github.com/cloudfoundry/jibber_jabber v0.0.0-20151120183258-bcc4c8345a21
github.com/creack/pty v1.1.11
github.com/gdamore/tcell/v2 v2.7.4
@ -46,6 +47,7 @@ require (
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cli/safeexec v1.0.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emirpasic/gods v1.12.0 // indirect
github.com/fatih/color v1.9.0 // indirect
@ -62,8 +64,8 @@ require (
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.11 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/onsi/ginkgo v1.10.3 // indirect
github.com/onsi/gomega v1.7.1 // indirect
@ -73,8 +75,8 @@ require (
github.com/sergi/go-diff v1.1.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xanzy/ssh-agent v0.2.1 // indirect
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/term v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect

24
go.sum
View File

@ -59,6 +59,10 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cli/go-gh/v2 v2.9.0 h1:D3lTjEneMYl54M+WjZ+kRPrR5CEJ5BHS05isBPOV3LI=
github.com/cli/go-gh/v2 v2.9.0/go.mod h1:MeRoKzXff3ygHu7zP+NVTT+imcHW6p3tpuxHAzRM2xE=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudfoundry/jibber_jabber v0.0.0-20151120183258-bcc4c8345a21 h1:tuijfIjZyjZaHq9xDUh0tNitwXshJpbLkqMOJv4H3do=
github.com/cloudfoundry/jibber_jabber v0.0.0-20151120183258-bcc4c8345a21/go.mod h1:po7NpZ/QiTKzBKyrsEAxwnTamCoh8uDk/egRpQ7siIc=
@ -228,13 +232,14 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-colorable v0.1.0/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs=
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mgutz/str v1.2.0 h1:4IzWSdIz9qPQWLfKZ0rJcV0jcUDpxvP4JVZ4GXQyvSw=
@ -326,8 +331,9 @@ golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -400,8 +406,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -467,12 +473,12 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

47
pkg/app/app_test.go Normal file
View File

@ -0,0 +1,47 @@
package app
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsValidGhVersion(t *testing.T) {
type scenario struct {
versionStr string
expectedResult bool
}
scenarios := []scenario{
{
"",
false,
},
{
`gh version 1.0.0 (2020-08-23)
https://github.com/cli/cli/releases/tag/v1.0.0`,
false,
},
{
`gh version 2.0.0 (2021-08-23)
https://github.com/cli/cli/releases/tag/v2.0.0`,
true,
},
{
`gh version 1.1.0 (2021-10-14)
https://github.com/cli/cli/releases/tag/v1.1.0
A new release of gh is available: 1.1.0 → v2.2.0
To upgrade, run: brew update && brew upgrade gh
https://github.com/cli/cli/releases/tag/v2.2.0`,
false,
},
}
for _, s := range scenarios {
t.Run(s.versionStr, func(t *testing.T) {
result := isGhVersionValid(s.versionStr)
assert.Equal(t, result, s.expectedResult)
})
}
}

View File

@ -38,6 +38,8 @@ type GitCommand struct {
Worktree *git_commands.WorktreeCommands
Version *git_commands.GitVersion
RepoPaths *git_commands.RepoPaths
GitHub *git_commands.GitHubCommands
HostingService *git_commands.HostingService
Loaders Loaders
}
@ -133,6 +135,8 @@ func NewGitCommandAux(
bisectCommands := git_commands.NewBisectCommands(gitCommon)
worktreeCommands := git_commands.NewWorktreeCommands(gitCommon)
blameCommands := git_commands.NewBlameCommands(gitCommon)
gitHubCommands := git_commands.NewGitHubCommand(gitCommon)
hostingServiceCommands := git_commands.NewHostingServiceCommand(gitCommon)
branchLoader := git_commands.NewBranchLoader(cmn, gitCommon, cmd, branchCommands.CurrentBranchInfo, configCommands)
commitFileLoader := git_commands.NewCommitFileLoader(cmn, cmd)
@ -164,6 +168,8 @@ func NewGitCommandAux(
WorkingTree: workingTreeCommands,
Worktree: worktreeCommands,
Version: version,
GitHub: gitHubCommands,
HostingService: hostingServiceCommands,
Loaders: Loaders{
BranchLoader: branchLoader,
CommitFileLoader: commitFileLoader,

View File

@ -0,0 +1,432 @@
package git_commands
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/cli/go-gh/v2/pkg/auth"
gogit "github.com/jesseduffield/go-git/v5"
"github.com/jesseduffield/lazygit/pkg/commands/models"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
)
type GitHubCommands struct {
*GitCommon
}
func NewGitHubCommand(gitCommon *GitCommon) *GitHubCommands {
return &GitHubCommands{
GitCommon: gitCommon,
}
}
// https://github.com/cli/cli/issues/2300
func (self *GitHubCommands) BaseRepo() error {
cmdArgs := NewGitCmd("config").
Arg("--local", "--get-regexp", ".gh-resolved").
ToArgv()
return self.cmd.New(cmdArgs).DontLog().Run()
}
// Ex: git config --local --add "remote.origin.gh-resolved" "jesseduffield/lazygit"
func (self *GitHubCommands) SetBaseRepo(repository string) (string, error) {
cmdArgs := NewGitCmd("config").
Arg("--local", "--add", "remote.origin.gh-resolved", repository).
ToArgv()
return self.cmd.New(cmdArgs).DontLog().RunWithOutput()
}
type Response struct {
Data RepositoryQuery `json:"data"`
}
type RepositoryQuery struct {
Repository map[string]PullRequest `json:"repository"`
}
type PullRequest struct {
Edges []PullRequestEdge `json:"edges"`
}
type PullRequestEdge struct {
Node PullRequestNode `json:"node"`
}
type PullRequestNode struct {
Title string `json:"title"`
HeadRefName string `json:"headRefName"`
Number int `json:"number"`
Url string `json:"url"`
HeadRepositoryOwner GithubRepositoryOwner `json:"headRepositoryOwner"`
State string `json:"state"`
}
type GithubRepositoryOwner struct {
Login string `json:"login"`
}
func fetchPullRequestsQuery(branches []string, owner string, repo string) string {
var queries []string
for i, branch := range branches {
// We're making a sub-query per branch, and arbitrarily labelling each subquery
// as a1, a2, etc.
fieldName := fmt.Sprintf("a%d", i+1)
// TODO: scope down by remote too if we can (right now if you search for master, you can get multiple results back, and all from forks)
queries = append(queries, fmt.Sprintf(`%s: pullRequests(first: 1, headRefName: "%s") {
edges {
node {
title
headRefName
state
number
url
headRepositoryOwner {
login
}
}
}
}`, fieldName, branch))
}
queryString := fmt.Sprintf(`{
repository(owner: "%s", name: "%s") {
%s
}
}`, owner, repo, strings.Join(queries, "\n"))
return queryString
}
// FetchRecentPRs fetches recent pull requests using GraphQL.
func (self *GitHubCommands) FetchRecentPRs(branches []string) ([]*models.GithubPullRequest, error) {
repoOwner, repoName, err := self.GetBaseRepoOwnerAndName()
if err != nil {
return nil, err
}
t := time.Now()
var g errgroup.Group
results := make(chan []*models.GithubPullRequest)
// We want at most 5 concurrent requests, but no less than 10 branches per request
concurrency := 5
minBranchesPerRequest := 10
branchesPerRequest := max(len(branches)/concurrency, minBranchesPerRequest)
for i := 0; i < len(branches); i += branchesPerRequest {
end := i + branchesPerRequest
if end > len(branches) {
end = len(branches)
}
branchChunk := branches[i:end]
// Launch a goroutine for each chunk of branches
g.Go(func() error {
prs, err := self.FetchRecentPRsAux(repoOwner, repoName, branchChunk)
if err != nil {
return err
}
results <- prs
return nil
})
}
// Close the results channel when all goroutines are done
go func() {
g.Wait()
close(results)
}()
// Collect results from all goroutines
var allPRs []*models.GithubPullRequest
for prs := range results {
allPRs = append(allPRs, prs...)
}
if err := g.Wait(); err != nil {
return nil, err
}
self.Log.Warnf("Fetched PRs in %s", time.Since(t))
return allPRs, nil
}
func (self *GitHubCommands) FetchRecentPRsAux(repoOwner string, repoName string, branches []string) ([]*models.GithubPullRequest, error) {
queryString := fetchPullRequestsQuery(branches, repoOwner, repoName)
escapedQueryString := strconv.Quote(queryString)
body := fmt.Sprintf(`{"query": %s}`, escapedQueryString)
req, err := http.NewRequest("POST", "https://api.github.com/graphql", bytes.NewBuffer([]byte(body)))
if err != nil {
return nil, err
}
defaultHost, _ := auth.DefaultHost()
token, _ := auth.TokenForHost(defaultHost)
if token == "" {
return nil, fmt.Errorf("No token found for GitHub")
}
req.Header.Set("Authorization", "token "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyStr := new(bytes.Buffer)
bodyStr.ReadFrom(resp.Body)
return nil, fmt.Errorf("GraphQL query failed with status: %s. Body: %s", resp.Status, bodyStr.String())
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result Response
err = json.Unmarshal(bodyBytes, &result)
if err != nil {
return nil, err
}
prs := []*models.GithubPullRequest{}
for _, repoQuery := range result.Data.Repository {
for _, edge := range repoQuery.Edges {
node := edge.Node
pr := &models.GithubPullRequest{
HeadRefName: node.HeadRefName,
Number: node.Number,
State: node.State,
Url: node.Url,
HeadRepositoryOwner: models.GithubRepositoryOwner{
Login: node.HeadRepositoryOwner.Login,
},
}
prs = append(prs, pr)
}
}
return prs, nil
}
// returns a map from branch name to pull request
func GenerateGithubPullRequestMap(
prs []*models.GithubPullRequest,
branches []*models.Branch,
remotes []*models.Remote,
) map[string]*models.GithubPullRequest {
res := map[string]*models.GithubPullRequest{}
if len(prs) == 0 {
return res
}
remotesToOwnersMap := getRemotesToOwnersMap(remotes)
if len(remotesToOwnersMap) == 0 {
return res
}
// A PR can be identified by two things: the owner e.g. 'jesseduffield' and the
// branch name e.g. 'feature/my-feature'. The owner might be different
// to the owner of the repo if the PR is from a fork of that repo.
type prKey struct {
owner string
branchName string
}
prByKey := map[prKey]models.GithubPullRequest{}
for _, pr := range prs {
prByKey[prKey{owner: pr.UserName(), branchName: pr.BranchName()}] = *pr
}
for _, branch := range branches {
if !branch.IsTrackingRemote() {
continue
}
// TODO: support branches whose UpstreamRemote contains a full git
// URL rather than just a remote name.
owner, foundRemoteOwner := remotesToOwnersMap[branch.UpstreamRemote]
if !foundRemoteOwner {
continue
}
pr, hasPr := prByKey[prKey{owner: owner, branchName: branch.UpstreamBranch}]
if !hasPr {
continue
}
res[branch.Name] = &pr
}
return res
}
func getRemotesToOwnersMap(remotes []*models.Remote) map[string]string {
res := map[string]string{}
for _, remote := range remotes {
if len(remote.Urls) == 0 {
continue
}
res[remote.Name] = getRepoInfoFromURL(remote.Urls[0]).Owner
}
return res
}
type RepoInformation struct {
Owner string
Repository string
}
// TODO: move this into hosting_service.go
func getRepoInfoFromURL(url string) RepoInformation {
isHTTP := strings.HasPrefix(url, "http")
if isHTTP {
splits := strings.Split(url, "/")
owner := strings.Join(splits[3:len(splits)-1], "/")
repo := strings.TrimSuffix(splits[len(splits)-1], ".git")
return RepoInformation{
Owner: owner,
Repository: repo,
}
}
tmpSplit := strings.Split(url, ":")
splits := strings.Split(tmpSplit[1], "/")
owner := strings.Join(splits[0:len(splits)-1], "/")
repo := strings.TrimSuffix(splits[len(splits)-1], ".git")
return RepoInformation{
Owner: owner,
Repository: repo,
}
}
// return <installed>, <valid version>
func (self *GitHubCommands) DetermineGitHubCliState() (bool, bool) {
output, err := self.cmd.New([]string{"gh", "--version"}).DontLog().RunWithOutput()
if err != nil {
// assuming a failure here means that it's not installed
return false, false
}
if !isGhVersionValid(output) {
return true, false
}
return true, true
}
func isGhVersionValid(versionStr string) bool {
// output should be something like:
// gh version 2.0.0 (2021-08-23)
// https://github.com/cli/cli/releases/tag/v2.0.0
re := regexp.MustCompile(`[^\d]+([\d\.]+)`)
matches := re.FindStringSubmatch(versionStr)
if len(matches) == 0 {
return false
}
ghVersion := matches[1]
majorVersion, err := strconv.Atoi(ghVersion[0:1])
if err != nil {
return false
}
if majorVersion < 2 {
return false
}
return true
}
func (self *GitHubCommands) InGithubRepo() bool {
remotes, err := self.repo.Remotes()
if err != nil {
self.Log.Error(err)
return false
}
if len(remotes) == 0 {
return false
}
remote := GetMainRemote(remotes)
if len(remote.Config().URLs) == 0 {
return false
}
url := remote.Config().URLs[0]
return strings.Contains(url, "github.com")
}
func GetMainRemote(remotes []*gogit.Remote) *gogit.Remote {
for _, remote := range remotes {
if remote.Config().Name == "origin" {
return remote
}
}
// need to sort remotes by name so that this is deterministic
return lo.MinBy(remotes, func(a, b *gogit.Remote) bool {
return a.Config().Name < b.Config().Name
})
}
func GetSuggestedRemoteName(remotes []*models.Remote) string {
if len(remotes) == 0 {
return "origin"
}
for _, remote := range remotes {
if remote.Name == "origin" {
return remote.Name
}
}
return remotes[0].Name
}
func (self *GitHubCommands) GetBaseRepoOwnerAndName() (string, string, error) {
remotes, err := self.repo.Remotes()
if err != nil {
return "", "", err
}
if len(remotes) == 0 {
return "", "", fmt.Errorf("No remotes found")
}
firstRemote := remotes[0]
if len(firstRemote.Config().URLs) == 0 {
return "", "", fmt.Errorf("No URLs found for remote")
}
url := firstRemote.Config().URLs[0]
repoInfo := getRepoInfoFromURL(url)
return repoInfo.Owner, repoInfo.Repository, nil
}

View File

@ -0,0 +1,34 @@
package git_commands
import "github.com/jesseduffield/lazygit/pkg/commands/hosting_service"
// a hosting service is something like github, gitlab, bitbucket etc
type HostingService struct {
*GitCommon
}
func NewHostingServiceCommand(gitCommon *GitCommon) *HostingService {
return &HostingService{
GitCommon: gitCommon,
}
}
func (self *HostingService) GetPullRequestURL(from string, to string) (string, error) {
return self.getHostingServiceMgr(self.config.GetRemoteURL()).GetPullRequestURL(from, to)
}
func (self *HostingService) GetCommitURL(commitSha string) (string, error) {
return self.getHostingServiceMgr(self.config.GetRemoteURL()).GetCommitURL(commitSha)
}
func (self *HostingService) GetRepoNameFromRemoteURL(remoteURL string) (string, error) {
return self.getHostingServiceMgr(remoteURL).GetRepoName()
}
// getting this on every request rather than storing it in state in case our remoteURL changes
// from one invocation to the next. Note however that we're currently caching config
// results so we might want to invalidate the cache here if it becomes a problem.
func (self *HostingService) getHostingServiceMgr(remoteURL string) *hosting_service.HostingServiceMgr {
configServices := self.UserConfig.Services
return hosting_service.NewHostingServiceMgr(self.Log, self.Tr, remoteURL, configServices)
}

View File

@ -6,7 +6,11 @@ var defaultUrlRegexStrings = []string{
`^(?:https?|ssh)://[^/]+/(?P<owner>.*)/(?P<repo>.*?)(?:\.git)?$`,
`^.*?@.*:(?P<owner>.*)/(?P<repo>.*?)(?:\.git)?$`,
}
var defaultRepoURLTemplate = "https://{{.webDomain}}/{{.owner}}/{{.repo}}"
var (
defaultRepoURLTemplate = "https://{{.webDomain}}/{{.owner}}/{{.repo}}"
defaultRepoNameTemplate = "{{.owner}}/{{.repo}}"
)
// we've got less type safety using go templates but this lends itself better to
// users adding custom service definitions in their config
@ -17,6 +21,7 @@ var githubServiceDef = ServiceDefinition{
commitURL: "/commit/{{.CommitHash}}",
regexStrings: defaultUrlRegexStrings,
repoURLTemplate: defaultRepoURLTemplate,
repoNameTemplate: defaultRepoNameTemplate,
}
var bitbucketServiceDef = ServiceDefinition{
@ -29,6 +34,7 @@ var bitbucketServiceDef = ServiceDefinition{
`^.*@.*:(?P<owner>.*)/(?P<repo>.*?)(?:\.git)?$`,
},
repoURLTemplate: defaultRepoURLTemplate,
repoNameTemplate: defaultRepoNameTemplate,
}
var gitLabServiceDef = ServiceDefinition{
@ -38,6 +44,7 @@ var gitLabServiceDef = ServiceDefinition{
commitURL: "/-/commit/{{.CommitHash}}",
regexStrings: defaultUrlRegexStrings,
repoURLTemplate: defaultRepoURLTemplate,
repoNameTemplate: defaultRepoNameTemplate,
}
var azdoServiceDef = ServiceDefinition{
@ -50,6 +57,8 @@ var azdoServiceDef = ServiceDefinition{
`^https://.*@dev.azure.com/(?P<org>.*?)/(?P<project>.*?)/_git/(?P<repo>.*?)(?:\.git)?$`,
},
repoURLTemplate: "https://{{.webDomain}}/{{.org}}/{{.project}}/_git/{{.repo}}",
// TODO: verify this is actually correct
repoNameTemplate: "{{.org}}/{{.project}}/{{.repo}}",
}
var bitbucketServerServiceDef = ServiceDefinition{
@ -62,6 +71,8 @@ var bitbucketServerServiceDef = ServiceDefinition{
`^https://.*/scm/(?P<project>.*)/(?P<repo>.*?)(?:\.git)?$`,
},
repoURLTemplate: "https://{{.webDomain}}/projects/{{.project}}/repos/{{.repo}}",
// TODO: verify this is actually correct
repoNameTemplate: "{{.project}}/{{.repo}}",
}
var giteaServiceDef = ServiceDefinition{

View File

@ -62,6 +62,18 @@ func (self *HostingServiceMgr) GetCommitURL(commitHash string) (string, error) {
return pullRequestURL, nil
}
// e.g. 'jesseduffield/lazygit'
func (self *HostingServiceMgr) GetRepoName() (string, error) {
gitService, err := self.getService()
if err != nil {
return "", err
}
repoName := gitService.repoName
return repoName, nil
}
func (self *HostingServiceMgr) getService() (*Service, error) {
serviceDomain, err := self.getServiceDomain(self.remoteURL)
if err != nil {
@ -73,8 +85,14 @@ func (self *HostingServiceMgr) getService() (*Service, error) {
return nil, err
}
repoName, err := serviceDomain.serviceDefinition.getRepoNameFromRemoteURL(self.remoteURL)
if err != nil {
return nil, err
}
return &Service{
repoURL: repoURL,
repoName: repoName,
ServiceDefinition: serviceDomain.serviceDefinition,
}, nil
}
@ -146,23 +164,44 @@ type ServiceDefinition struct {
// can expect 'webdomain' to be passed in. Otherwise, you get to pick what we match in the regex
repoURLTemplate string
repoNameTemplate string
}
func (self ServiceDefinition) getRepoURLFromRemoteURL(url string, webDomain string) (string, error) {
matches, err := self.parseRemoteUrl(url)
if err != nil {
return "", err
}
matches["webDomain"] = webDomain
return utils.ResolvePlaceholderString(self.repoURLTemplate, matches), nil
}
func (self ServiceDefinition) getRepoNameFromRemoteURL(url string) (string, error) {
matches, err := self.parseRemoteUrl(url)
if err != nil {
return "", err
}
return utils.ResolvePlaceholderString(self.repoNameTemplate, matches), nil
}
func (self ServiceDefinition) parseRemoteUrl(url string) (map[string]string, error) {
for _, regexStr := range self.regexStrings {
re := regexp.MustCompile(regexStr)
input := utils.FindNamedMatches(re, url)
if input != nil {
input["webDomain"] = webDomain
return utils.ResolvePlaceholderString(self.repoURLTemplate, input), nil
matches := utils.FindNamedMatches(re, url)
if matches != nil {
return matches, nil
}
}
return "", errors.New("Failed to parse repo information from url")
return nil, errors.New("Failed to parse repo information from url")
}
type Service struct {
repoURL string
// e.g. 'jesseduffield/lazygit'
repoName string
ServiceDefinition
}

View File

@ -0,0 +1,24 @@
package models
// TODO: see if I need to store the head repo name in case it differs from the base repo
type GithubPullRequest struct {
HeadRefName string `json:"headRefName"`
Number int `json:"number"`
State string `json:"state"` // "MERGED", "OPEN", "CLOSED"
Url string `json:"url"`
HeadRepositoryOwner GithubRepositoryOwner `json:"headRepositoryOwner"`
}
func (pr *GithubPullRequest) UserName() string {
// e.g. 'jesseduffield'
return pr.HeadRepositoryOwner.Login
}
func (pr *GithubPullRequest) BranchName() string {
// e.g. 'feature/my-feature'
return pr.HeadRefName
}
type GithubRepositoryOwner struct {
Login string `json:"login"`
}

View File

@ -244,6 +244,8 @@ type GitConfig struct {
// When copying commit hashes to the clipboard, truncate them to this
// length. Set to 40 to disable truncation.
TruncateCopiedCommitHashesTo int `yaml:"truncateCopiedCommitHashesTo"`
// If true and if if `gh` is installed and on version >=2, we will use `gh` to display pull requests against branches.
EnableGithubCli bool `yaml:"enableGithubCli"`
}
type PagerType string
@ -748,6 +750,7 @@ func GetDefaultConfig() *UserConfig {
CommitPrefixes: map[string]CommitPrefixConfig(nil),
ParseEmoji: false,
TruncateCopiedCommitHashesTo: 12,
EnableGithubCli: true,
},
Refresher: RefresherConfig{
RefreshInterval: 10,

View File

@ -30,7 +30,8 @@ func (self *BackgroundRoutineMgr) startBackgroundRoutines() {
if userConfig.Git.AutoFetch {
fetchInterval := userConfig.Refresher.FetchInterval
if fetchInterval > 0 {
go utils.Safe(self.startBackgroundFetch)
refreshInterval := self.gui.UserConfig.Refresher.FetchInterval
go utils.Safe(func() { self.startBackgroundFetch(refreshInterval) })
} else {
self.gui.c.Log.Errorf(
"Value of config option 'refresher.fetchInterval' (%d) is invalid, disabling auto-fetch",
@ -73,19 +74,15 @@ func (self *BackgroundRoutineMgr) startBackgroundRoutines() {
}
}
func (self *BackgroundRoutineMgr) startBackgroundFetch() {
func (self *BackgroundRoutineMgr) startBackgroundFetch(refreshInterval int) {
self.gui.waitForIntro.Wait()
isNew := self.gui.IsNewRepo
userConfig := self.gui.UserConfig
if !isNew {
time.After(time.Duration(userConfig.Refresher.FetchInterval) * time.Second)
}
err := self.backgroundFetch()
if err != nil && strings.Contains(err.Error(), "exit status 128") && isNew {
_ = self.gui.c.Alert(self.gui.c.Tr.NoAutomaticGitFetchTitle, self.gui.c.Tr.NoAutomaticGitFetchBody)
} else {
self.goEvery(time.Second*time.Duration(userConfig.Refresher.FetchInterval), self.gui.stopChan, func() error {
self.goEvery(time.Second*time.Duration(refreshInterval), self.gui.stopChan, func() error {
err := self.backgroundFetch()
self.gui.c.Render()
return err
@ -129,7 +126,7 @@ func (self *BackgroundRoutineMgr) goEvery(interval time.Duration, stop chan stru
func (self *BackgroundRoutineMgr) backgroundFetch() (err error) {
err = self.gui.git.Sync.FetchBackground()
_ = self.gui.c.Refresh(types.RefreshOptions{Scope: []types.RefreshableView{types.BRANCHES, types.COMMITS, types.REMOTES, types.TAGS}, Mode: types.ASYNC})
_ = self.gui.c.Refresh(types.RefreshOptions{Scope: []types.RefreshableView{types.BRANCHES, types.COMMITS, types.REMOTES, types.TAGS, types.PULL_REQUESTS}, Mode: types.ASYNC})
return err
}

View File

@ -28,6 +28,8 @@ func NewBranchesContext(c *ContextCommon) *BranchesContext {
return presentation.GetBranchListDisplayStrings(
viewModel.GetItems(),
c.State().GetItemOperation,
c.Model().PullRequests,
c.Model().Remotes,
c.State().GetRepoState().GetScreenMode() != types.SCREEN_NORMAL,
c.Modes().Diffing.Ref,
c.Views().Branches.Width(),

View File

@ -69,6 +69,7 @@ func (gui *Gui) resetHelpersAndControllers() {
mergeConflictsHelper,
worktreeHelper,
searchHelper,
suggestionsHelper,
)
diffHelper := helpers.NewDiffHelper(helperCommon)
cherryPickHelper := helpers.NewCherryPickHelper(

View File

@ -27,7 +27,6 @@ type Helpers struct {
MergeAndRebase *MergeAndRebaseHelper
MergeConflicts *MergeConflictsHelper
CherryPick *CherryPickHelper
Host *HostHelper
PatchBuilding *PatchBuildingHelper
Staging *StagingHelper
GPG *GpgHelper
@ -52,6 +51,7 @@ type Helpers struct {
Search *SearchHelper
Worktree *WorktreeHelper
SubCommits *SubCommitsHelper
Host *HostHelper
}
func NewStubHelpers() *Helpers {

View File

@ -29,6 +29,7 @@ type RefreshHelper struct {
mergeConflictsHelper *MergeConflictsHelper
worktreeHelper *WorktreeHelper
searchHelper *SearchHelper
suggestionsHelper *SuggestionsHelper
}
func NewRefreshHelper(
@ -40,6 +41,7 @@ func NewRefreshHelper(
mergeConflictsHelper *MergeConflictsHelper,
worktreeHelper *WorktreeHelper,
searchHelper *SearchHelper,
suggestionsHelper *SuggestionsHelper,
) *RefreshHelper {
return &RefreshHelper{
c: c,
@ -50,6 +52,7 @@ func NewRefreshHelper(
mergeConflictsHelper: mergeConflictsHelper,
worktreeHelper: worktreeHelper,
searchHelper: searchHelper,
suggestionsHelper: suggestionsHelper,
}
}
@ -93,6 +96,7 @@ func (self *RefreshHelper) Refresh(options types.RefreshOptions) error {
types.STATUS,
types.BISECT_INFO,
types.STAGING,
types.PULL_REQUESTS,
})
} else {
scopeSet = set.NewFromSlice(options.Scope)
@ -119,6 +123,10 @@ func (self *RefreshHelper) Refresh(options types.RefreshOptions) error {
}
}
if scopeSet.Includes(types.PULL_REQUESTS) {
refresh("pull requests", func() { _ = self.refreshGithubPullRequests() })
}
includeWorktreesWithBranches := false
if scopeSet.Includes(types.COMMITS) || scopeSet.Includes(types.BRANCHES) || scopeSet.Includes(types.REFLOG) || scopeSet.Includes(types.BISECT_INFO) {
// whenever we change commits, we should update branches because the upstream/downstream
@ -770,3 +778,130 @@ func (self *RefreshHelper) refreshView(context types.Context) error {
self.searchHelper.ReApplySearch(context)
return err
}
func (self *RefreshHelper) refreshGithubPullRequests() error {
self.c.Mutexes().RefreshingPullRequestsMutex.Lock()
defer self.c.Mutexes().RefreshingPullRequestsMutex.Unlock()
if !self.c.UserConfig.Git.EnableGithubCli {
return nil
}
if !self.c.Git().GitHub.InGithubRepo() {
self.c.Model().PullRequests = []*models.GithubPullRequest{}
return nil
}
switch self.c.State().GetGitHubCliState() {
case types.UNKNOWN:
state := self.determineGithubCliState()
self.c.State().SetGitHubCliState(state)
if state != types.VALID {
if state == types.INVALID_VERSION {
// todo: i18n
self.c.LogAction("gh version is too old (must be version 2 or greater), so pull requests will not be shown against branches.")
}
return nil
}
case types.VALID:
// continue on
default:
return nil
}
if err := self.c.Git().GitHub.BaseRepo(); err != nil {
ok, err := self.promptForBaseGithubRepo()
if err != nil {
return err
}
if !ok {
return nil
}
}
if err := self.setGithubPullRequests(); err != nil {
self.c.LogAction(fmt.Sprintf("Error fetching pull requests from GitHub: %s", err.Error()))
}
return nil
}
func (self *RefreshHelper) promptForBaseGithubRepo() (bool, error) {
err := self.refreshRemotes()
if err != nil {
return false, err
}
switch len(self.c.Model().Remotes) {
case 0:
return false, nil
case 1:
remote := self.c.Model().Remotes[0]
if len(remote.Urls) == 0 {
return false, nil
}
repoName, err := self.c.Git().HostingService.GetRepoNameFromRemoteURL(remote.Urls[0])
if err != nil {
self.c.Log.Error(err)
return false, nil
}
_, err = self.c.Git().GitHub.SetBaseRepo(repoName)
if err != nil {
self.c.Log.Error(err)
}
return true, nil
default:
_ = self.c.Prompt(types.PromptOpts{
Title: self.c.Tr.SelectRemoteRepository,
InitialContent: "",
FindSuggestionsFunc: self.suggestionsHelper.GetRemoteRepoSuggestionsFunc(),
HandleConfirm: func(repository string) error {
return self.c.WithWaitingStatus(self.c.Tr.LcSelectingRemote, func(gocui.Task) error {
// `repository` is something like 'jesseduffield/lazygit'
_, err := self.c.Git().GitHub.SetBaseRepo(repository)
if err != nil {
return err
}
return self.refreshGithubPullRequests()
})
},
})
return false, nil
}
}
func (self *RefreshHelper) determineGithubCliState() types.GitHubCliState {
installed, validVersion := self.c.Git().GitHub.DetermineGitHubCliState()
if validVersion {
return types.VALID
} else if installed {
return types.INVALID_VERSION
} else {
return types.NOT_INSTALLED
}
}
func (self *RefreshHelper) setGithubPullRequests() error {
branches := lo.Filter(self.c.Model().Branches, func(branch *models.Branch, _ int) bool {
return branch.IsTrackingRemote()
})
branchNames := lo.Map(branches, func(branch *models.Branch, _ int) string {
return branch.UpstreamBranch
})
prs, err := self.c.Git().GitHub.FetchRecentPRs(branchNames)
if err != nil {
return err
}
self.c.Model().PullRequests = prs
return self.c.PostRefreshUpdate(self.c.Contexts().Branches)
}

View File

@ -75,6 +75,30 @@ func (self *SuggestionsHelper) getBranchNames() []string {
})
}
func (self *SuggestionsHelper) GetRemoteRepoSuggestionsFunc() func(string) []*types.Suggestion {
repoNames := self.getRemoteRepoNames()
return FilterFunc(repoNames, self.c.UserConfig.Gui.UseFuzzySearch())
}
func (self *SuggestionsHelper) getRemoteRepoNames() []string {
remotes := self.c.Model().Remotes
result := make([]string, 0, len(remotes))
for _, remote := range remotes {
if len(remote.Urls) == 0 {
continue
}
repoName, err := self.c.Git().HostingService.GetRepoNameFromRemoteURL(remote.Urls[0])
if err != nil {
self.c.Log.Error(err)
continue
}
result = append(result, repoName)
}
return result
}
func (self *SuggestionsHelper) GetBranchNameSuggestionsFunc() func(string) []*types.Suggestion {
branchNames := self.getBranchNames()

View File

@ -4,6 +4,7 @@ import (
"errors"
"strings"
"github.com/jesseduffield/lazygit/pkg/commands/git_commands"
"github.com/jesseduffield/lazygit/pkg/commands/models"
"github.com/jesseduffield/lazygit/pkg/gui/types"
)
@ -67,19 +68,5 @@ func (self *UpstreamHelper) PromptForUpstreamWithoutInitialContent(_ *models.Bra
}
func (self *UpstreamHelper) GetSuggestedRemote() string {
return getSuggestedRemote(self.c.Model().Remotes)
}
func getSuggestedRemote(remotes []*models.Remote) string {
if len(remotes) == 0 {
return "origin"
}
for _, remote := range remotes {
if remote.Name == "origin" {
return remote.Name
}
}
return remotes[0].Name
return git_commands.GetSuggestedRemoteName(self.c.Model().Remotes)
}

View File

@ -140,6 +140,7 @@ type Gui struct {
integrationTest integrationTypes.IntegrationTest
afterLayoutFuncs chan func() error
gitHubCliState types.GitHubCliState
}
type StateAccessor struct {
@ -209,6 +210,14 @@ func (self *StateAccessor) ClearItemOperation(item types.HasUrn) {
delete(self.gui.itemOperations, item.URN())
}
func (self *StateAccessor) GetGitHubCliState() types.GitHubCliState {
return self.gui.gitHubCliState
}
func (self *StateAccessor) SetGitHubCliState(value types.GitHubCliState) {
self.gui.gitHubCliState = value
}
// we keep track of some stuff from one render to the next to see if certain
// things have changed
type PrevLayout struct {
@ -380,6 +389,7 @@ func (gui *Gui) resetState(startArgs appTypes.StartArgs) types.Context {
FilesTrie: patricia.NewTrie(),
Authors: map[string]*models.Author{},
MainBranches: git_commands.NewMainBranches(gui.UserConfig.Git.MainBranches, gui.os.Cmd),
PullRequests: make([]*models.GithubPullRequest, 0),
},
Modes: &types.Modes{
Filtering: filtering.New(startArgs.FilterPath, ""),
@ -486,6 +496,7 @@ func NewGui(
RefreshingFilesMutex: &deadlock.Mutex{},
RefreshingBranchesMutex: &deadlock.Mutex{},
RefreshingStatusMutex: &deadlock.Mutex{},
RefreshingPullRequestsMutex: &deadlock.Mutex{},
LocalCommitsMutex: &deadlock.Mutex{},
SubCommitsMutex: &deadlock.Mutex{},
AuthorsMutex: &deadlock.Mutex{},

View File

@ -2,6 +2,7 @@ package presentation
import (
"fmt"
"strconv"
"strings"
"time"
@ -23,6 +24,8 @@ var branchPrefixColorCache = make(map[string]style.TextStyle)
func GetBranchListDisplayStrings(
branches []*models.Branch,
getItemOperation func(item types.HasUrn) types.ItemOperation,
pullRequests []*models.GithubPullRequest,
remotes []*models.Remote,
fullDescription bool,
diffName string,
viewWidth int,
@ -30,9 +33,15 @@ func GetBranchListDisplayStrings(
userConfig *config.UserConfig,
worktrees []*models.Worktree,
) [][]string {
prs := git_commands.GenerateGithubPullRequestMap(
pullRequests,
branches,
remotes,
)
return lo.Map(branches, func(branch *models.Branch, _ int) []string {
diffed := branch.Name == diffName
return getBranchDisplayStrings(branch, getItemOperation(branch), fullDescription, diffed, viewWidth, tr, userConfig, worktrees, time.Now())
return getBranchDisplayStrings(branch, getItemOperation(branch), fullDescription, diffed, viewWidth, tr, userConfig, worktrees, time.Now(), prs)
})
}
@ -47,6 +56,7 @@ func getBranchDisplayStrings(
userConfig *config.UserConfig,
worktrees []*models.Worktree,
now time.Time,
prs map[string]*models.GithubPullRequest,
) []string {
checkedOutByWorkTree := git_commands.CheckedOutByOtherWorktree(b, worktrees)
showCommitHash := fullDescription || userConfig.Gui.ShowBranchCommitHash
@ -88,6 +98,7 @@ func getBranchDisplayStrings(
if checkedOutByWorkTree {
coloredName = fmt.Sprintf("%s %s", coloredName, style.FgDefault.Sprint(worktreeIcon))
}
if len(branchStatus) > 0 {
coloredName = fmt.Sprintf("%s %s", coloredName, branchStatus)
}
@ -98,18 +109,30 @@ func getBranchDisplayStrings(
}
res := make([]string, 0, 6)
res = append(res, recencyColor.Sprint(b.Recency))
pr, hasPr := prs[b.Name]
if hasPr {
if icons.IsIconEnabled() {
res = append(res, nameTextStyle.Sprint(icons.IconForBranch(b)))
res = append(res, prColor(pr.State).Sprint(icons.IconForBranch(b)))
} else {
res = append(res, prColor(pr.State).Sprint("⬤"))
}
} else {
if icons.IsIconEnabled() {
res = append(res, style.FgDefault.Sprint(icons.IconForBranch(b)))
} else {
res = append(res, style.FgDefault.Sprint("⬤"))
}
}
res = append(res, coloredName)
if showCommitHash {
res = append(res, utils.ShortHash(b.CommitHash))
}
res = append(res, coloredName)
if fullDescription {
res = append(
res,
@ -192,3 +215,24 @@ func BranchStatus(
func SetCustomBranches(customBranchColors map[string]string) {
branchPrefixColorCache = utils.SetCustomColors(customBranchColors)
}
func coloredPrNumber(pr *models.GithubPullRequest, hasPr bool) string {
if hasPr {
return prColor(pr.State).Sprint("#" + strconv.Itoa(pr.Number))
}
return ("")
}
func prColor(state string) style.TextStyle {
switch state {
case "OPEN":
return style.FgGreen
case "CLOSED":
return style.FgRed
case "MERGED":
return style.FgMagenta
default:
return style.FgDefault
}
}

View File

@ -295,6 +295,7 @@ type Model struct {
SubCommits []*models.Commit
Remotes []*models.Remote
Worktrees []*models.Worktree
PullRequests []*models.GithubPullRequest
// FilteredReflogCommits are the ones that appear in the reflog panel.
// when in filtering mode we only include the ones that match the given path
@ -327,6 +328,7 @@ type Mutexes struct {
RefreshingFilesMutex *deadlock.Mutex
RefreshingBranchesMutex *deadlock.Mutex
RefreshingStatusMutex *deadlock.Mutex
RefreshingPullRequestsMutex *deadlock.Mutex
LocalCommitsMutex *deadlock.Mutex
SubCommitsMutex *deadlock.Mutex
AuthorsMutex *deadlock.Mutex
@ -370,6 +372,8 @@ type IStateAccessor interface {
GetItemOperation(item HasUrn) ItemOperation
SetItemOperation(item HasUrn, operation ItemOperation)
ClearItemOperation(item HasUrn)
GetGitHubCliState() GitHubCliState
SetGitHubCliState(GitHubCliState)
}
type IRepoStateAccessor interface {
@ -406,3 +410,13 @@ const (
SCREEN_HALF
SCREEN_FULL
)
// for keeping track of whether our github CLI is installed and on a valid version
type GitHubCliState int
const (
UNKNOWN GitHubCliState = iota
VALID
NOT_INSTALLED
INVALID_VERSION
)

View File

@ -22,6 +22,7 @@ const (
COMMIT_FILES
// not actually a view. Will refactor this later
BISECT_INFO
PULL_REQUESTS
)
type RefreshMode int

View File

@ -523,6 +523,8 @@ type TranslationSet struct {
PrevScreenMode string
StartSearch string
StartFilter string
SelectRemoteRepository string
LcSelectingRemote string
Panel string
Keybindings string
KeybindingsLegend string
@ -610,6 +612,8 @@ type TranslationSet struct {
EnterSubmoduleTooltip string
Enter string
CopySubmoduleNameToClipboard string
MinGhVersionError string
FailedToObtainGhVersionError string
RemoveSubmodule string
RemoveSubmoduleTooltip string
RemoveSubmodulePrompt string
@ -1583,6 +1587,8 @@ func EnglishTranslationSet() *TranslationSet {
Enter: "Enter",
EnterSubmoduleTooltip: "Enter submodule. After entering the submodule, you can press `{{.escape}}` to escape back to the parent repo.",
CopySubmoduleNameToClipboard: "Copy submodule name to clipboard",
MinGhVersionError: "GH version must be at least 2.0. Please upgrade your gh version. Alternatively raise an issue at https://github.com/jesseduffield/lazygit/issues for lazygit to be more backwards compatible.",
FailedToObtainGhVersionError: "Failed to obtain gh version. Output from running 'gh --version' was: %s",
RemoveSubmodule: "Remove submodule",
RemoveSubmodulePrompt: "Are you sure you want to remove submodule '%s' and its corresponding directory? This is irreversible.",
RemoveSubmoduleTooltip: "Remove the selected submodule and its corresponding directory.",

View File

@ -675,6 +675,11 @@
"type": "integer",
"description": "When copying commit hashes to the clipboard, truncate them to this\nlength. Set to 40 to disable truncation.",
"default": 12
},
"enableGithubCli": {
"type": "boolean",
"description": "If true and if if `gh` is installed and on version \u003e=2, we will use `gh` to display pull requests against branches.",
"default": true
}
},
"additionalProperties": false,

View File

@ -20,3 +20,4 @@ git:
# TODO: add tests which explicitly test auto-refresh functionality
autoRefresh: false
autoFetch: false
enableGithubCli: false

21
vendor/github.com/cli/go-gh/v2/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 GitHub Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,70 @@
package set
var exists = struct{}{}
type stringSet struct {
v []string
m map[string]struct{}
}
func NewStringSet() *stringSet {
s := &stringSet{}
s.m = make(map[string]struct{})
s.v = []string{}
return s
}
func (s *stringSet) Add(value string) {
if s.Contains(value) {
return
}
s.m[value] = exists
s.v = append(s.v, value)
}
func (s *stringSet) AddValues(values []string) {
for _, v := range values {
s.Add(v)
}
}
func (s *stringSet) Remove(value string) {
if !s.Contains(value) {
return
}
delete(s.m, value)
s.v = sliceWithout(s.v, value)
}
func sliceWithout(s []string, v string) []string {
idx := -1
for i, item := range s {
if item == v {
idx = i
break
}
}
if idx < 0 {
return s
}
return append(s[:idx], s[idx+1:]...)
}
func (s *stringSet) RemoveValues(values []string) {
for _, v := range values {
s.Remove(v)
}
}
func (s *stringSet) Contains(value string) bool {
_, c := s.m[value]
return c
}
func (s *stringSet) Len() int {
return len(s.m)
}
func (s *stringSet) ToSlice() []string {
return s.v
}

View File

@ -0,0 +1,214 @@
// Package yamlmap is a wrapper of gopkg.in/yaml.v3 for interacting
// with yaml data as if it were a map.
package yamlmap
import (
"errors"
"gopkg.in/yaml.v3"
)
const (
modified = "modifed"
)
type Map struct {
*yaml.Node
}
var ErrNotFound = errors.New("not found")
var ErrInvalidYaml = errors.New("invalid yaml")
var ErrInvalidFormat = errors.New("invalid format")
func StringValue(value string) *Map {
return &Map{&yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: value,
}}
}
func MapValue() *Map {
return &Map{&yaml.Node{
Kind: yaml.MappingNode,
Tag: "!!map",
}}
}
func NullValue() *Map {
return &Map{&yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!null",
}}
}
func Unmarshal(data []byte) (*Map, error) {
var root yaml.Node
err := yaml.Unmarshal(data, &root)
if err != nil {
return nil, ErrInvalidYaml
}
if len(root.Content) == 0 {
return MapValue(), nil
}
if root.Content[0].Kind != yaml.MappingNode {
return nil, ErrInvalidFormat
}
return &Map{root.Content[0]}, nil
}
func Marshal(m *Map) ([]byte, error) {
return yaml.Marshal(m.Node)
}
func (m *Map) AddEntry(key string, value *Map) {
keyNode := &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: key,
}
m.Content = append(m.Content, keyNode, value.Node)
m.SetModified()
}
func (m *Map) Empty() bool {
return m.Content == nil || len(m.Content) == 0
}
func (m *Map) FindEntry(key string) (*Map, error) {
// Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...].
// When iterating over the content slice we only want to compare the keys of the yamlMap.
for i, v := range m.Content {
if i%2 != 0 {
continue
}
if v.Value == key {
if i+1 < len(m.Content) {
return &Map{m.Content[i+1]}, nil
}
}
}
return nil, ErrNotFound
}
func (m *Map) Keys() []string {
// Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...].
// When iterating over the content slice we only want to select the keys of the yamlMap.
keys := []string{}
for i, v := range m.Content {
if i%2 != 0 {
continue
}
keys = append(keys, v.Value)
}
return keys
}
func (m *Map) RemoveEntry(key string) error {
// Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...].
// When iterating over the content slice we only want to compare the keys of the yamlMap.
// If we find they key to remove, remove the key and its value from the content slice.
found, skipNext := false, false
newContent := []*yaml.Node{}
for i, v := range m.Content {
if skipNext {
skipNext = false
continue
}
if i%2 != 0 || v.Value != key {
newContent = append(newContent, v)
} else {
found = true
skipNext = true
m.SetModified()
}
}
if !found {
return ErrNotFound
}
m.Content = newContent
return nil
}
func (m *Map) SetEntry(key string, value *Map) {
// Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...].
// When iterating over the content slice we only want to compare the keys of the yamlMap.
// If we find they key to set, set the next item in the content slice to the new value.
m.SetModified()
for i, v := range m.Content {
if i%2 != 0 || v.Value != key {
continue
}
if v.Value == key {
if i+1 < len(m.Content) {
m.Content[i+1] = value.Node
return
}
}
}
m.AddEntry(key, value)
}
// Note: This is a hack to introduce the concept of modified/unmodified
// on top of gopkg.in/yaml.v3. This works by setting the Value property
// of a MappingNode to a specific value and then later checking if the
// node's Value property is that specific value. When a MappingNode gets
// output as a string the Value property is not used, thus changing it
// has no impact for our purposes.
func (m *Map) SetModified() {
// Can not mark a non-mapping node as modified
if m.Node.Kind != yaml.MappingNode && m.Node.Tag == "!!null" {
m.Node.Kind = yaml.MappingNode
m.Node.Tag = "!!map"
}
if m.Node.Kind == yaml.MappingNode {
m.Node.Value = modified
}
}
// Traverse map using BFS to set all nodes as unmodified.
func (m *Map) SetUnmodified() {
i := 0
queue := []*yaml.Node{m.Node}
for {
if i > (len(queue) - 1) {
break
}
q := queue[i]
i = i + 1
if q.Kind != yaml.MappingNode {
continue
}
q.Value = ""
queue = append(queue, q.Content...)
}
}
// Traverse map using BFS to searach for any nodes that have been modified.
func (m *Map) IsModified() bool {
i := 0
queue := []*yaml.Node{m.Node}
for {
if i > (len(queue) - 1) {
break
}
q := queue[i]
i = i + 1
if q.Kind != yaml.MappingNode {
continue
}
if q.Value == modified {
return true
}
queue = append(queue, q.Content...)
}
return false
}
func (m *Map) String() string {
data, err := Marshal(m)
if err != nil {
return ""
}
return string(data)
}

194
vendor/github.com/cli/go-gh/v2/pkg/auth/auth.go generated vendored Normal file
View File

@ -0,0 +1,194 @@
// Package auth is a set of functions for retrieving authentication tokens
// and authenticated hosts.
package auth
import (
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"github.com/cli/go-gh/v2/internal/set"
"github.com/cli/go-gh/v2/pkg/config"
"github.com/cli/safeexec"
)
const (
codespaces = "CODESPACES"
defaultSource = "default"
ghEnterpriseToken = "GH_ENTERPRISE_TOKEN"
ghHost = "GH_HOST"
ghToken = "GH_TOKEN"
github = "github.com"
githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN"
githubToken = "GITHUB_TOKEN"
hostsKey = "hosts"
localhost = "github.localhost"
oauthToken = "oauth_token"
)
// TokenForHost retrieves an authentication token and the source of that token for the specified
// host. The source can be either an environment variable, configuration file, or the system
// keyring. In the latter case, this shells out to "gh auth token" to obtain the token.
//
// Returns "", "default" if no applicable token is found.
func TokenForHost(host string) (string, string) {
if token, source := TokenFromEnvOrConfig(host); token != "" {
return token, source
}
ghExe := os.Getenv("GH_PATH")
if ghExe == "" {
ghExe, _ = safeexec.LookPath("gh")
}
if ghExe != "" {
if token, source := tokenFromGh(ghExe, host); token != "" {
return token, source
}
}
return "", defaultSource
}
// TokenFromEnvOrConfig retrieves an authentication token from environment variables or the config
// file as fallback, but does not support reading the token from system keyring. Most consumers
// should use TokenForHost.
func TokenFromEnvOrConfig(host string) (string, string) {
cfg, _ := config.Read(nil)
return tokenForHost(cfg, host)
}
func tokenForHost(cfg *config.Config, host string) (string, string) {
host = normalizeHostname(host)
if IsEnterprise(host) {
if token := os.Getenv(ghEnterpriseToken); token != "" {
return token, ghEnterpriseToken
}
if token := os.Getenv(githubEnterpriseToken); token != "" {
return token, githubEnterpriseToken
}
if isCodespaces, _ := strconv.ParseBool(os.Getenv(codespaces)); isCodespaces {
if token := os.Getenv(githubToken); token != "" {
return token, githubToken
}
}
if cfg != nil {
token, _ := cfg.Get([]string{hostsKey, host, oauthToken})
return token, oauthToken
}
}
if token := os.Getenv(ghToken); token != "" {
return token, ghToken
}
if token := os.Getenv(githubToken); token != "" {
return token, githubToken
}
if cfg != nil {
token, _ := cfg.Get([]string{hostsKey, host, oauthToken})
return token, oauthToken
}
return "", defaultSource
}
func tokenFromGh(path string, host string) (string, string) {
cmd := exec.Command(path, "auth", "token", "--secure-storage", "--hostname", host)
result, err := cmd.Output()
if err != nil {
return "", "gh"
}
return strings.TrimSpace(string(result)), "gh"
}
// KnownHosts retrieves a list of hosts that have corresponding
// authentication tokens, either from environment variables
// or from the configuration file.
// Returns an empty string slice if no hosts are found.
func KnownHosts() []string {
cfg, _ := config.Read(nil)
return knownHosts(cfg)
}
func knownHosts(cfg *config.Config) []string {
hosts := set.NewStringSet()
if host := os.Getenv(ghHost); host != "" {
hosts.Add(host)
}
if token, _ := tokenForHost(cfg, github); token != "" {
hosts.Add(github)
}
if cfg != nil {
keys, err := cfg.Keys([]string{hostsKey})
if err == nil {
hosts.AddValues(keys)
}
}
return hosts.ToSlice()
}
// DefaultHost retrieves an authenticated host and the source of host.
// The source can be either an environment variable or from the
// configuration file.
// Returns "github.com", "default" if no viable host is found.
func DefaultHost() (string, string) {
cfg, _ := config.Read(nil)
return defaultHost(cfg)
}
func defaultHost(cfg *config.Config) (string, string) {
if host := os.Getenv(ghHost); host != "" {
return host, ghHost
}
if cfg != nil {
keys, err := cfg.Keys([]string{hostsKey})
if err == nil && len(keys) == 1 {
return keys[0], hostsKey
}
}
return github, defaultSource
}
// TenancyHost is the domain name of a tenancy GitHub instance.
const tenancyHost = "ghe.com"
// IsEnterprise determines if a provided host is a GitHub Enterprise Server instance,
// rather than GitHub.com or a tenancy GitHub instance.
func IsEnterprise(host string) bool {
normalizedHost := normalizeHostname(host)
return normalizedHost != github && normalizedHost != localhost && !IsTenancy(normalizedHost)
}
// IsTenancy determines if a provided host is a tenancy GitHub instance,
// rather than GitHub.com or a GitHub Enterprise Server instance.
func IsTenancy(host string) bool {
normalizedHost := normalizeHostname(host)
return strings.HasSuffix(normalizedHost, "."+tenancyHost)
}
func normalizeHostname(host string) string {
hostname := strings.ToLower(host)
if strings.HasSuffix(hostname, "."+github) {
return github
}
if strings.HasSuffix(hostname, "."+localhost) {
return localhost
}
// This has been copied over from the cli/cli NormalizeHostname function
// to ensure compatible behaviour but we don't fully understand when or
// why it would be useful here. We can't see what harm will come of
// duplicating the logic.
if before, found := cutSuffix(hostname, "."+tenancyHost); found {
idx := strings.LastIndex(before, ".")
return fmt.Sprintf("%s.%s", before[idx+1:], tenancyHost)
}
return hostname
}
// Backport strings.CutSuffix from Go 1.20.
func cutSuffix(s, suffix string) (string, bool) {
if !strings.HasSuffix(s, suffix) {
return s, false
}
return s[:len(s)-len(suffix)], true
}

336
vendor/github.com/cli/go-gh/v2/pkg/config/config.go generated vendored Normal file
View File

@ -0,0 +1,336 @@
// Package config is a set of types for interacting with the gh configuration files.
// Note: This package is intended for use only in gh, any other use cases are subject
// to breakage and non-backwards compatible updates.
package config
import (
"errors"
"io"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/cli/go-gh/v2/internal/yamlmap"
)
const (
appData = "AppData"
ghConfigDir = "GH_CONFIG_DIR"
localAppData = "LocalAppData"
xdgConfigHome = "XDG_CONFIG_HOME"
xdgDataHome = "XDG_DATA_HOME"
xdgStateHome = "XDG_STATE_HOME"
xdgCacheHome = "XDG_CACHE_HOME"
)
var (
cfg *Config
once sync.Once
loadErr error
)
// Config is a in memory representation of the gh configuration files.
// It can be thought of as map where entries consist of a key that
// correspond to either a string value or a map value, allowing for
// multi-level maps.
type Config struct {
entries *yamlmap.Map
mu sync.RWMutex
}
// Get a string value from a Config.
// The keys argument is a sequence of key values so that nested
// entries can be retrieved. A undefined string will be returned
// if trying to retrieve a key that corresponds to a map value.
// Returns "", KeyNotFoundError if any of the keys can not be found.
func (c *Config) Get(keys []string) (string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
m, err = m.FindEntry(key)
if err != nil {
return "", &KeyNotFoundError{key}
}
}
return m.Value, nil
}
// Keys enumerates a Config's keys.
// The keys argument is a sequence of key values so that nested
// map values can be have their keys enumerated.
// Returns nil, KeyNotFoundError if any of the keys can not be found.
func (c *Config) Keys(keys []string) ([]string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
m, err = m.FindEntry(key)
if err != nil {
return nil, &KeyNotFoundError{key}
}
}
return m.Keys(), nil
}
// Remove an entry from a Config.
// The keys argument is a sequence of key values so that nested
// entries can be removed. Removing an entry that has nested
// entries removes those also.
// Returns KeyNotFoundError if any of the keys can not be found.
func (c *Config) Remove(keys []string) error {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
var err error
key := keys[i]
m, err = m.FindEntry(key)
if err != nil {
return &KeyNotFoundError{key}
}
}
err := m.RemoveEntry(keys[len(keys)-1])
if err != nil {
return &KeyNotFoundError{keys[len(keys)-1]}
}
return nil
}
// Set a string value in a Config.
// The keys argument is a sequence of key values so that nested
// entries can be set. If any of the keys do not exist they will
// be created. If the string value to be set is empty it will be
// represented as null not an empty string when written.
//
// var c *Config
// c.Set([]string{"key"}, "")
// Write(c) // writes `key: ` not `key: ""`
func (c *Config) Set(keys []string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
key := keys[i]
entry, err := m.FindEntry(key)
if err != nil {
entry = yamlmap.MapValue()
m.AddEntry(key, entry)
}
m = entry
}
val := yamlmap.StringValue(value)
if value == "" {
val = yamlmap.NullValue()
}
m.SetEntry(keys[len(keys)-1], val)
}
func (c *Config) deepCopy() *Config {
return ReadFromString(c.entries.String())
}
// Read gh configuration files from the local file system and
// returns a Config. A copy of the fallback configuration will
// be returned when there are no configuration files to load.
// If there are no configuration files and no fallback configuration
// an empty configuration will be returned.
var Read = func(fallback *Config) (*Config, error) {
once.Do(func() {
cfg, loadErr = load(generalConfigFile(), hostsConfigFile(), fallback)
})
return cfg, loadErr
}
// ReadFromString takes a yaml string and returns a Config.
func ReadFromString(str string) *Config {
m, _ := mapFromString(str)
if m == nil {
m = yamlmap.MapValue()
}
return &Config{entries: m}
}
// Write gh configuration files to the local file system.
// It will only write gh configuration files that have been modified
// since last being read.
func Write(c *Config) error {
c.mu.Lock()
defer c.mu.Unlock()
hosts, err := c.entries.FindEntry("hosts")
if err == nil && hosts.IsModified() {
err := writeFile(hostsConfigFile(), []byte(hosts.String()))
if err != nil {
return err
}
hosts.SetUnmodified()
}
if c.entries.IsModified() {
// Hosts gets written to a different file above so remove it
// before writing and add it back in after writing.
hostsMap, hostsErr := c.entries.FindEntry("hosts")
if hostsErr == nil {
_ = c.entries.RemoveEntry("hosts")
}
err := writeFile(generalConfigFile(), []byte(c.entries.String()))
if err != nil {
return err
}
c.entries.SetUnmodified()
if hostsErr == nil {
c.entries.AddEntry("hosts", hostsMap)
}
}
return nil
}
func load(generalFilePath, hostsFilePath string, fallback *Config) (*Config, error) {
generalMap, err := mapFromFile(generalFilePath)
if err != nil && !os.IsNotExist(err) {
if errors.Is(err, yamlmap.ErrInvalidYaml) ||
errors.Is(err, yamlmap.ErrInvalidFormat) {
return nil, &InvalidConfigFileError{Path: generalFilePath, Err: err}
}
return nil, err
}
if generalMap == nil {
generalMap = yamlmap.MapValue()
}
hostsMap, err := mapFromFile(hostsFilePath)
if err != nil && !os.IsNotExist(err) {
if errors.Is(err, yamlmap.ErrInvalidYaml) ||
errors.Is(err, yamlmap.ErrInvalidFormat) {
return nil, &InvalidConfigFileError{Path: hostsFilePath, Err: err}
}
return nil, err
}
if hostsMap != nil && !hostsMap.Empty() {
generalMap.AddEntry("hosts", hostsMap)
generalMap.SetUnmodified()
}
if generalMap.Empty() && fallback != nil {
return fallback.deepCopy(), nil
}
return &Config{entries: generalMap}, nil
}
func generalConfigFile() string {
return filepath.Join(ConfigDir(), "config.yml")
}
func hostsConfigFile() string {
return filepath.Join(ConfigDir(), "hosts.yml")
}
func mapFromFile(filename string) (*yamlmap.Map, error) {
data, err := readFile(filename)
if err != nil {
return nil, err
}
return yamlmap.Unmarshal(data)
}
func mapFromString(str string) (*yamlmap.Map, error) {
return yamlmap.Unmarshal([]byte(str))
}
// Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME.
func ConfigDir() string {
var path string
if a := os.Getenv(ghConfigDir); a != "" {
path = a
} else if b := os.Getenv(xdgConfigHome); b != "" {
path = filepath.Join(b, "gh")
} else if c := os.Getenv(appData); runtime.GOOS == "windows" && c != "" {
path = filepath.Join(c, "GitHub CLI")
} else {
d, _ := os.UserHomeDir()
path = filepath.Join(d, ".config", "gh")
}
return path
}
// State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME.
func StateDir() string {
var path string
if a := os.Getenv(xdgStateHome); a != "" {
path = filepath.Join(a, "gh")
} else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" {
path = filepath.Join(b, "GitHub CLI")
} else {
c, _ := os.UserHomeDir()
path = filepath.Join(c, ".local", "state", "gh")
}
return path
}
// Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME.
func DataDir() string {
var path string
if a := os.Getenv(xdgDataHome); a != "" {
path = filepath.Join(a, "gh")
} else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" {
path = filepath.Join(b, "GitHub CLI")
} else {
c, _ := os.UserHomeDir()
path = filepath.Join(c, ".local", "share", "gh")
}
return path
}
// Cache path precedence: XDG_CACHE_HOME, LocalAppData (windows only), HOME, legacy gh-cli-cache.
func CacheDir() string {
if a := os.Getenv(xdgCacheHome); a != "" {
return filepath.Join(a, "gh")
} else if b := os.Getenv(localAppData); runtime.GOOS == "windows" && b != "" {
return filepath.Join(b, "GitHub CLI")
} else if c, err := os.UserHomeDir(); err == nil {
return filepath.Join(c, ".cache", "gh")
} else {
// Note that this has a minor security issue because /tmp is world-writeable.
// As such, it is possible for other users on a shared system to overwrite cached data.
// The practical risk of this is low, but it's worth calling out as a risk.
// I've included this here for backwards compatibility but we should consider removing it.
return filepath.Join(os.TempDir(), "gh-cli-cache")
}
}
func readFile(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return nil, err
}
return data, nil
}
func writeFile(filename string, data []byte) (writeErr error) {
if writeErr = os.MkdirAll(filepath.Dir(filename), 0771); writeErr != nil {
return
}
var file *os.File
if file, writeErr = os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600); writeErr != nil {
return
}
defer func() {
if err := file.Close(); writeErr == nil && err != nil {
writeErr = err
}
}()
_, writeErr = file.Write(data)
return
}

32
vendor/github.com/cli/go-gh/v2/pkg/config/errors.go generated vendored Normal file
View File

@ -0,0 +1,32 @@
package config
import (
"fmt"
)
// InvalidConfigFileError represents an error when trying to read a config file.
type InvalidConfigFileError struct {
Path string
Err error
}
// Allow InvalidConfigFileError to satisfy error interface.
func (e *InvalidConfigFileError) Error() string {
return fmt.Sprintf("invalid config file %s: %s", e.Path, e.Err)
}
// Allow InvalidConfigFileError to be unwrapped.
func (e *InvalidConfigFileError) Unwrap() error {
return e.Err
}
// KeyNotFoundError represents an error when trying to find a config key
// that does not exist.
type KeyNotFoundError struct {
Key string
}
// Allow KeyNotFoundError to satisfy error interface.
func (e *KeyNotFoundError) Error() string {
return fmt.Sprintf("could not find key %q", e.Key)
}

25
vendor/github.com/cli/safeexec/LICENSE generated vendored Normal file
View File

@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2020, GitHub Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

40
vendor/github.com/cli/safeexec/README.md generated vendored Normal file
View File

@ -0,0 +1,40 @@
# safeexec
A Go module that provides a safer alternative to `exec.LookPath()` on Windows.
The following, relatively common approach to running external commands has a subtle vulnerability on Windows:
```go
import "os/exec"
func gitStatus() error {
// On Windows, this will result in `.\git.exe` or `.\git.bat` being executed
// if either were found in the current working directory.
cmd := exec.Command("git", "status")
return cmd.Run()
}
```
Searching the current directory (surprising behavior) before searching folders listed in the PATH environment variable (expected behavior) seems to be intended in Go and unlikely to be changed: https://github.com/golang/go/issues/38736
Since Go does not provide a version of [`exec.LookPath()`](https://golang.org/pkg/os/exec/#LookPath) that only searches PATH and does not search the current working directory, this module provides a `LookPath` function that works consistently across platforms.
Example use:
```go
import (
"os/exec"
"github.com/cli/safeexec"
)
func gitStatus() error {
gitBin, err := safeexec.LookPath("git")
if err != nil {
return err
}
cmd := exec.Command(gitBin, "status")
return cmd.Run()
}
```
## TODO
Ideally, this module would also provide `exec.Command()` and `exec.CommandContext()` equivalents that delegate to the patched version of `LookPath`. However, this doesn't seem possible since `LookPath` may return an error, while `exec.Command/CommandContext()` themselves do not return an error. In the standard library, the resulting `exec.Cmd` struct stores the LookPath error in a private field, but that functionality isn't available to us.

9
vendor/github.com/cli/safeexec/lookpath.go generated vendored Normal file
View File

@ -0,0 +1,9 @@
// +build !windows
package safeexec
import "os/exec"
func LookPath(file string) (string, error) {
return exec.LookPath(file)
}

120
vendor/github.com/cli/safeexec/lookpath_windows.go generated vendored Normal file
View File

@ -0,0 +1,120 @@
// Copyright (c) 2009 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Package safeexec provides alternatives for exec package functions to avoid
// accidentally executing binaries found in the current working directory on
// Windows.
package safeexec
import (
"os"
"os/exec"
"path/filepath"
"strings"
)
func chkStat(file string) error {
d, err := os.Stat(file)
if err != nil {
return err
}
if d.IsDir() {
return os.ErrPermission
}
return nil
}
func hasExt(file string) bool {
i := strings.LastIndex(file, ".")
if i < 0 {
return false
}
return strings.LastIndexAny(file, `:\/`) < i
}
func findExecutable(file string, exts []string) (string, error) {
if len(exts) == 0 {
return file, chkStat(file)
}
if hasExt(file) {
if chkStat(file) == nil {
return file, nil
}
}
for _, e := range exts {
if f := file + e; chkStat(f) == nil {
return f, nil
}
}
return "", os.ErrNotExist
}
// LookPath searches for an executable named file in the
// directories named by the PATH environment variable.
// If file contains a slash, it is tried directly and the PATH is not consulted.
// LookPath also uses PATHEXT environment variable to match
// a suitable candidate.
// The result may be an absolute path or a path relative to the current directory.
func LookPath(file string) (string, error) {
var exts []string
x := os.Getenv(`PATHEXT`)
if x != "" {
for _, e := range strings.Split(strings.ToLower(x), `;`) {
if e == "" {
continue
}
if e[0] != '.' {
e = "." + e
}
exts = append(exts, e)
}
} else {
exts = []string{".com", ".exe", ".bat", ".cmd"}
}
if strings.ContainsAny(file, `:\/`) {
if f, err := findExecutable(file, exts); err == nil {
return f, nil
} else {
return "", &exec.Error{file, err}
}
}
// https://github.com/golang/go/issues/38736
// if f, err := findExecutable(filepath.Join(".", file), exts); err == nil {
// return f, nil
// }
path := os.Getenv("path")
for _, dir := range filepath.SplitList(path) {
if f, err := findExecutable(filepath.Join(dir, file), exts); err == nil {
return f, nil
}
}
return "", &exec.Error{file, exec.ErrNotFound}
}

View File

@ -42,7 +42,6 @@ loop:
continue
}
var buf bytes.Buffer
for {
c, err := er.ReadByte()
if err != nil {
@ -51,7 +50,6 @@ loop:
if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' {
break
}
buf.Write([]byte(string(c)))
}
}

View File

@ -1,6 +1,7 @@
//go:build (darwin || freebsd || openbsd || netbsd || dragonfly) && !appengine
// +build darwin freebsd openbsd netbsd dragonfly
//go:build (darwin || freebsd || openbsd || netbsd || dragonfly || hurd) && !appengine && !tinygo
// +build darwin freebsd openbsd netbsd dragonfly hurd
// +build !appengine
// +build !tinygo
package isatty

View File

@ -1,5 +1,6 @@
//go:build appengine || js || nacl || wasm
// +build appengine js nacl wasm
//go:build (appengine || js || nacl || tinygo || wasm) && !windows
// +build appengine js nacl tinygo wasm
// +build !windows
package isatty

View File

@ -1,6 +1,7 @@
//go:build (linux || aix || zos) && !appengine
//go:build (linux || aix || zos) && !appengine && !tinygo
// +build linux aix zos
// +build !appengine
// +build !tinygo
package isatty

View File

@ -13,7 +13,10 @@
// golang.org/x/crypto/chacha20poly1305).
package cast5 // import "golang.org/x/crypto/cast5"
import "errors"
import (
"errors"
"math/bits"
)
const BlockSize = 8
const KeySize = 16
@ -241,19 +244,19 @@ func (c *Cipher) keySchedule(in []byte) {
// These are the three 'f' functions. See RFC 2144, section 2.2.
func f1(d, m uint32, r uint8) uint32 {
t := m + d
I := (t << r) | (t >> (32 - r))
I := bits.RotateLeft32(t, int(r))
return ((sBox[0][I>>24] ^ sBox[1][(I>>16)&0xff]) - sBox[2][(I>>8)&0xff]) + sBox[3][I&0xff]
}
func f2(d, m uint32, r uint8) uint32 {
t := m ^ d
I := (t << r) | (t >> (32 - r))
I := bits.RotateLeft32(t, int(r))
return ((sBox[0][I>>24] - sBox[1][(I>>16)&0xff]) + sBox[2][(I>>8)&0xff]) ^ sBox[3][I&0xff]
}
func f3(d, m uint32, r uint8) uint32 {
t := m - d
I := (t << r) | (t >> (32 - r))
I := bits.RotateLeft32(t, int(r))
return ((sBox[0][I>>24] + sBox[1][(I>>16)&0xff]) ^ sBox[2][(I>>8)&0xff]) - sBox[3][I&0xff]
}

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.11 && gc && !purego
// +build go1.11,gc,!purego
//go:build gc && !purego
// +build gc,!purego
package chacha20

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.11 && gc && !purego
// +build go1.11,gc,!purego
//go:build gc && !purego
// +build gc,!purego
#include "textflag.h"

View File

@ -12,7 +12,7 @@ import (
"errors"
"math/bits"
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/internal/alias"
)
const (
@ -189,7 +189,7 @@ func (s *Cipher) XORKeyStream(dst, src []byte) {
panic("chacha20: output smaller than input")
}
dst = dst[:len(src)]
if subtle.InexactOverlap(dst, src) {
if alias.InexactOverlap(dst, src) {
panic("chacha20: invalid buffer overlap")
}

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (!arm64 && !s390x && !ppc64le) || (arm64 && !go1.11) || !gc || purego
// +build !arm64,!s390x,!ppc64le arm64,!go1.11 !gc purego
//go:build (!arm64 && !s390x && !ppc64le) || !gc || purego
// +build !arm64,!s390x,!ppc64le !gc purego
package chacha20

View File

@ -5,71 +5,18 @@
// Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748.
//
// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
// in the crypto/ecdh package.
package curve25519 // import "golang.org/x/crypto/curve25519"
import (
"crypto/subtle"
"errors"
"strconv"
"golang.org/x/crypto/curve25519/internal/field"
)
// ScalarMult sets dst to the product scalar * point.
//
// Deprecated: when provided a low-order point, ScalarMult will set dst to all
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
// will return an error.
func ScalarMult(dst, scalar, point *[32]byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
scalarMult(dst, scalar, point)
}
// ScalarBaseMult sets dst to the product scalar * base where base is the
@ -78,7 +25,7 @@ func ScalarMult(dst, scalar, point *[32]byte) {
// It is recommended to use the X25519 function with Basepoint instead, as
// copying into fixed size arrays can lead to unexpected bugs.
func ScalarBaseMult(dst, scalar *[32]byte) {
ScalarMult(dst, scalar, &basePoint)
scalarBaseMult(dst, scalar)
}
const (
@ -91,21 +38,10 @@ const (
// Basepoint is the canonical Curve25519 generator.
var Basepoint []byte
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
var basePoint = [32]byte{9}
func init() { Basepoint = basePoint[:] }
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}
// X25519 returns the result of the scalar multiplication (scalar * point),
// according to RFC 7748, Section 5. scalar, point and the return value are
// slices of 32 bytes.
@ -121,26 +57,3 @@ func X25519(scalar, point []byte) ([]byte, error) {
var dst [32]byte
return x25519(&dst, scalar, point)
}
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
checkBasepoint()
ScalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
ScalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}

View File

@ -0,0 +1,105 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.20
package curve25519
import (
"crypto/subtle"
"errors"
"strconv"
"golang.org/x/crypto/curve25519/internal/field"
)
func scalarMult(dst, scalar, point *[32]byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}
func scalarBaseMult(dst, scalar *[32]byte) {
checkBasepoint()
scalarMult(dst, scalar, &basePoint)
}
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
scalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
scalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}

View File

@ -0,0 +1,46 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.20
package curve25519
import "crypto/ecdh"
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
curve := ecdh.X25519()
pub, err := curve.NewPublicKey(point)
if err != nil {
return nil, err
}
priv, err := curve.NewPrivateKey(scalar)
if err != nil {
return nil, err
}
out, err := priv.ECDH(pub)
if err != nil {
return nil, err
}
copy(dst[:], out)
return dst[:], nil
}
func scalarMult(dst, scalar, point *[32]byte) {
if _, err := x25519(dst, scalar[:], point[:]); err != nil {
// The only error condition for x25519 when the inputs are 32 bytes long
// is if the output would have been the all-zero value.
for i := range dst {
dst[i] = 0
}
}
}
func scalarBaseMult(dst, scalar *[32]byte) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(scalar[:])
if err != nil {
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
}
copy(dst[:], priv.PublicKey().Bytes())
}

View File

@ -245,7 +245,7 @@ func feSquareGeneric(v, a *Element) {
v.carryPropagate()
}
// carryPropagate brings the limbs below 52 bits by applying the reduction
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51

View File

@ -1,71 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ed25519 implements the Ed25519 signature algorithm. See
// https://ed25519.cr.yp.to/.
//
// These functions are also compatible with the “Ed25519” function defined in
// RFC 8032. However, unlike RFC 8032's formulation, this package's private key
// representation includes a public key suffix to make multiple signing
// operations with the same key more efficient. This package refers to the RFC
// 8032 private key as the “seed”.
//
// Beginning with Go 1.13, the functionality of this package was moved to the
// standard library as crypto/ed25519. This package only acts as a compatibility
// wrapper.
package ed25519
import (
"crypto/ed25519"
"io"
)
const (
// PublicKeySize is the size, in bytes, of public keys as used in this package.
PublicKeySize = 32
// PrivateKeySize is the size, in bytes, of private keys as used in this package.
PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64
// SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032.
SeedSize = 32
)
// PublicKey is the type of Ed25519 public keys.
//
// This type is an alias for crypto/ed25519's PublicKey type.
// See the crypto/ed25519 package for the methods on this type.
type PublicKey = ed25519.PublicKey
// PrivateKey is the type of Ed25519 private keys. It implements crypto.Signer.
//
// This type is an alias for crypto/ed25519's PrivateKey type.
// See the crypto/ed25519 package for the methods on this type.
type PrivateKey = ed25519.PrivateKey
// GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, crypto/rand.Reader will be used.
func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) {
return ed25519.GenerateKey(rand)
}
// NewKeyFromSeed calculates a private key from a seed. It will panic if
// len(seed) is not SeedSize. This function is provided for interoperability
// with RFC 8032. RFC 8032's private keys correspond to seeds in this
// package.
func NewKeyFromSeed(seed []byte) PrivateKey {
return ed25519.NewKeyFromSeed(seed)
}
// Sign signs the message with privateKey and returns a signature. It will
// panic if len(privateKey) is not PrivateKeySize.
func Sign(privateKey PrivateKey, message []byte) []byte {
return ed25519.Sign(privateKey, message)
}
// Verify reports whether sig is a valid signature of message by publicKey. It
// will panic if len(publicKey) is not PublicKeySize.
func Verify(publicKey PublicKey, message, sig []byte) bool {
return ed25519.Verify(publicKey, message, sig)
}

View File

@ -5,9 +5,8 @@
//go:build !purego
// +build !purego
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
// Package alias implements memory aliasing tests.
package alias
import "unsafe"

View File

@ -5,9 +5,8 @@
//go:build purego
// +build purego
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
// Package alias implements memory aliasing tests.
package alias
// This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed.

View File

@ -156,7 +156,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err error) {
n, err = r.b64Reader.Read(p)
r.currentCRC = crc24(r.currentCRC, p[:n])
if err == io.EOF && r.lReader.crcSet && r.lReader.crc != uint32(r.currentCRC&crc24Mask) {
if err == io.EOF && r.lReader.crcSet && r.lReader.crc != r.currentCRC&crc24Mask {
return 0, ArmorCorrupt
}

View File

@ -61,7 +61,7 @@ type Key struct {
type KeyRing interface {
// KeysById returns the set of keys that have the given key id.
KeysById(id uint64) []Key
// KeysByIdAndUsage returns the set of keys with the given id
// KeysByIdUsage returns the set of keys with the given id
// that also meet the key usage given by requiredUsage.
// The requiredUsage is expressed as the bitwise-OR of
// packet.KeyFlag* values.
@ -183,7 +183,7 @@ func (el EntityList) KeysById(id uint64) (keys []Key) {
return
}
// KeysByIdAndUsage returns the set of keys with the given id that also meet
// KeysByIdUsage returns the set of keys with the given id that also meet
// the key usage given by requiredUsage. The requiredUsage is expressed as
// the bitwise-OR of packet.KeyFlag* values.
func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) {

View File

@ -60,7 +60,7 @@ func (c *Compressed) parse(r io.Reader) error {
return err
}
// compressedWriterCloser represents the serialized compression stream
// compressedWriteCloser represents the serialized compression stream
// header and the compressor. Its Close() method ensures that both the
// compressor and serialized stream header are closed. Its Write()
// method writes to the compressor.

View File

@ -7,7 +7,6 @@ package packet
import (
"bytes"
"io"
"io/ioutil"
"golang.org/x/crypto/openpgp/errors"
)
@ -26,7 +25,7 @@ type OpaquePacket struct {
}
func (op *OpaquePacket) parse(r io.Reader) (err error) {
op.Contents, err = ioutil.ReadAll(r)
op.Contents, err = io.ReadAll(r)
return
}

View File

@ -13,7 +13,6 @@ import (
"crypto/rsa"
"crypto/sha1"
"io"
"io/ioutil"
"math/big"
"strconv"
"time"
@ -133,7 +132,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
}
}
pk.encryptedData, err = ioutil.ReadAll(r)
pk.encryptedData, err = io.ReadAll(r)
if err != nil {
return
}

View File

@ -236,7 +236,7 @@ func (w *seMDCWriter) Close() (err error) {
return w.w.Close()
}
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
// noOpCloser is like an io.NopCloser, but for an io.Writer.
type noOpCloser struct {
w io.Writer
}

View File

@ -9,7 +9,6 @@ import (
"image"
"image/jpeg"
"io"
"io/ioutil"
)
const UserAttrImageSubpacket = 1
@ -56,7 +55,7 @@ func NewUserAttribute(contents ...*OpaqueSubpacket) *UserAttribute {
func (uat *UserAttribute) parse(r io.Reader) (err error) {
// RFC 4880, section 5.13
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
if err != nil {
return
}

View File

@ -6,7 +6,6 @@ package packet
import (
"io"
"io/ioutil"
"strings"
)
@ -66,7 +65,7 @@ func NewUserId(name, comment, email string) *UserId {
func (uid *UserId) parse(r io.Reader) (err error) {
// RFC 4880, section 5.11
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
if err != nil {
return
}

View File

@ -268,7 +268,7 @@ func HashIdToString(id byte) (name string, ok bool) {
return "", false
}
// HashIdToHash returns an OpenPGP hash id which corresponds the given Hash.
// HashToHashId returns an OpenPGP hash id which corresponds the given Hash.
func HashToHashId(h crypto.Hash) (id byte, ok bool) {
for _, m := range hashToHashIdMapping {
if m.hash == h {

View File

@ -402,7 +402,7 @@ func (s signatureWriter) Close() error {
return s.encryptedData.Close()
}
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
// noOpCloser is like an io.NopCloser, but for an io.Writer.
// TODO: we have two of these in OpenPGP packages alone. This probably needs
// to be promoted somewhere more common.
type noOpCloser struct {

View File

@ -16,6 +16,7 @@ import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
@ -26,7 +27,6 @@ import (
"math/big"
"sync"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
@ -93,7 +93,7 @@ type ExtendedAgent interface {
type ConstraintExtension struct {
// ExtensionName consist of a UTF-8 string suffixed by the
// implementation domain following the naming scheme defined
// in Section 4.2 of [RFC4251], e.g. "foo@example.com".
// in Section 4.2 of RFC 4251, e.g. "foo@example.com".
ExtensionName string
// ExtensionDetails contains the actual content of the extended
// constraint.
@ -226,7 +226,9 @@ var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
type extensionAgentMsg struct {
ExtensionType string `sshtype:"27"`
Contents []byte
// NOTE: this matches OpenSSH's PROTOCOL.agent, not the IETF draft [PROTOCOL.agent],
// so that it matches what OpenSSH actually implements in the wild.
Contents []byte `ssh:"rest"`
}
// Key represents a protocol 2 public key as defined in
@ -729,7 +731,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
if err != nil {
return err
}
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return errors.New("agent: signer and cert have different public key")
}

View File

@ -7,6 +7,7 @@ package agent
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/binary"
@ -16,11 +17,10 @@ import (
"log"
"math/big"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
// Server wraps an Agent and uses it to implement the agent side of
// server wraps an Agent and uses it to implement the agent side of
// the SSH-agent, wire protocol.
type server struct {
agent Agent

View File

@ -16,8 +16,9 @@ import (
// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
// Unlike key algorithm names, these are not passed to AlgorithmSigner and don't
// appear in the Signature.Format field.
// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
// field.
const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
@ -251,14 +252,21 @@ type algorithmOpenSSHCertSigner struct {
// private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer.
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return nil, errors.New("ssh: signer and cert have different public key")
}
if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
switch s := signer.(type) {
case MultiAlgorithmSigner:
return &multiAlgorithmSigner{
AlgorithmSigner: &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, s},
supportedAlgorithms: s.Algorithms(),
}, nil
case AlgorithmSigner:
return &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
} else {
&openSSHCertSigner{cert, signer}, s}, nil
default:
return &openSSHCertSigner{cert, signer}, nil
}
}
@ -432,7 +440,9 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
}
// SignCert signs the certificate with an authority, setting the Nonce,
// SignatureKey, and Signature fields.
// SignatureKey, and Signature fields. If the authority implements the
// MultiAlgorithmSigner interface the first algorithm in the list is used. This
// is useful if you want to sign with a specific algorithm.
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
c.Nonce = make([]byte, 32)
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
@ -440,8 +450,20 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
}
c.SignatureKey = authority.PublicKey()
if v, ok := authority.(MultiAlgorithmSigner); ok {
if len(v.Algorithms()) == 0 {
return errors.New("the provided authority has no signature algorithm")
}
// Use the first algorithm in the list.
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
if err != nil {
return err
}
c.Signature = sig
return nil
} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
// TODO: consider using KeyAlgoRSASHA256 as default.
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
if err != nil {
return err

View File

@ -15,7 +15,6 @@ import (
"fmt"
"hash"
"io"
"io/ioutil"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/internal/poly1305"
@ -115,7 +114,8 @@ var cipherModes = map[string]*cipherMode{
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers
gcmCipherID: {16, 12, newGCMCipher},
gcm128CipherID: {16, 12, newGCMCipher},
gcm256CipherID: {32, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config.
@ -497,7 +497,7 @@ func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error)
// data, to make distinguishing between
// failing MAC and failing length check more
// difficult.
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage))
io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
}
}
return p, err

View File

@ -71,7 +71,9 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
for auth := AuthMethod(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
if err != nil {
return err
// We return the error later if there is no other method left to
// try.
ok = authFailure
}
if ok == authSuccess {
// success
@ -101,6 +103,12 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
}
}
}
if auth == nil && err != nil {
// We have an error and there are no other authentication methods to
// try, so we return it.
return err
}
}
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
}
@ -217,21 +225,45 @@ func (cb publicKeyCallback) method() string {
return "publickey"
}
func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (as AlgorithmSigner, algo string) {
func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
var as MultiAlgorithmSigner
keyFormat := signer.PublicKey().Type()
// Like in sendKexInit, if the public key implements AlgorithmSigner we
// assume it supports all algorithms, otherwise only the key format one.
as, ok := signer.(AlgorithmSigner)
if !ok {
return algorithmSignerWrapper{signer}, keyFormat
// If the signer implements MultiAlgorithmSigner we use the algorithms it
// support, if it implements AlgorithmSigner we assume it supports all
// algorithms, otherwise only the key format one.
switch s := signer.(type) {
case MultiAlgorithmSigner:
as = s
case AlgorithmSigner:
as = &multiAlgorithmSigner{
AlgorithmSigner: s,
supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
}
default:
as = &multiAlgorithmSigner{
AlgorithmSigner: algorithmSignerWrapper{signer},
supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
}
}
getFallbackAlgo := func() (string, error) {
// Fallback to use if there is no "server-sig-algs" extension or a
// common algorithm cannot be found. We use the public key format if the
// MultiAlgorithmSigner supports it, otherwise we return an error.
if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
}
return keyFormat, nil
}
extPayload, ok := extensions["server-sig-algs"]
if !ok {
// If there is no "server-sig-algs" extension, fall back to the key
// format algorithm.
return as, keyFormat
// If there is no "server-sig-algs" extension use the fallback
// algorithm.
algo, err := getFallbackAlgo()
return as, algo, err
}
// The server-sig-algs extension only carries underlying signature
@ -245,15 +277,22 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (as Alg
}
}
keyAlgos := algorithmsForKeyFormat(keyFormat)
// Filter algorithms based on those supported by MultiAlgorithmSigner.
var keyAlgos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(as.Algorithms(), underlyingAlgo(algo)) {
keyAlgos = append(keyAlgos, algo)
}
}
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
if err != nil {
// If there is no overlap, try the key anyway with the key format
// algorithm, to support servers that fail to list all supported
// algorithms.
return as, keyFormat
// If there is no overlap, return the fallback algorithm to support
// servers that fail to list all supported algorithms.
algo, err := getFallbackAlgo()
return as, algo, err
}
return as, algo
return as, algo, nil
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
@ -267,10 +306,17 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
return authFailure, nil, err
}
var methods []string
var errSigAlgo error
for _, signer := range signers {
pub := signer.PublicKey()
as, algo := pickSignatureAlgorithm(signer, extensions)
as, algo, err := pickSignatureAlgorithm(signer, extensions)
if err != nil && errSigAlgo == nil {
// If we cannot negotiate a signature algorithm store the first
// error so we can return it to provide a more meaningful message if
// no other signers work.
errSigAlgo = err
continue
}
ok, err := validateKey(pub, algo, user, c)
if err != nil {
return authFailure, nil, err
@ -317,22 +363,12 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success == authSuccess || !containsMethod(methods, cb.method()) {
if success == authSuccess || !contains(methods, cb.method()) {
return success, methods, err
}
}
return authFailure, methods, nil
}
func containsMethod(methods []string, method string) bool {
for _, m := range methods {
if m == method {
return true
}
}
return false
return authFailure, methods, errSigAlgo
}
// validateKey validates the key provided is acceptable to the server.

View File

@ -10,6 +10,7 @@ import (
"fmt"
"io"
"math"
"strings"
"sync"
_ "crypto/sha1"
@ -27,7 +28,7 @@ const (
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com",
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
@ -36,7 +37,7 @@ var supportedCiphers = []string{
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com",
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
}
@ -48,7 +49,8 @@ var supportedKexAlgos = []string{
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH14SHA1, kexAlgoDH1SHA1,
kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
kexAlgoDH1SHA1,
}
// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
@ -58,8 +60,9 @@ var serverForbiddenKexAlgos = map[string]struct{}{
kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
}
// preferredKexAlgos specifies the default preference for key-exchange algorithms
// in preference order.
// preferredKexAlgos specifies the default preference for key-exchange
// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
// is disabled by default because it is a bit slower than the others.
var preferredKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
@ -69,12 +72,12 @@ var preferredKexAlgos = []string{
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSASHA512v01, CertAlgoRSASHA256v01,
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA512, KeyAlgoRSASHA256,
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
KeyAlgoRSA, KeyAlgoDSA,
KeyAlgoED25519,
@ -84,7 +87,7 @@ var supportedHostKeyAlgos = []string{
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
}
var supportedCompressions = []string{compressionNone}
@ -118,6 +121,27 @@ func algorithmsForKeyFormat(keyFormat string) []string {
}
}
// isRSA returns whether algo is a supported RSA algorithm, including certificate
// algorithms.
func isRSA(algo string) bool {
algos := algorithmsForKeyFormat(KeyAlgoRSA)
return contains(algos, underlyingAlgo(algo))
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if
// it supports the server-sig-algs extension. Order is irrelevant.
var supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
KeyAlgoDSA,
}
var supportedPubKeyAuthAlgosList = strings.Join(supportedPubKeyAuthAlgos, ",")
// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
@ -153,7 +177,7 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
return 16 * (1 << 32)
}
@ -163,7 +187,8 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
}
var aeadCiphers = map[string]bool{
gcmCipherID: true,
gcm128CipherID: true,
gcm256CipherID: true,
chacha20Poly1305ID: true,
}
@ -246,16 +271,16 @@ type Config struct {
// unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a
// default set of algorithms is used.
// The allowed key exchanges algorithms. If unspecified then a default set
// of algorithms is used. Unsupported values are silently ignored.
KeyExchanges []string
// The allowed cipher algorithms. If unspecified then a sensible
// default is used.
// The allowed cipher algorithms. If unspecified then a sensible default is
// used. Unsupported values are silently ignored.
Ciphers []string
// The allowed MAC algorithms. If unspecified then a sensible default
// is used.
// The allowed MAC algorithms. If unspecified then a sensible default is
// used. Unsupported values are silently ignored.
MACs []string
}
@ -272,7 +297,7 @@ func (c *Config) SetDefaults() {
var ciphers []string
for _, c := range c.Ciphers {
if cipherModes[c] != nil {
// reject the cipher if we have no cipherModes definition
// Ignore the cipher if we have no cipherModes definition.
ciphers = append(ciphers, c)
}
}
@ -281,10 +306,26 @@ func (c *Config) SetDefaults() {
if c.KeyExchanges == nil {
c.KeyExchanges = preferredKexAlgos
}
var kexs []string
for _, k := range c.KeyExchanges {
if kexAlgoMap[k] != nil {
// Ignore the KEX if we have no kexAlgoMap definition.
kexs = append(kexs, k)
}
}
c.KeyExchanges = kexs
if c.MACs == nil {
c.MACs = supportedMACs
}
var macs []string
for _, m := range c.MACs {
if macModes[m] != nil {
// Ignore the MAC if we have no macModes definition.
macs = append(macs, m)
}
}
c.MACs = macs
if c.RekeyThreshold == 0 {
// cipher specific default

View File

@ -97,7 +97,7 @@ func (c *connection) Close() error {
return c.sshConn.conn.Close()
}
// sshconn provides net.Conn metadata, but disallows direct reads and
// sshConn provides net.Conn metadata, but disallows direct reads and
// writes.
type sshConn struct {
conn net.Conn

View File

@ -13,6 +13,7 @@ others.
References:
[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1

View File

@ -63,6 +63,8 @@ type handshakeTransport struct {
sentInitPacket []byte
sentInitMsg *kexInitMsg
pendingPackets [][]byte // Used when a key exchange is in progress.
writePacketsLeft uint32
writeBytesLeft int64
// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
@ -72,6 +74,7 @@ type handshakeTransport struct {
// If the other side requests or confirms a kex, its kexInit
// packet is sent here for the write loop to find it.
startKex chan *pendingKex
kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
// data for host key checking
hostKeyCallback HostKeyCallback
@ -86,12 +89,10 @@ type handshakeTransport struct {
// Algorithms agreed in the last key exchange.
algorithms *algorithms
// Counters exclusively owned by readLoop.
readPacketsLeft uint32
readBytesLeft int64
writePacketsLeft uint32
writeBytesLeft int64
// The session ID or nil if first kex did not complete yet.
sessionID []byte
}
@ -108,7 +109,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
clientVersion: clientVersion,
incoming: make(chan []byte, chanSize),
requestKex: make(chan struct{}, 1),
startKex: make(chan *pendingKex, 1),
startKex: make(chan *pendingKex),
kexLoopDone: make(chan struct{}),
config: config,
}
@ -340,16 +342,17 @@ write:
t.mu.Unlock()
}
// drain startKex channel. We don't service t.requestKex
// because nobody does blocking sends there.
go func() {
for init := range t.startKex {
init.done <- t.writeError
}
}()
// Unblock reader.
t.conn.Close()
// drain startKex channel. We don't service t.requestKex
// because nobody does blocking sends there.
for request := range t.startKex {
request.done <- t.getWriteError()
}
// Mark that the loop is done so that Close can return.
close(t.kexLoopDone)
}
// The protocol uses uint32 for packet counters, so we can't let them
@ -458,19 +461,24 @@ func (t *handshakeTransport) sendKexInit() error {
isServer := len(t.hostKeys) > 0
if isServer {
for _, k := range t.hostKeys {
// If k is an AlgorithmSigner, presume it supports all signature algorithms
// associated with the key format. (Ideally AlgorithmSigner would have a
// method to advertise supported algorithms, but it doesn't. This means that
// adding support for a new algorithm is a breaking change, as we will
// immediately negotiate it even if existing implementations don't support
// it. If that ever happens, we'll have to figure something out.)
// If k is not an AlgorithmSigner, we can only assume it only supports the
// algorithms that matches the key format. (This means that Sign can't pick
// a different default.)
// If k is a MultiAlgorithmSigner, we restrict the signature
// algorithms. If k is a AlgorithmSigner, presume it supports all
// signature algorithms associated with the key format. If k is not
// an AlgorithmSigner, we can only assume it only supports the
// algorithms that matches the key format. (This means that Sign
// can't pick a different default).
keyFormat := k.PublicKey().Type()
if _, ok := k.(AlgorithmSigner); ok {
switch s := k.(type) {
case MultiAlgorithmSigner:
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(s.Algorithms(), underlyingAlgo(algo)) {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
}
}
case AlgorithmSigner:
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
} else {
default:
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
}
}
@ -545,7 +553,16 @@ func (t *handshakeTransport) writePacket(p []byte) error {
}
func (t *handshakeTransport) Close() error {
return t.conn.Close()
// Close the connection. This should cause the readLoop goroutine to wake up
// and close t.startKex, which will shut down kexLoop if running.
err := t.conn.Close()
// Wait for the kexLoop goroutine to complete.
// At that point we know that the readLoop goroutine is complete too,
// because kexLoop itself waits for readLoop to close the startKex channel.
<-t.kexLoopDone
return err
}
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
@ -615,7 +632,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err
}
if t.sessionID == nil {
firstKeyExchange := t.sessionID == nil
if firstKeyExchange {
t.sessionID = result.H
}
result.SessionID = t.sessionID
@ -626,6 +644,28 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
// message with the server-sig-algs extension if the client supports it. See
// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
extInfo := &extInfoMsg{
NumExtensions: 2,
Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
}
extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
extInfo.Payload = appendInt(extInfo.Payload, 1)
extInfo.Payload = append(extInfo.Payload, "0"...)
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
return err
}
}
if packet, err := t.conn.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
@ -654,9 +694,16 @@ func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, a
func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
for _, k := range hostKeys {
if s, ok := k.(MultiAlgorithmSigner); ok {
if !contains(s.Algorithms(), underlyingAlgo(algo)) {
continue
}
}
if algo == k.PublicKey().Type() {
return algorithmSignerWrapper{k}
}
k, ok := k.(AlgorithmSigner)
if !ok {
continue

View File

@ -23,6 +23,7 @@ const (
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
@ -430,6 +431,17 @@ func init() {
hashFunc: crypto.SHA256,
}
// This is the group called diffie-hellman-group16-sha512 in RFC
// 8268 and Oakley Group 16 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA512,
}
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}

View File

@ -11,13 +11,16 @@ import (
"crypto/cipher"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/pem"
"errors"
@ -26,7 +29,6 @@ import (
"math/big"
"strings"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh/internal/bcrypt_pbkdf"
)
@ -184,7 +186,7 @@ func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey
return "", nil, nil, "", nil, io.EOF
}
// ParseAuthorizedKeys parses a public key from an authorized_keys
// ParseAuthorizedKey parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
for len(in) > 0 {
@ -295,6 +297,18 @@ func MarshalAuthorizedKey(key PublicKey) []byte {
return b.Bytes()
}
// MarshalPrivateKey returns a PEM block with the private key serialized in the
// OpenSSH format.
func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) {
return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler)
}
// MarshalPrivateKeyWithPassphrase returns a PEM block holding the encrypted
// private key serialized in the OpenSSH format.
func MarshalPrivateKeyWithPassphrase(key crypto.PrivateKey, comment string, passphrase []byte) (*pem.Block, error) {
return marshalOpenSSHPrivateKey(key, comment, passphraseProtectedOpenSSHMarshaler(passphrase))
}
// PublicKey represents a public key using an unspecified algorithm.
//
// Some PublicKeys provided by this package also implement CryptoPublicKey.
@ -321,7 +335,7 @@ type CryptoPublicKey interface {
// A Signer can create signatures that verify against a public key.
//
// Some Signers provided by this package also implement AlgorithmSigner.
// Some Signers provided by this package also implement MultiAlgorithmSigner.
type Signer interface {
// PublicKey returns the associated PublicKey.
PublicKey() PublicKey
@ -336,9 +350,9 @@ type Signer interface {
// An AlgorithmSigner is a Signer that also supports specifying an algorithm to
// use for signing.
//
// An AlgorithmSigner can't advertise the algorithms it supports, so it should
// be prepared to be invoked with every algorithm supported by the public key
// format.
// An AlgorithmSigner can't advertise the algorithms it supports, unless it also
// implements MultiAlgorithmSigner, so it should be prepared to be invoked with
// every algorithm supported by the public key format.
type AlgorithmSigner interface {
Signer
@ -349,6 +363,75 @@ type AlgorithmSigner interface {
SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
}
// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms
// supported by that signer.
type MultiAlgorithmSigner interface {
AlgorithmSigner
// Algorithms returns the available algorithms in preference order. The list
// must not be empty, and it must not include certificate types.
Algorithms() []string
}
// NewSignerWithAlgorithms returns a signer restricted to the specified
// algorithms. The algorithms must be set in preference order. The list must not
// be empty, and it must not include certificate types. An error is returned if
// the specified algorithms are incompatible with the public key type.
func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) {
if len(algorithms) == 0 {
return nil, errors.New("ssh: please specify at least one valid signing algorithm")
}
var signerAlgos []string
supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type()))
if s, ok := signer.(*multiAlgorithmSigner); ok {
signerAlgos = s.Algorithms()
} else {
signerAlgos = supportedAlgos
}
for _, algo := range algorithms {
if !contains(supportedAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q",
algo, signer.PublicKey().Type())
}
if !contains(signerAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo)
}
}
return &multiAlgorithmSigner{
AlgorithmSigner: signer,
supportedAlgorithms: algorithms,
}, nil
}
type multiAlgorithmSigner struct {
AlgorithmSigner
supportedAlgorithms []string
}
func (s *multiAlgorithmSigner) Algorithms() []string {
return s.supportedAlgorithms
}
func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool {
if algorithm == "" {
algorithm = underlyingAlgo(s.PublicKey().Type())
}
for _, algo := range s.supportedAlgorithms {
if algorithm == algo {
return true
}
}
return false
}
func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if !s.isAlgorithmSupported(algorithm) {
return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms)
}
return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
}
type rsaPublicKey rsa.PublicKey
func (r *rsaPublicKey) Type() string {
@ -512,6 +595,10 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
return k.SignWithAlgorithm(rand, data, k.PublicKey().Type())
}
func (k *dsaPrivateKey) Algorithms() []string {
return []string{k.PublicKey().Type()}
}
func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm != "" && algorithm != k.PublicKey().Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
@ -961,13 +1048,16 @@ func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.SignWithAlgorithm(rand, data, s.pubKey.Type())
}
func (s *wrappedSigner) Algorithms() []string {
return algorithmsForKeyFormat(s.pubKey.Type())
}
func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm == "" {
algorithm = s.pubKey.Type()
}
supportedAlgos := algorithmsForKeyFormat(s.pubKey.Type())
if !contains(supportedAlgos, algorithm) {
if !contains(s.Algorithms(), algorithm) {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type())
}
@ -1087,9 +1177,9 @@ func (*PassphraseMissingError) Error() string {
return "ssh: this private key is passphrase protected"
}
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the
// private key is encrypted, it will return a PassphraseMissingError.
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports
// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH
// formats. If the private key is encrypted, it will return a PassphraseMissingError.
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
@ -1241,20 +1331,54 @@ func passphraseProtectedOpenSSHKey(passphrase []byte) openSSHDecryptFunc {
}
}
type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
// as the decrypt function to parse an unencrypted private key. See
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
const magic = "openssh-key-v1\x00"
if len(key) < len(magic) || string(key[:len(magic)]) != magic {
return nil, errors.New("ssh: invalid openssh private key format")
func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) {
key := generateOpenSSHPadding(privKeyBlock, 8)
return key, "none", "none", "", nil
}
remaining := key[len(magic):]
var w struct {
func passphraseProtectedOpenSSHMarshaler(passphrase []byte) openSSHEncryptFunc {
return func(privKeyBlock []byte) ([]byte, string, string, string, error) {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return nil, "", "", "", err
}
opts := struct {
Salt []byte
Rounds uint32
}{salt, 16}
// Derive key to encrypt the private key block.
k, err := bcrypt_pbkdf.Key(passphrase, salt, int(opts.Rounds), 32+aes.BlockSize)
if err != nil {
return nil, "", "", "", err
}
// Add padding matching the block size of AES.
keyBlock := generateOpenSSHPadding(privKeyBlock, aes.BlockSize)
// Encrypt the private key using the derived secret.
dst := make([]byte, len(keyBlock))
key, iv := k[:32], k[32:]
block, err := aes.NewCipher(key)
if err != nil {
return nil, "", "", "", err
}
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(dst, keyBlock)
return dst, "aes256-ctr", "bcrypt", string(Marshal(opts)), nil
}
}
const privateKeyAuthMagic = "openssh-key-v1\x00"
type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error)
type openSSHEncryptedPrivateKey struct {
CipherName string
KdfName string
KdfOpts string
@ -1263,6 +1387,50 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
PrivKeyBlock []byte
}
type openSSHPrivateKey struct {
Check1 uint32
Check2 uint32
Keytype string
Rest []byte `ssh:"rest"`
}
type openSSHRSAPrivateKey struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}
type openSSHEd25519PrivateKey struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}
type openSSHECDSAPrivateKey struct {
Curve string
Pub []byte
D *big.Int
Comment string
Pad []byte `ssh:"rest"`
}
// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
// as the decrypt function to parse an unencrypted private key. See
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic {
return nil, errors.New("ssh: invalid openssh private key format")
}
remaining := key[len(privateKeyAuthMagic):]
var w openSSHEncryptedPrivateKey
if err := Unmarshal(remaining, &w); err != nil {
return nil, err
}
@ -1284,13 +1452,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
return nil, err
}
pk1 := struct {
Check1 uint32
Check2 uint32
Keytype string
Rest []byte `ssh:"rest"`
}{}
var pk1 openSSHPrivateKey
if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 {
if w.CipherName != "none" {
return nil, x509.IncorrectPasswordError
@ -1300,18 +1462,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
switch pk1.Keytype {
case KeyAlgoRSA:
// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
key := struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
var key openSSHRSAPrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
@ -1337,13 +1488,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
return pk, nil
case KeyAlgoED25519:
key := struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{}
var key openSSHEd25519PrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
@ -1360,14 +1505,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
copy(pk, key.Priv)
return &pk, nil
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
key := struct {
Curve string
Pub []byte
D *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
var key openSSHECDSAPrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
@ -1415,6 +1553,131 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
}
}
func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) {
var w openSSHEncryptedPrivateKey
var pk1 openSSHPrivateKey
// Random check bytes.
var check uint32
if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil {
return nil, err
}
pk1.Check1 = check
pk1.Check2 = check
w.NumKeys = 1
// Use a []byte directly on ed25519 keys.
if k, ok := key.(*ed25519.PrivateKey); ok {
key = *k
}
switch k := key.(type) {
case *rsa.PrivateKey:
E := new(big.Int).SetInt64(int64(k.PublicKey.E))
// Marshal public key:
// E and N are in reversed order in the public and private key.
pubKey := struct {
KeyType string
E *big.Int
N *big.Int
}{
KeyAlgoRSA,
E, k.PublicKey.N,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHRSAPrivateKey{
N: k.PublicKey.N,
E: E,
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comment: comment,
}
pk1.Keytype = KeyAlgoRSA
pk1.Rest = Marshal(key)
case ed25519.PrivateKey:
pub := make([]byte, ed25519.PublicKeySize)
priv := make([]byte, ed25519.PrivateKeySize)
copy(pub, k[32:])
copy(priv, k)
// Marshal public key.
pubKey := struct {
KeyType string
Pub []byte
}{
KeyAlgoED25519, pub,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHEd25519PrivateKey{
Pub: pub,
Priv: priv,
Comment: comment,
}
pk1.Keytype = KeyAlgoED25519
pk1.Rest = Marshal(key)
case *ecdsa.PrivateKey:
var curve, keyType string
switch name := k.Curve.Params().Name; name {
case "P-256":
curve = "nistp256"
keyType = KeyAlgoECDSA256
case "P-384":
curve = "nistp384"
keyType = KeyAlgoECDSA384
case "P-521":
curve = "nistp521"
keyType = KeyAlgoECDSA521
default:
return nil, errors.New("ssh: unhandled elliptic curve " + name)
}
pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y)
// Marshal public key.
pubKey := struct {
KeyType string
Curve string
Pub []byte
}{
keyType, curve, pub,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHECDSAPrivateKey{
Curve: curve,
Pub: pub,
D: k.D,
Comment: comment,
}
pk1.Keytype = keyType
pk1.Rest = Marshal(key)
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
}
var err error
// Add padding and encrypt the key if necessary.
w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1))
if err != nil {
return nil, err
}
b := Marshal(w)
block := &pem.Block{
Type: "OPENSSH PRIVATE KEY",
Bytes: append([]byte(privateKeyAuthMagic), b...),
}
return block, nil
}
func checkOpenSSHKeyPadding(pad []byte) error {
for i, b := range pad {
if int(b) != i+1 {
@ -1424,6 +1687,13 @@ func checkOpenSSHKeyPadding(pad []byte) error {
return nil
}
func generateOpenSSHPadding(block []byte, blockSize int) []byte {
for i, l := 0, len(block); (l+i)%blockSize != 0; i++ {
block = append(block, byte(i+1))
}
return block
}
// FingerprintLegacyMD5 returns the user presentation of the key's
// fingerprint as described by RFC 4716 section 4.
func FingerprintLegacyMD5(pubKey PublicKey) string {

View File

@ -142,7 +142,7 @@ func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
// IsAuthorityForHost can be used as a callback in ssh.CertChecker
// IsHostAuthority can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {

View File

@ -10,6 +10,7 @@ import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"hash"
)
@ -46,9 +47,15 @@ func (t truncatingMAC) Size() int {
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{
"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},

View File

@ -68,7 +68,7 @@ type kexInitMsg struct {
// See RFC 4253, section 8.
// Diffie-Helman
// Diffie-Hellman
const msgKexDHInit = 30
type kexDHInitMsg struct {
@ -349,6 +349,20 @@ type userAuthGSSAPIError struct {
LanguageTag string
}
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
const msgPing = 192
type pingMsg struct {
Data string `sshtype:"192"`
}
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
const msgPong = 193
type pongMsg struct {
Data string `sshtype:"193"`
}
// typeTags returns the possible type bytes for the given reflect.Type, which
// should be a struct. The possible values are separated by a '|' character.
func typeTags(structType reflect.Type) (tags []byte) {

View File

@ -231,6 +231,12 @@ func (m *mux) onePacket() error {
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
return m.handleGlobalPacket(packet)
case msgPing:
var msg pingMsg
if err := Unmarshal(packet, &msg); err != nil {
return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
}
return m.sendMessage(pongMsg(msg))
}
// assume a channel packet.

View File

@ -68,8 +68,16 @@ type ServerConfig struct {
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
// To determine NoClientAuth at runtime, set NoClientAuth to true
// and the optional NoClientAuthCallback to a non-nil value.
NoClientAuth bool
// NoClientAuthCallback, if non-nil, is called when a user
// attempts to authenticate with auth method "none".
// NoClientAuth must also be set to true for this be used, or
// this func is unused.
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
@ -283,15 +291,6 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
return perms, err
}
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
return true
}
return false
}
func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required")
@ -371,6 +370,25 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
return authErr, perms, nil
}
// isAlgoCompatible checks if the signature format is compatible with the
// selected algorithm taking into account edge cases that occur with old
// clients.
func isAlgoCompatible(algo, sigFormat string) bool {
// Compatibility for old clients.
//
// For certificate authentication with OpenSSH 7.2-7.7 signature format can
// be rsa-sha2-256 or rsa-sha2-512 for the algorithm
// ssh-rsa-cert-v01@openssh.com.
//
// With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512
// for signature format ssh-rsa.
if isRSA(algo) && isRSA(sigFormat) {
return true
}
// Standard case: the underlying algorithm must match the signature format.
return underlyingAlgo(algo) == sigFormat
}
// ServerAuthError represents server authentication errors and is
// sometimes returned by NewServerConn. It appends any authentication
// errors that may occur, and is returned if all of the authentication
@ -455,8 +473,12 @@ userAuthLoop:
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s)
} else {
authErr = nil
}
}
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
@ -502,7 +524,7 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest)
}
algo := string(algoBytes)
if !isAcceptableAlgo(algo) {
if !contains(supportedPubKeyAuthAlgos, underlyingAlgo(algo)) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
break
}
@ -554,17 +576,26 @@ userAuthLoop:
if !ok || len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
// Ensure the declared public key algo is compatible with the
// decoded one. This check will ensure we don't accept e.g.
// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
// key type. The algorithm and public key type must be
// consistent: both must be certificate algorithms, or neither.
if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
pubKey.Type(), algo)
break
}
// Ensure the public key algo and signature algo
// are supported. Compare the private key
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !isAcceptableAlgo(sig.Format) {
if !contains(supportedPubKeyAuthAlgos, sig.Format) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
break
}
if underlyingAlgo(algo) != sig.Format {
if !isAlgoCompatible(algo, sig.Format) {
authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
break
}

View File

@ -13,7 +13,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
)
@ -124,7 +123,7 @@ type Session struct {
// output and error.
//
// If either is nil, Run connects the corresponding file
// descriptor to an instance of ioutil.Discard. There is a
// descriptor to an instance of io.Discard. There is a
// fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote
// command to block.
@ -506,7 +505,7 @@ func (s *Session) stdout() {
return
}
if s.Stdout == nil {
s.Stdout = ioutil.Discard
s.Stdout = io.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, s.ch)
@ -519,7 +518,7 @@ func (s *Session) stderr() {
return
}
if s.Stderr == nil {
s.Stderr = ioutil.Discard
s.Stderr = io.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, s.ch.Stderr())

View File

@ -17,7 +17,8 @@ import (
const debugTransport = false
const (
gcmCipherID = "aes128-gcm@openssh.com"
gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)

View File

@ -289,7 +289,7 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter,
case AuthMethodNotRequired:
return nil
case AuthMethodUsernamePassword:
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
return errors.New("invalid username/password")
}
b := []byte{authUsernamePasswordVersion}

24
vendor/modules.txt vendored
View File

@ -14,6 +14,15 @@ github.com/bahlo/generic-list-go
# github.com/buger/jsonparser v1.1.1
## explicit; go 1.13
github.com/buger/jsonparser
# github.com/cli/go-gh/v2 v2.9.0
## explicit; go 1.21
github.com/cli/go-gh/v2/internal/set
github.com/cli/go-gh/v2/internal/yamlmap
github.com/cli/go-gh/v2/pkg/auth
github.com/cli/go-gh/v2/pkg/config
# github.com/cli/safeexec v1.0.0
## explicit; go 1.15
github.com/cli/safeexec
# github.com/cloudfoundry/jibber_jabber v0.0.0-20151120183258-bcc4c8345a21
## explicit
github.com/cloudfoundry/jibber_jabber
@ -212,11 +221,11 @@ github.com/lucasb-eyer/go-colorful
## explicit; go 1.12
github.com/mailru/easyjson/buffer
github.com/mailru/easyjson/jwriter
# github.com/mattn/go-colorable v0.1.11
## explicit; go 1.13
# github.com/mattn/go-colorable v0.1.13
## explicit; go 1.15
github.com/mattn/go-colorable
# github.com/mattn/go-isatty v0.0.14
## explicit; go 1.12
# github.com/mattn/go-isatty v0.0.20
## explicit; go 1.15
github.com/mattn/go-isatty
# github.com/mattn/go-runewidth v0.0.15
## explicit; go 1.9
@ -284,16 +293,15 @@ github.com/xanzy/ssh-agent
# github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778
## explicit; go 1.15
github.com/xo/terminfo
# golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
# golang.org/x/crypto v0.14.0
## explicit; go 1.17
golang.org/x/crypto/blowfish
golang.org/x/crypto/cast5
golang.org/x/crypto/chacha20
golang.org/x/crypto/curve25519
golang.org/x/crypto/curve25519/internal/field
golang.org/x/crypto/ed25519
golang.org/x/crypto/internal/alias
golang.org/x/crypto/internal/poly1305
golang.org/x/crypto/internal/subtle
golang.org/x/crypto/openpgp
golang.org/x/crypto/openpgp/armor
golang.org/x/crypto/openpgp/elgamal
@ -308,7 +316,7 @@ golang.org/x/crypto/ssh/knownhosts
## explicit; go 1.18
golang.org/x/exp/constraints
golang.org/x/exp/slices
# golang.org/x/net v0.7.0
# golang.org/x/net v0.17.0
## explicit; go 1.17
golang.org/x/net/context
golang.org/x/net/internal/socks