server: reject unexpected auth hosts (#13738)

Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow.

Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com>

* gofmt

---------

Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
This commit is contained in:
Bruce MacDonald
2026-01-16 14:10:36 -05:00
committed by GitHub
parent aad3f03890
commit 7601f0e93e
4 changed files with 123 additions and 5 deletions

View File

@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil
}
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
redirectURL, err := challenge.URL()
if err != nil {
return "", err
}
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

113
server/auth_test.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}, base.Host)
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}, base.Host)
}
return transfer.Upload(ctx, transfer.UploadOptions{
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return nil, err
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return err
}