diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 7208e69334a..074701eb5de 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -1804,13 +1804,28 @@ func fipsAllowChain(chain []*x509.Certificate) bool { return true } -// anyUnexpiredChain reports if at least one of verifiedChains is still -// unexpired. If verifiedChains is empty, it returns false. -func anyUnexpiredChain(verifiedChains [][]*x509.Certificate, now time.Time) bool { +// anyValidVerifiedChain reports if at least one of the chains in verifiedChains +// is valid, as indicated by none of the certificates being expired and the root +// being in opts.Roots (or in the system root pool if opts.Roots is nil). If +// verifiedChains is empty, it returns false. +func anyValidVerifiedChain(verifiedChains [][]*x509.Certificate, opts x509.VerifyOptions) bool { for _, chain := range verifiedChains { - if len(chain) != 0 && !slices.ContainsFunc(chain, func(cert *x509.Certificate) bool { - return now.Before(cert.NotBefore) || now.After(cert.NotAfter) // cert is expired + if len(chain) == 0 { + continue + } + if slices.ContainsFunc(chain, func(cert *x509.Certificate) bool { + return opts.CurrentTime.Before(cert.NotBefore) || opts.CurrentTime.After(cert.NotAfter) }) { + continue + } + // Since we already validated the chain, we only care that it is + // rooted in a CA in CAs, or in the system pool. On platforms where + // we control chain validation (e.g. not Windows or macOS) this is a + // simple lookup in the CertPool internal hash map. On other + // platforms, this may be more expensive, depending on how they + // implement verification of just root certificates. + root := chain[len(chain)-1] + if _, err := root.Verify(opts); err == nil { return true } } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 3752df09b6d..b1bbff23632 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -444,7 +444,12 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( // application from a faulty ClientSessionCache implementation. return nil, nil, nil, nil } - if !anyUnexpiredChain(session.verifiedChains, c.config.time()) { + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.RootCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + if !anyValidVerifiedChain(session.verifiedChains, opts) { // No valid chains, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) return nil, nil, nil, nil diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index d4d05e54629..2f321f85946 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -523,8 +523,13 @@ func (hs *serverHandshakeState) checkForResumption() error { if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) { return nil } + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.ClientCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven && - !anyUnexpiredChain(sessionState.verifiedChains, c.config.time()) { + !anyValidVerifiedChain(sessionState.verifiedChains, opts) { return nil } diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 406dab48eef..249767b1e46 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -2243,3 +2243,217 @@ func testHandshakeChainExpiryResumption(t *testing.T, version uint16) { testExpiration("LeafExpiresBeforeRoot", now.Add(2*time.Hour), now.Add(3*time.Hour)) testExpiration("LeafExpiresAfterRoot", now.Add(2*time.Hour), now.Add(time.Hour)) } + +func TestHandshakeGetConfigForClientDifferentClientCAs(t *testing.T) { + t.Run("TLS1.2", func(t *testing.T) { + testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS12) + }) + t.Run("TLS1.3", func(t *testing.T) { + testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS13) + }) +} + +func testHandshakeGetConfigForClientDifferentClientCAs(t *testing.T, version uint16) { + now := time.Now() + tmpl := &x509.Certificate{ + Subject: pkix.Name{CommonName: "root"}, + NotBefore: now.Add(-time.Hour * 24), + NotAfter: now.Add(time.Hour * 24), + IsCA: true, + BasicConstraintsValid: true, + } + rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + rootA, err := x509.ParseCertificate(rootDER) + if err != nil { + t.Fatalf("ParseCertificate: %v", err) + } + rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + rootB, err := x509.ParseCertificate(rootDER) + if err != nil { + t.Fatalf("ParseCertificate: %v", err) + } + + tmpl = &x509.Certificate{ + Subject: pkix.Name{}, + DNSNames: []string{"example.com"}, + NotBefore: now.Add(-time.Hour * 24), + NotAfter: now.Add(time.Hour * 24), + KeyUsage: x509.KeyUsageDigitalSignature, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = version + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{certDER}, + PrivateKey: testECDSAPrivateKey, + }} + serverConfig.Time = func() time.Time { + return now + } + serverConfig.ClientCAs = x509.NewCertPool() + serverConfig.ClientCAs.AddCert(rootA) + serverConfig.ClientAuth = RequireAndVerifyClientCert + switchConfig := false + serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) { + if !switchConfig { + return nil, nil + } + cfg := serverConfig.Clone() + cfg.ClientCAs = x509.NewCertPool() + cfg.ClientCAs.AddCert(rootB) + return cfg, nil + } + serverConfig.InsecureSkipVerify = false + serverConfig.ServerName = "example.com" + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = version + clientConfig.Certificates = []Certificate{{ + Certificate: [][]byte{certDER}, + PrivateKey: testECDSAPrivateKey, + }} + clientConfig.ClientSessionCache = NewLRUClientSessionCache(32) + clientConfig.RootCAs = x509.NewCertPool() + clientConfig.RootCAs.AddCert(rootA) + clientConfig.Time = func() time.Time { + return now + } + clientConfig.InsecureSkipVerify = false + clientConfig.ServerName = "example.com" + + testResume := func(t *testing.T, sc, cc *Config, expectResume bool) { + t.Helper() + ss, cs, err := testHandshake(t, cc, sc) + if err != nil { + t.Fatalf("handshake: %v", err) + } + if cs.DidResume != expectResume { + t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume) + } + if ss.DidResume != expectResume { + t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume) + } + } + + testResume(t, serverConfig, clientConfig, false) + testResume(t, serverConfig, clientConfig, true) + + // Cause GetConfigForClient to return a config cloned from the base config, + // but with a different ClientCAs pool. This should cause resumption to fail. + switchConfig = true + + testResume(t, serverConfig, clientConfig, false) + testResume(t, serverConfig, clientConfig, true) +} + +func TestHandshakeChangeRootCAsResumption(t *testing.T) { + t.Run("TLS1.2", func(t *testing.T) { + testHandshakeChangeRootCAsResumption(t, VersionTLS12) + }) + t.Run("TLS1.3", func(t *testing.T) { + testHandshakeChangeRootCAsResumption(t, VersionTLS13) + }) +} + +func testHandshakeChangeRootCAsResumption(t *testing.T, version uint16) { + now := time.Now() + tmpl := &x509.Certificate{ + Subject: pkix.Name{CommonName: "root"}, + NotBefore: now.Add(-time.Hour * 24), + NotAfter: now.Add(time.Hour * 24), + IsCA: true, + BasicConstraintsValid: true, + } + rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + rootA, err := x509.ParseCertificate(rootDER) + if err != nil { + t.Fatalf("ParseCertificate: %v", err) + } + rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + rootB, err := x509.ParseCertificate(rootDER) + if err != nil { + t.Fatalf("ParseCertificate: %v", err) + } + + tmpl = &x509.Certificate{ + Subject: pkix.Name{}, + DNSNames: []string{"example.com"}, + NotBefore: now.Add(-time.Hour * 24), + NotAfter: now.Add(time.Hour * 24), + KeyUsage: x509.KeyUsageDigitalSignature, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = version + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{certDER}, + PrivateKey: testECDSAPrivateKey, + }} + serverConfig.Time = func() time.Time { + return now + } + serverConfig.ClientCAs = x509.NewCertPool() + serverConfig.ClientCAs.AddCert(rootA) + serverConfig.ClientAuth = RequireAndVerifyClientCert + serverConfig.InsecureSkipVerify = false + serverConfig.ServerName = "example.com" + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = version + clientConfig.Certificates = []Certificate{{ + Certificate: [][]byte{certDER}, + PrivateKey: testECDSAPrivateKey, + }} + clientConfig.ClientSessionCache = NewLRUClientSessionCache(32) + clientConfig.RootCAs = x509.NewCertPool() + clientConfig.RootCAs.AddCert(rootA) + clientConfig.Time = func() time.Time { + return now + } + clientConfig.InsecureSkipVerify = false + clientConfig.ServerName = "example.com" + + testResume := func(t *testing.T, sc, cc *Config, expectResume bool) { + t.Helper() + ss, cs, err := testHandshake(t, cc, sc) + if err != nil { + t.Fatalf("handshake: %v", err) + } + if cs.DidResume != expectResume { + t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume) + } + if ss.DidResume != expectResume { + t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume) + } + } + + testResume(t, serverConfig, clientConfig, false) + testResume(t, serverConfig, clientConfig, true) + + clientConfig = clientConfig.Clone() + clientConfig.RootCAs = x509.NewCertPool() + clientConfig.RootCAs.AddCert(rootB) + + testResume(t, serverConfig, clientConfig, false) + testResume(t, serverConfig, clientConfig, true) +} diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 3dba595331d..a4b205908ed 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -15,6 +15,7 @@ import ( "crypto/internal/hpke" "crypto/rsa" "crypto/tls/internal/fips140tls" + "crypto/x509" "errors" "fmt" "hash" @@ -409,8 +410,13 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) { continue } + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.ClientCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven && - !anyUnexpiredChain(sessionState.verifiedChains, c.config.time()) { + !anyValidVerifiedChain(sessionState.verifiedChains, opts) { continue }