Skip to content

Commit

Permalink
Merge pull request #326 from redis/feat-resp3-ftaggregate
Browse files Browse the repository at this point in the history
feat: supports new RESP3 FT.AGGREGATE and RedisMessage.AsFtAggregate
  • Loading branch information
rueian authored Aug 4, 2023
2 parents 0745617 + 18249b9 commit 09bba2a
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 29 deletions.
69 changes: 69 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,24 @@ func (r RedisResult) AsFtSearch() (total int64, docs []FtSearchDoc, err error) {
return
}

func (r RedisResult) AsFtAggregate() (total int64, docs []map[string]string, err error) {
if r.err != nil {
err = r.err
} else {
total, docs, err = r.val.AsFtAggregate()
}
return
}

func (r RedisResult) AsFtAggregateCursor() (cursor, total int64, docs []map[string]string, err error) {
if r.err != nil {
err = r.err
} else {
cursor, total, docs, err = r.val.AsFtAggregateCursor()
}
return
}

func (r RedisResult) AsGeosearch() (locations []GeoLocation, err error) {
if r.err != nil {
err = r.err
Expand Down Expand Up @@ -987,6 +1005,57 @@ func (m *RedisMessage) AsFtSearch() (total int64, docs []FtSearchDoc, err error)
panic(fmt.Sprintf("redis message type %s is not a FT.SEARCH response", typeNames[typ]))
}

func (m *RedisMessage) AsFtAggregate() (total int64, docs []map[string]string, err error) {
if err = m.Error(); err != nil {
return 0, nil, err
}
if m.IsMap() {
for i := 0; i < len(m.values); i += 2 {
switch m.values[i].string {
case "total_results":
total = m.values[i+1].integer
case "results":
records := m.values[i+1].values
docs = make([]map[string]string, len(records))
for d, record := range records {
for j := 0; j < len(record.values); j += 2 {
switch record.values[j].string {
case "extra_attributes":
docs[d], _ = record.values[j+1].AsStrMap()
}
}
}
case "error":
for _, e := range m.values[i+1].values {
e := e
return 0, nil, (*RedisError)(&e)
}
}
}
return
}
if len(m.values) > 0 {
total = m.values[0].integer
docs = make([]map[string]string, len(m.values)-1)
for d, record := range m.values[1:] {
docs[d], _ = record.AsStrMap()
}
return
}
typ := m.typ
panic(fmt.Sprintf("redis message type %s is not a FT.AGGREGATE response", typeNames[typ]))
}

func (m *RedisMessage) AsFtAggregateCursor() (cursor, total int64, docs []map[string]string, err error) {
if m.IsArray() && len(m.values) == 2 && (m.values[0].IsArray() || m.values[0].IsMap()) {
total, docs, err = m.values[0].AsFtAggregate()
cursor = m.values[1].integer
} else {
total, docs, err = m.AsFtAggregate()
}
return
}

type GeoLocation struct {
Name string
Longitude, Latitude, Dist float64
Expand Down
235 changes: 235 additions & 0 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,217 @@ func TestRedisResult(t *testing.T) {
}
})

t.Run("AsFtAggregate", func(t *testing.T) {
if _, _, err := (RedisResult{err: errors.New("other")}).AsFtAggregate(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
if _, _, err := (RedisResult{val: RedisMessage{typ: '-'}}).AsFtAggregate(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
if n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k1"},
{typ: '+', string: "v1"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k2"},
{typ: '+', string: "v2"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
}}}).AsFtAggregate(); n != 3 || !reflect.DeepEqual([]map[string]string{
{"k1": "v1", "kk": "vv"},
{"k2": "v2", "kk": "vv"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k1"},
{typ: '+', string: "v1"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
}}}).AsFtAggregate(); n != 3 || !reflect.DeepEqual([]map[string]string{
{"k1": "v1", "kk": "vv"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
}}}).AsFtAggregate(); n != 3 || !reflect.DeepEqual([]map[string]string{}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
})

