Skip to content

Commit aaff9e7

Browse files
authored
grpc: better RFC 3986 compliant target parsing (#4817)
1 parent 45097a8 commit aaff9e7

13 files changed

+314
-244
lines changed

clientconn.go

+144-40
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"errors"
2424
"fmt"
2525
"math"
26+
"net/url"
2627
"reflect"
2728
"strings"
2829
"sync"
@@ -37,7 +38,6 @@ import (
3738
"google.golang.org/grpc/internal/backoff"
3839
"google.golang.org/grpc/internal/channelz"
3940
"google.golang.org/grpc/internal/grpcsync"
40-
"google.golang.org/grpc/internal/grpcutil"
4141
iresolver "google.golang.org/grpc/internal/resolver"
4242
"google.golang.org/grpc/internal/transport"
4343
"google.golang.org/grpc/keepalive"
@@ -248,38 +248,15 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
248248
}
249249

250250
// Determine the resolver to use.
251-
cc.parsedTarget = grpcutil.ParseTarget(cc.target, cc.dopts.copts.Dialer != nil)
252-
channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
253-
resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
254-
if resolverBuilder == nil {
255-
// If resolver builder is still nil, the parsed target's scheme is
256-
// not registered. Fallback to default resolver and set Endpoint to
257-
// the original target.
258-
channelz.Infof(logger, cc.channelzID, "scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme)
259-
cc.parsedTarget = resolver.Target{
260-
Scheme: resolver.GetDefaultScheme(),
261-
Endpoint: target,
262-
}
263-
resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme)
264-
if resolverBuilder == nil {
265-
return nil, fmt.Errorf("could not get resolver for default scheme: %q", cc.parsedTarget.Scheme)
266-
}
251+
resolverBuilder, err := cc.parseTargetAndFindResolver()
252+
if err != nil {
253+
return nil, err
267254
}
268-
269-
creds := cc.dopts.copts.TransportCredentials
270-
if creds != nil && creds.Info().ServerName != "" {
271-
cc.authority = creds.Info().ServerName
272-
} else if cc.dopts.insecure && cc.dopts.authority != "" {
273-
cc.authority = cc.dopts.authority
274-
} else if strings.HasPrefix(cc.target, "unix:") || strings.HasPrefix(cc.target, "unix-abstract:") {
275-
cc.authority = "localhost"
276-
} else if strings.HasPrefix(cc.parsedTarget.Endpoint, ":") {
277-
cc.authority = "localhost" + cc.parsedTarget.Endpoint
278-
} else {
279-
// Use endpoint from "scheme://authority/endpoint" as the default
280-
// authority for ClientConn.
281-
cc.authority = cc.parsedTarget.Endpoint
255+
cc.authority, err = determineAuthority(cc.parsedTarget.Endpoint, cc.target, cc.dopts)
256+
if err != nil {
257+
return nil, err
282258
}
259+
channelz.Infof(logger, cc.channelzID, "Channel authority set to %q", cc.authority)
283260

284261
if cc.dopts.scChan != nil && !scSet {
285262
// Blocking wait for the initial service config.
@@ -902,10 +879,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
902879
// ac.state is Ready, try to find the connected address.
903880
var curAddrFound bool
904881
for _, a := range addrs {
905-
// a.ServerName takes precedent over ClientConn authority, if present.
906-
if a.ServerName == "" {
907-
a.ServerName = ac.cc.authority
908-
}
882+
a.ServerName = ac.cc.getServerName(a)
909883
if reflect.DeepEqual(ac.curAddr, a) {
910884
curAddrFound = true
911885
break
@@ -919,6 +893,26 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
919893
return curAddrFound
920894
}
921895

896+
// getServerName determines the serverName to be used in the connection
897+
// handshake. The default value for the serverName is the authority on the
898+
// ClientConn, which either comes from the user's dial target or through an
899+
// authority override specified using the WithAuthority dial option. Name
900+
// resolvers can specify a per-address override for the serverName through the
901+
// resolver.Address.ServerName field which is used only if the WithAuthority
902+
// dial option was not used. The rationale is that per-address authority
903+
// overrides specified by the name resolver can represent a security risk, while
904+
// an override specified by the user is more dependable since they probably know
905+
// what they are doing.
906+
func (cc *ClientConn) getServerName(addr resolver.Address) string {
907+
if cc.dopts.authority != "" {
908+
return cc.dopts.authority
909+
}
910+
if addr.ServerName != "" {
911+
return addr.ServerName
912+
}
913+
return cc.authority
914+
}
915+
922916
func getMethodConfig(sc *ServiceConfig, method string) MethodConfig {
923917
if sc == nil {
924918
return MethodConfig{}
@@ -1275,11 +1269,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
12751269
prefaceReceived := grpcsync.NewEvent()
12761270
connClosed := grpcsync.NewEvent()
12771271

1278-
// addr.ServerName takes precedent over ClientConn authority, if present.
1279-
if addr.ServerName == "" {
1280-
addr.ServerName = ac.cc.authority
1281-
}
1282-
1272+
addr.ServerName = ac.cc.getServerName(addr)
12831273
hctx, hcancel := context.WithCancel(ac.ctx)
12841274
hcStarted := false // protected by ac.mu
12851275

@@ -1621,3 +1611,117 @@ func (cc *ClientConn) connectionError() error {
16211611
defer cc.lceMu.Unlock()
16221612
return cc.lastConnectionError
16231613
}
1614+
1615+
func (cc *ClientConn) parseTargetAndFindResolver() (resolver.Builder, error) {
1616+
channelz.Infof(logger, cc.channelzID, "original dial target is: %q", cc.target)
1617+
1618+
var rb resolver.Builder
1619+
parsedTarget, err := parseTarget(cc.target)
1620+
if err != nil {
1621+
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", cc.target, err)
1622+
} else {
1623+
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %+v", parsedTarget)
1624+
rb = cc.getResolver(parsedTarget.Scheme)
1625+
if rb != nil {
1626+
cc.parsedTarget = parsedTarget
1627+
return rb, nil
1628+
}
1629+
}
1630+
1631+
// We are here because the user's dial target did not contain a scheme or
1632+
// specified an unregistered scheme. We should fallback to the default
1633+
// scheme, except when a custom dialer is specified in which case, we should
1634+
// always use passthrough scheme.
1635+
defScheme := resolver.GetDefaultScheme()
1636+
if cc.dopts.copts.Dialer != nil {
1637+
defScheme = "passthrough"
1638+
}
1639+
channelz.Infof(logger, cc.channelzID, "fallback to scheme %q", defScheme)
1640+
canonicalTarget := defScheme + ":///" + cc.target
1641+
1642+
parsedTarget, err = parseTarget(canonicalTarget)
1643+
if err != nil {
1644+
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", canonicalTarget, err)
1645+
return nil, err
1646+
}
1647+
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %+v", parsedTarget)
1648+
rb = cc.getResolver(parsedTarget.Scheme)
1649+
if rb == nil {
1650+
return nil, fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.Scheme)
1651+
}
1652+
cc.parsedTarget = parsedTarget
1653+
return rb, nil
1654+
}
1655+
1656+
// parseTarget uses RFC 3986 semantics to parse the given target into a
1657+
// resolver.Target struct containing scheme, authority and endpoint. Query
1658+
// params are stripped from the endpoint.
1659+
func parseTarget(target string) (resolver.Target, error) {
1660+
u, err := url.Parse(target)
1661+
if err != nil {
1662+
return resolver.Target{}, err
1663+
}
1664+
// For targets of the form "[scheme]://[authority]/endpoint, the endpoint
1665+
// value returned from url.Parse() contains a leading "/". Although this is
1666+
// in accordance with RFC 3986, we do not want to break existing resolver
1667+
// implementations which expect the endpoint without the leading "/". So, we
1668+
// end up stripping the leading "/" here. But this will result in an
1669+
// incorrect parsing for something like "unix:///path/to/socket". Since we
1670+
// own the "unix" resolver, we can workaround in the unix resolver by using
1671+
// the `URL` field instead of the `Endpoint` field.
1672+
endpoint := u.Path
1673+
if endpoint == "" {
1674+
endpoint = u.Opaque
1675+
}
1676+
endpoint = strings.TrimPrefix(endpoint, "/")
1677+
return resolver.Target{
1678+
Scheme: u.Scheme,
1679+
Authority: u.Host,
1680+
Endpoint: endpoint,
1681+
URL: *u,
1682+
}, nil
1683+
}
1684+
1685+
// Determine channel authority. The order of precedence is as follows:
1686+
// - user specified authority override using `WithAuthority` dial option
1687+
// - creds' notion of server name for the authentication handshake
1688+
// - endpoint from dial target of the form "scheme://[authority]/endpoint"
1689+
func determineAuthority(endpoint, target string, dopts dialOptions) (string, error) {
1690+
// Historically, we had two options for users to specify the serverName or
1691+
// authority for a channel. One was through the transport credentials
1692+
// (either in its constructor, or through the OverrideServerName() method).
1693+
// The other option (for cases where WithInsecure() dial option was used)
1694+
// was to use the WithAuthority() dial option.
1695+
//
1696+
// A few things have changed since:
1697+
// - `insecure` package with an implementation of the `TransportCredentials`
1698+
// interface for the insecure case
1699+
// - WithAuthority() dial option support for secure credentials
1700+
authorityFromCreds := ""
1701+
if creds := dopts.copts.TransportCredentials; creds != nil && creds.Info().ServerName != "" {
1702+
authorityFromCreds = creds.Info().ServerName
1703+
}
1704+
authorityFromDialOption := dopts.authority
1705+
if (authorityFromCreds != "" && authorityFromDialOption != "") && authorityFromCreds != authorityFromDialOption {
1706+
return "", fmt.Errorf("ClientConn's authority from transport creds %q and dial option %q don't match", authorityFromCreds, authorityFromDialOption)
1707+
}
1708+
1709+
switch {
1710+
case authorityFromDialOption != "":
1711+
return authorityFromDialOption, nil
1712+
case authorityFromCreds != "":
1713+
return authorityFromCreds, nil
1714+
case strings.HasPrefix(target, "unix:") || strings.HasPrefix(target, "unix-abstract:"):
1715+
// TODO: remove when the unix resolver implements optional interface to
1716+
// return channel authority.
1717+
return "localhost", nil
1718+
case strings.HasPrefix(endpoint, ":"):
1719+
return "localhost" + endpoint, nil
1720+
default:
1721+
// TODO: Define an optional interface on the resolver builder to return
1722+
// the channel authority given the user's dial target. For resolvers
1723+
// which don't implement this interface, we will use the endpoint from
1724+
// "scheme://authority/endpoint" as the default authority.
1725+
return endpoint, nil
1726+
}
1727+
}

clientconn_authority_test.go

+16-4
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,9 @@ func (s) TestClientConnAuthority(t *testing.T) {
5959
wantAuthority: "authority-override",
6060
},
6161
{
62-
name: "override-via-creds-and-WithAuthority",
63-
target: "Non-Existent.Server:8080",
64-
// WithAuthority override works only for insecure creds.
65-
opts: []DialOption{WithTransportCredentials(creds), WithAuthority("authority-override")},
62+
name: "override-via-creds-and-WithAuthority",
63+
target: "Non-Existent.Server:8080",
64+
opts: []DialOption{WithTransportCredentials(creds), WithAuthority(serverNameOverride)},
6665
wantAuthority: serverNameOverride,
6766
},
6867
{
@@ -120,3 +119,16 @@ func (s) TestClientConnAuthority(t *testing.T) {
120119
})
121120
}
122121
}
122+
123+
func (s) TestClientConnAuthority_CredsAndDialOptionMismatch(t *testing.T) {
124+
serverNameOverride := "over.write.server.name"
125+
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), serverNameOverride)
126+
if err != nil {
127+
t.Fatalf("credentials.NewClientTLSFromFile(_, %q) failed: %v", err, serverNameOverride)
128+
}
129+
opts := []DialOption{WithTransportCredentials(creds), WithAuthority("authority-override")}
130+
if cc, err := Dial("Non-Existent.Server:8000", opts...); err == nil {
131+
cc.Close()
132+
t.Fatal("grpc.Dial() succeeded when expected to fail")
133+
}
134+
}

0 commit comments

Comments
 (0)