Skip to content

Commit

Permalink
Merge pull request #343 from redis/feat-parse-url
Browse files Browse the repository at this point in the history
feat: rueidis.ParseURL and rueidis.MustParseURL
  • Loading branch information
rueian authored Aug 24, 2023
2 parents 92d534f + 42b0121 commit 8d1fc7f
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,21 @@ client, err := rueidis.NewClient(rueidis.ClientOption{
})
```

### Redis URL

You can use `ParseURL` or `MustParseURL` to construct a `ClientOption`:

```go
// connect to a redis cluster
client, err = rueidis.NewClient(rueidis.MustParseURL("redis://127.0.0.1:7001"))
// connect to a redis node
client, err = rueidis.NewClient(rueidis.MustParseURL("redis://127.0.0.1:6379"))
// connect to a redis sentinel
client, err = rueidis.NewClient(rueidis.MustParseURL("redis://127.0.0.1:26379?master_set=my_master"))
```

The url must be started with either `redis://`, `rediss://` or `unix://`.

## Arbitrary Command

If you want to construct commands that are absent from the command builder, you can use `client.B().Arbitrary()`:
Expand Down
86 changes: 86 additions & 0 deletions url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package rueidis

import (
"crypto/tls"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
)

// ParseURL parses a redis URL into ClientOption.
// https://github.com/redis/redis-specifications/blob/master/uri/redis.txt
// Example:
//
// redis://<user>:<password>@<host>:<port>/<db_number>
// unix://<user>:<password>@</path/to/redis.sock>?db=<db_number>
func ParseURL(str string) (opt ClientOption, err error) {
u, _ := url.Parse(str)
switch u.Scheme {
case "unix":
opt.DialFn = func(s string, dialer *net.Dialer, config *tls.Config) (conn net.Conn, err error) {
return dialer.Dial("unix", s)
}
opt.InitAddress = []string{strings.TrimSpace(u.Path)}
case "rediss":
opt.TLSConfig = &tls.Config{}
case "redis":
default:
return opt, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme)
}
if opt.InitAddress == nil {
host, port, _ := net.SplitHostPort(u.Host)
if host == "" {
host = u.Host
}
if host == "" {
host = "localhost"
}
if port == "" {
port = "6379"
}
opt.InitAddress = []string{net.JoinHostPort(host, port)}
}
if u.User != nil {
opt.Username = u.User.Username()
opt.Password, _ = u.User.Password()
}
if ps := strings.Split(u.Path, "/"); len(ps) == 2 {
if opt.SelectDB, err = strconv.Atoi(ps[1]); err != nil {
return opt, fmt.Errorf("redis: invalid database number: %q", ps[1])
}
} else if len(ps) > 2 {
return opt, fmt.Errorf("redis: invalid URL path: %s", u.Path)
}
q := u.Query()
if q.Has("db") {
if opt.SelectDB, err = strconv.Atoi(q.Get("db")); err != nil {
return opt, fmt.Errorf("redis: invalid database number: %q", q.Get("db"))
}
}
if q.Has("dial_timeout") {
if opt.Dialer.Timeout, err = time.ParseDuration(q.Get("dial_timeout")); err != nil {
return opt, fmt.Errorf("redis: invalid dial timeout: %q", q.Get("dial_timeout"))
}
}
if q.Has("write_timeout") {
if opt.Dialer.Timeout, err = time.ParseDuration(q.Get("write_timeout")); err != nil {
return opt, fmt.Errorf("redis: invalid write timeout: %q", q.Get("write_timeout"))
}
}
opt.AlwaysRESP2 = q.Get("protocol") == "2"
opt.DisableRetry = q.Get("max_retries") == "0"
opt.ClientName = q.Get("client_name")
opt.Sentinel.MasterSet = q.Get("master_set")
return
}

func MustParseURL(str string) ClientOption {
opt, err := ParseURL(str)
if err != nil {
panic(err)
}
return opt
}
76 changes: 76 additions & 0 deletions url_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package rueidis

import (
"strings"
"testing"
)

func TestParseURL(t *testing.T) {
if opt, err := ParseURL(""); !strings.HasPrefix(err.Error(), "redis: invalid URL scheme") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("rediss://"); err != nil || opt.TLSConfig == nil {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("unix://"); err != nil || opt.DialFn == nil {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://"); err != nil || opt.InitAddress[0] != "localhost:6379" {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://localhost"); err != nil || opt.InitAddress[0] != "localhost:6379" {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://myhost:1234"); err != nil || opt.InitAddress[0] != "myhost:1234" {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://ooo:xxx@"); err != nil || opt.Username != "ooo" || opt.Password != "xxx" {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis:///1"); err != nil || opt.SelectDB != 1 {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis:///a"); !strings.HasPrefix(err.Error(), "redis: invalid database number") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis:///1?db=a"); !strings.HasPrefix(err.Error(), "redis: invalid database number") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis:////1"); !strings.HasPrefix(err.Error(), "redis: invalid URL path") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?dial_timeout=a"); !strings.HasPrefix(err.Error(), "redis: invalid dial timeout") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?write_timeout=a"); !strings.HasPrefix(err.Error(), "redis: invalid write timeout") {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?protocol=2"); !opt.AlwaysRESP2 {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?max_retries=0"); !opt.DisableRetry {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?client_name=0"); opt.ClientName != "0" {
t.Fatalf("unexpected %v %v", opt, err)
}
if opt, err := ParseURL("redis://?master_set=0"); opt.Sentinel.MasterSet != "0" {
t.Fatalf("unexpected %v %v", opt, err)
}
}

func TestMustParseURL(t *testing.T) {
defer func() {
if err := recover(); !strings.HasPrefix(err.(error).Error(), "redis: invalid URL path") {
t.Failed()
}
}()
MustParseURL("redis:////1")
}

func TestMustParseURLUnix(t *testing.T) {
opt := MustParseURL("unix://")
if conn, err := opt.DialFn("", &opt.Dialer, nil); !strings.Contains(err.Error(), "unix") {
t.Fatalf("unexpected %v %v", conn, err) // the error should be "dial unix: missing address"
}
}

0 comments on commit 8d1fc7f

Please sign in to comment.