t.Run("AsFtAggregate RESP3", func(t *testing.T) {
if n, ret, _ := (RedisResult{val: RedisMessage{typ: '%', values: []RedisMessage{
{typ: '+', string: "total_results"},
{typ: ':', integer: 3},
{typ: '+', string: "results"},
{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "1"},
}},
}},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "2"},
}},
}},
}},
{typ: '+', string: "error"},
{typ: '*', values: []RedisMessage{}},
}}}).AsFtAggregate(); n != 3 || !reflect.DeepEqual([]map[string]string{
{"$": "1"},
{"$": "2"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if _, _, err := (RedisResult{val: RedisMessage{typ: '%', values: []RedisMessage{
{typ: '+', string: "total_results"},
{typ: ':', integer: 3},
{typ: '+', string: "results"},
{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "1"},
}},
}},
}},
{typ: '+', string: "error"},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "mytimeout"},
}},
}}}).AsFtAggregate(); err == nil || err.Error() != "mytimeout" {
t.Fatal("AsFtAggregate not get value as expected")
}
})

t.Run("AsFtAggregate Cursor", func(t *testing.T) {
if _, _, _, err := (RedisResult{err: errors.New("other")}).AsFtAggregateCursor(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
if _, _, _, err := (RedisResult{val: RedisMessage{typ: '-'}}).AsFtAggregateCursor(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
if c, n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k1"},
{typ: '+', string: "v1"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k2"},
{typ: '+', string: "v2"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
}},
{typ: ':', integer: 1},
}}}).AsFtAggregateCursor(); c != 1 || n != 3 || !reflect.DeepEqual([]map[string]string{
{"k1": "v1", "kk": "vv"},
{"k2": "v2", "kk": "vv"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if c, n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "k1"},
{typ: '+', string: "v1"},
{typ: '+', string: "kk"},
{typ: '+', string: "vv"},
}},
}},
{typ: ':', integer: 1},
}}}).AsFtAggregateCursor(); c != 1 || n != 3 || !reflect.DeepEqual([]map[string]string{
{"k1": "v1", "kk": "vv"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if c, n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: '*', values: []RedisMessage{
{typ: ':', integer: 3},
}},
{typ: ':', integer: 1},
}}}).AsFtAggregateCursor(); c != 1 || n != 3 || !reflect.DeepEqual([]map[string]string{}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
})

t.Run("AsFtAggregate Cursor RESP3", func(t *testing.T) {
if c, n, ret, _ := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "total_results"},
{typ: ':', integer: 3},
{typ: '+', string: "results"},
{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "1"},
}},
}},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "2"},
}},
}},
}},
{typ: '+', string: "error"},
{typ: '*', values: []RedisMessage{}},
}},
{typ: ':', integer: 1},
}}}).AsFtAggregateCursor(); c != 1 || n != 3 || !reflect.DeepEqual([]map[string]string{
{"$": "1"},
{"$": "2"},
}, ret) {
t.Fatal("AsFtAggregate not get value as expected")
}
if _, _, _, err := (RedisResult{val: RedisMessage{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "total_results"},
{typ: ':', integer: 3},
{typ: '+', string: "results"},
{typ: '*', values: []RedisMessage{
{typ: '%', values: []RedisMessage{
{typ: '+', string: "extra_attributes"},
{typ: '%', values: []RedisMessage{
{typ: '+', string: "$"},
{typ: '+', string: "1"},
}},
}},
}},
{typ: '+', string: "error"},
{typ: '*', values: []RedisMessage{
{typ: '+', string: "mytimeout"},
}},
}},
{typ: ':', integer: 1},
}}}).AsFtAggregateCursor(); err == nil || err.Error() != "mytimeout" {
t.Fatal("AsFtAggregate not get value as expected")
}
})

t.Run("AsGeosearch", func(t *testing.T) {
if _, err := (RedisResult{err: errors.New("other")}).AsGeosearch(); err == nil {
t.Fatal("AsGeosearch not failed as expected")
Expand Down Expand Up @@ -1354,6 +1565,30 @@ func TestRedisMessage(t *testing.T) {
(&RedisMessage{typ: '*', values: []RedisMessage{}}).AsFtSearch()
})

t.Run("AsFtAggregate", func(t *testing.T) {
if _, _, err := (&RedisMessage{typ: '_'}).AsFtAggregate(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
defer func() {
if !strings.Contains(recover().(string), fmt.Sprintf("redis message type %s is not a FT.AGGREGATE response", typeNames['*'])) {
t.Fatal("AsFtAggregate not panic as expected")
}
}()
(&RedisMessage{typ: '*', values: []RedisMessage{}}).AsFtAggregate()
})

t.Run("AsFtAggregateCursor", func(t *testing.T) {
if _, _, _, err := (&RedisMessage{typ: '_'}).AsFtAggregateCursor(); err == nil {
t.Fatal("AsFtAggregate not failed as expected")
}
defer func() {
if !strings.Contains(recover().(string), fmt.Sprintf("redis message type %s is not a FT.AGGREGATE response", typeNames['*'])) {
t.Fatal("AsFtAggregate not panic as expected")
}
}()
(&RedisMessage{typ: '*', values: []RedisMessage{}}).AsFtAggregateCursor()
})

t.Run("AsScanEntry", func(t *testing.T) {
if _, err := (RedisResult{err: errors.New("other")}).AsScanEntry(); err == nil {
t.Fatal("AsScanEntry not failed as expected")
Expand Down
28 changes: 3 additions & 25 deletions om/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ import (

var EndOfCursor = errors.New("end of cursor")

func newAggregateCursor(idx string, client rueidis.Client, resp []rueidis.RedisMessage) *AggregateCursor {
c := &AggregateCursor{client: client, idx: idx}
c.n, c.first, c.id = readAggregateResponse(resp)
return c
func newAggregateCursor(idx string, client rueidis.Client, first []map[string]string, cursor, total int64) *AggregateCursor {
return &AggregateCursor{client: client, idx: idx, first: first, id: cursor, n: total}
}

// AggregateCursor unifies the response of FT.AGGREGATE with or without WITHCURSOR
Expand All @@ -39,11 +37,7 @@ func (c *AggregateCursor) Read(ctx context.Context) (partial []map[string]string
if c.id == 0 {
return nil, EndOfCursor
}
resp, err := c.client.Do(ctx, c.client.B().FtCursorRead().Index(c.idx).CursorId(c.id).Build()).ToArray()
if err != nil {
return nil, err
}
_, partial, c.id = readAggregateResponse(resp)
c.id, _, partial, err = c.client.Do(ctx, c.client.B().FtCursorRead().Index(c.idx).CursorId(c.id).Build()).AsFtAggregateCursor()
return
}

Expand All @@ -54,19 +48,3 @@ func (c *AggregateCursor) Del(ctx context.Context) (err error) {
}
return c.client.Do(ctx, c.client.B().FtCursorDel().Index(c.idx).CursorId(c.id).Build()).Error()
}

func readAggregateResponse(resp []rueidis.RedisMessage) (n int64, partial []map[string]string, cursor int64) {
var results []rueidis.RedisMessage
if resp[0].IsArray() {
results, _ = resp[0].ToArray()
cursor, _ = resp[1].ToInt64()
} else {
results = resp
}
n, _ = results[0].ToInt64()
partial = make([]map[string]string, len(results[1:]))
for i, record := range results[1:] {
partial[i], _ = record.AsStrMap()
}
return
}
4 changes: 2 additions & 2 deletions om/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ func (r *HashRepository[T]) Search(ctx context.Context, cmdFn func(search FtSear

// Aggregate performs the FT.AGGREGATE and returns a *AggregateCursor for accessing the results
func (r *HashRepository[T]) Aggregate(ctx context.Context, cmdFn func(agg FtAggregateIndex) rueidis.Completed) (cursor *AggregateCursor, err error) {
resp, err := r.client.Do(ctx, cmdFn(r.client.B().FtAggregate().Index(r.idx))).ToArray()
cid, total, resp, err := r.client.Do(ctx, cmdFn(r.client.B().FtAggregate().Index(r.idx))).AsFtAggregateCursor()
if err != nil {
return nil, err
}
return newAggregateCursor(r.idx, r.client, resp), nil
return newAggregateCursor(r.idx, r.client, resp, cid, total), nil
}

// IndexName returns the index name used in the FT.CREATE
Expand Down
Loading

0 comments on commit 09bba2a

Please sign in to comment.