Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions compiler/dag/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ func (seq *Seq) Delete(from, to int) {

type (
AggregateOp struct {
Kind string `json:"kind" unpack:""`
Limit int `json:"limit"`
Keys []Assignment `json:"keys"`
Aggs []Assignment `json:"aggs"`
InputSortDir int `json:"input_sort_dir,omitempty"`
PartialsIn bool `json:"partials_in,omitempty"`
PartialsOut bool `json:"partials_out,omitempty"`
Kind string `json:"kind" unpack:""`
Limit int `json:"limit"`
Keys []Assignment `json:"keys"`
Aggs []Assignment `json:"aggs"`
PartialsIn bool `json:"partials_in,omitempty"`
PartialsOut bool `json:"partials_out,omitempty"`
}
CombineOp struct {
Kind string `json:"kind" unpack:""`
Expand Down
23 changes: 0 additions & 23 deletions compiler/optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,29 +332,6 @@ func (o *Optimizer) propagateSortKeyOp(op dag.Op, parents []order.SortKeys) ([]o
}
switch op := op.(type) {
case *dag.AggregateOp:
if parent.IsNil() {
return []order.SortKeys{nil}, nil
}
//XXX handle only primary sortKey for now
sortKey := parent.Primary()
for _, k := range op.Keys {
if groupingKey := fieldOf(k.LHS); groupingKey.Equal(sortKey.Key) {
rhsExpr := k.RHS
rhs := fieldOf(rhsExpr)
if rhs.Equal(sortKey.Key) || orderPreservingCall(rhsExpr, groupingKey) {
op.InputSortDir = int(sortKey.Order.Direction())
// Currently, the aggregate operator will sort its
// output according to the primary key, but we
// should relax this and do an analysis here as
// to whether the sort is necessary for the
// downstream consumer.
return []order.SortKeys{parent}, nil
}
}
}
// We'll leave this as unknown for now in spite of the aggregate
// and not try to optimize downstream of the first aggregate
// unless there is an excplicit sort encountered.
return []order.SortKeys{nil}, nil
case *dag.ForkOp:
var keys []order.SortKeys
Expand Down
23 changes: 0 additions & 23 deletions compiler/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/brimdata/super/compiler/semantic"
"github.com/brimdata/super/compiler/srcfiles"
"github.com/brimdata/super/dbid"
"github.com/brimdata/super/order"
"github.com/brimdata/super/runtime"
"github.com/brimdata/super/runtime/exec"
"github.com/brimdata/super/runtime/vam/op"
Expand Down Expand Up @@ -140,25 +139,3 @@ func VectorFilterCompile(rctx *runtime.Context, query string, env *exec.Environm
}
return rungen.NewBuilder(rctx, env).BuildVamToSeqFilter(f.Expr, poolID, commitID)
}

// XXX currently used only by aggregate test, need to deprecate
func CompileWithSortKey(rctx *runtime.Context, ast *parser.AST, r sio.Reader, sortKey order.SortKey) (*exec.Query, error) {
env := exec.NewEnvironment(nil, nil)
main, err := Analyze(rctx, ast, env, true)
if err != nil {
return nil, err
}
scan, ok := main.Body[0].(*dag.DefaultScan)
if !ok {
return nil, errors.New("CompileWithSortKey: expected a reader")
}
scan.SortKeys = order.SortKeys{sortKey}
if err := Optimize(rctx, main, env, 0); err != nil {
return nil, err
}
outputs, debugs, meter, err := Build(rctx, main, env, []sio.Reader{r})
if err != nil {
return nil, err
}
return exec.NewQuery(rctx, bundleOutputs(rctx, outputs, debugs), meter), nil
}
4 changes: 1 addition & 3 deletions compiler/rungen/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"github.com/brimdata/super/compiler/dag"
"github.com/brimdata/super/order"
"github.com/brimdata/super/pkg/field"
"github.com/brimdata/super/runtime/sam/expr"
"github.com/brimdata/super/runtime/sam/op/aggregate"
Expand All @@ -21,11 +20,10 @@ func (b *Builder) compileAggregate(parent sbuf.Puller, a *dag.AggregateOp) (sbuf
if err != nil {
return nil, err
}
dir := order.Direction(a.InputSortDir)
if len(keys) == 0 {
return aggregate.NewScalar(b.rctx, parent, names, reducers, a.PartialsIn, a.PartialsOut)
}
return aggregate.New(b.rctx, parent, keys, names, reducers, a.Limit, dir, a.PartialsIn, a.PartialsOut)
return aggregate.New(b.rctx, parent, keys, names, reducers, a.Limit, a.PartialsIn, a.PartialsOut)
}

func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) {
Expand Down
3 changes: 0 additions & 3 deletions compiler/sfmt/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,6 @@ func (c *canonDAG) op(p dag.Op) {
if p.PartialsOut {
c.write(" partials-out")
}
if p.InputSortDir != 0 {
c.write(" sort-dir %d", p.InputSortDir)
}
c.ret()
c.open()
c.assignments(p.Aggs)
Expand Down
6 changes: 3 additions & 3 deletions compiler/ztests/par-ts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ outputs:
| scatter
(
seqscan ...
| aggregate partials-out sort-dir 1
| aggregate partials-out
count:=count() by y:=y,ts:=bucket(ts, 1h)
)
(
seqscan ...
| aggregate partials-out sort-dir 1
| aggregate partials-out
count:=count() by y:=y,ts:=bucket(ts, 1h)
)
| merge ts asc nulls last
| aggregate partials-in sort-dir 1
| aggregate partials-in
count:=count() by y:=y,ts:=ts
| output main
<PUT COUNTDISTINCT UNIQ>
Expand Down
15 changes: 0 additions & 15 deletions compiler/ztests/sem-aggregate-input-dir.yaml

This file was deleted.

121 changes: 13 additions & 108 deletions runtime/sam/op/aggregate/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/binary"
"errors"
"slices"
"sync"

"github.com/brimdata/super"
Expand Down Expand Up @@ -54,12 +53,7 @@ type Aggregator struct {
recordTypes map[int]*super.TypeRecord
table map[string]*Row
limit int
valueCompare expr.CompareFn // to compare primary group keys for early key output
keyCompare expr.CompareFn // compare the first key (used when input sorted)
keysComparator *expr.Comparator // compare all keys
maxTableKey *super.Value
maxSpillKey *super.Value
inputDir order.Direction
spiller *spill.MergeSort
partialsIn bool
partialsOut bool
Expand All @@ -71,26 +65,17 @@ type Row struct {
reducers valRow
}

func NewAggregator(ctx context.Context, sctx *super.Context, keyRefs, keyExprs, aggRefs []expr.Evaluator, aggs []*expr.Aggregator, builder *super.RecordBuilder, limit int, inputDir order.Direction, partialsIn, partialsOut bool) (*Aggregator, error) {
func NewAggregator(ctx context.Context, sctx *super.Context, keyRefs, keyExprs, aggRefs []expr.Evaluator, aggs []*expr.Aggregator, builder *super.RecordBuilder, limit int, partialsIn, partialsOut bool) (*Aggregator, error) {
if limit == 0 {
limit = DefaultLimit
}
var keyCompare, valueCompare expr.CompareFn
nkeys := len(keyExprs)
o, ok := inputDir.Which()
if ok && nkeys > 0 {
keySortExpr := expr.NewSortExpr(keyRefs[0], o, o.NullsMax(true))
keyCompare = expr.NewComparator(keySortExpr).WithMissingAsNull().Compare
valueCompare = expr.NewValueCompareFn(o, o.NullsMax(true))
}
var sortExprs []expr.SortExpr
for _, e := range keyRefs {
sortExprs = append(sortExprs, expr.NewSortExpr(e, o, o.NullsMax(true)))
sortExprs = append(sortExprs, expr.NewSortExpr(e, order.Asc, order.NullsLast))
}
return &Aggregator{
ctx: ctx,
sctx: sctx,
inputDir: inputDir,
limit: limit,
keyTypes: super.NewTypeVectorTable(),
outTypes: super.NewTypeVectorTable(),
Expand All @@ -99,19 +84,17 @@ func NewAggregator(ctx context.Context, sctx *super.Context, keyRefs, keyExprs,
aggRefs: aggRefs,
aggs: aggs,
builder: builder,
typeCache: make([]super.Type, nkeys+len(aggs)),
typeCache: make([]super.Type, len(keyExprs)+len(aggs)),
keyCache: make(scode.Bytes, 0, 128),
table: make(map[string]*Row),
recordTypes: make(map[int]*super.TypeRecord),
keyCompare: keyCompare,
keysComparator: expr.NewComparator(sortExprs...).WithMissingAsNull(),
valueCompare: valueCompare,
partialsIn: partialsIn,
partialsOut: partialsOut,
}, nil
}

func New(rctx *runtime.Context, parent sbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, inputSortDir order.Direction, partialsIn, partialsOut bool) (sbuf.Puller, error) {
func New(rctx *runtime.Context, parent sbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, partialsIn, partialsOut bool) (*Op, error) {
names := make(field.List, 0, len(keys)+len(aggNames))
for _, e := range keys {
p, ok := e.LHS.Path()
Expand All @@ -135,7 +118,7 @@ func New(rctx *runtime.Context, parent sbuf.Puller, keys []expr.Assignment, aggN
keyRefs = append(keyRefs, expr.NewDottedExpr(rctx.Sctx, names[i]))
keyExprs = append(keyExprs, keys[i].RHS)
}
agg, err := NewAggregator(rctx.Context, rctx.Sctx, keyRefs, keyExprs, valRefs, aggs, builder, limit, inputSortDir, partialsIn, partialsOut)
agg, err := NewAggregator(rctx.Context, rctx.Sctx, keyRefs, keyExprs, valRefs, aggs, builder, limit, partialsIn, partialsOut)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,31 +203,6 @@ func (o *Op) run() {
return
}
}
if o.agg.inputDir == 0 {
batch.Unref()
continue
}
// sorted input: see if we have any completed keys we can emit.
for {
res, err := o.agg.nextResult(false, batch)
if err != nil {
if _, ok := o.sendResult(nil, err); !ok {
return
}
break
}
if res == nil {
break
}
slices.SortStableFunc(res.Values(), o.agg.keyCompare)
done, ok := o.sendResult(res, nil)
if !ok {
return
}
if done {
break
}
}
batch.Unref()
}
}
Expand Down Expand Up @@ -318,14 +276,11 @@ func (a *Aggregator) Consume(batch sbuf.Batch, this super.Value) error {
types := a.typeCache[:0]
keyBytes := a.keyCache[:0]
var prim super.Value
for i, keyExpr := range a.keyExprs {
for _, keyExpr := range a.keyExprs {
key := keyExpr.Eval(this).SuperDeunion()
if key.IsQuiet() {
return nil
}
if i == 0 && a.inputDir != 0 {
prim = a.updateMaxTableKey(key)
}
types = append(types, key.Type())
// Append each value to the key as a flat value, independent
// of whether this is a primitive or container.
Expand Down Expand Up @@ -361,7 +316,7 @@ func (a *Aggregator) Consume(batch sbuf.Batch, this super.Value) error {
}

func (a *Aggregator) spillTable(eof bool, ref sbuf.Batch) error {
batch, err := a.readTable(true, true, ref)
batch, err := a.readTable(true, ref)
if err != nil || batch == nil {
return err
}
Expand All @@ -373,33 +328,7 @@ func (a *Aggregator) spillTable(eof bool, ref sbuf.Batch) error {
}
recs := batch.Values()
// Note that this will sort recs according to g.keysComparator.
if err := a.spiller.Spill(a.ctx, recs); err != nil {
return err
}
if !eof && a.inputDir != 0 {
val := a.keyExprs[0].Eval(recs[len(recs)-1])
if !val.IsError() {
// pass volatile super.Value since updateMaxSpillKey will make
// a copy if needed.
a.updateMaxSpillKey(val)
}
}
return nil
}

// updateMaxTableKey is called with a volatile super.Value to update the
// max value seen in the table for the streaming logic when the input is sorted.
func (a *Aggregator) updateMaxTableKey(val super.Value) super.Value {
if a.maxTableKey == nil || a.valueCompare(val, *a.maxTableKey) > 0 {
a.maxTableKey = val.Copy().Ptr()
}
return *a.maxTableKey
}

func (a *Aggregator) updateMaxSpillKey(v super.Value) {
if a.maxSpillKey == nil || a.valueCompare(v, *a.maxSpillKey) > 0 {
a.maxSpillKey = v.Copy().Ptr()
}
return a.spiller.Spill(a.ctx, recs)
}

// Results returns a batch of aggregation result records. Upon eof,
Expand All @@ -408,36 +337,20 @@ func (a *Aggregator) updateMaxSpillKey(v super.Value) {
// before eof, and keys that are completed will returned.
func (a *Aggregator) nextResult(eof bool, batch sbuf.Batch) (sbuf.Batch, error) {
if a.spiller == nil {
return a.readTable(eof, a.partialsOut, batch)
return a.readTable(a.partialsOut, batch)
}
if eof {
// EOF: spill in-memory table before merging all files for output.
if err := a.spillTable(true, batch); err != nil {
return nil, err
}
}
return a.readSpills(eof, batch)
return a.readSpills(batch)
}

func (a *Aggregator) readSpills(eof bool, batch sbuf.Batch) (sbuf.Batch, error) {
func (a *Aggregator) readSpills(batch sbuf.Batch) (sbuf.Batch, error) {
recs := make([]super.Value, 0, op.BatchLen)
if !eof && a.inputDir == 0 {
return nil, nil
}
for len(recs) < op.BatchLen {
if !eof && a.inputDir != 0 {
rec, err := a.spiller.Peek()
if err != nil {
return nil, err
}
if rec == nil {
break
}
keyVal := a.keyExprs[0].Eval(*rec)
if !keyVal.IsError() && a.valueCompare(keyVal, *a.maxSpillKey) >= 0 {
break
}
}
rec, err := a.nextResultFromSpills()
if err != nil {
return nil, err
Expand Down Expand Up @@ -511,19 +424,11 @@ func (a *Aggregator) nextResultFromSpills() (*super.Value, error) {
}

// readTable returns a slice of records from the in-memory aggregate
// table. If flush is true, the entire table is returned. If flush is
// false and input is sorted only completed keys are returned.
// If partialsOut is true, it returns partial aggregation results as
// table. If partialsOut is true, it returns partial aggregation results as
// defined by each agg.Function.ResultAsPartial() method.
func (a *Aggregator) readTable(flush, partialsOut bool, batch sbuf.Batch) (sbuf.Batch, error) {
func (a *Aggregator) readTable(partialsOut bool, batch sbuf.Batch) (sbuf.Batch, error) {
var recs []super.Value
for key, row := range a.table {
if !flush && a.valueCompare == nil {
panic("internal bug: tried to fetch completed tuples on non-sorted input")
}
if !flush && a.valueCompare(row.groupval, *a.maxTableKey) >= 0 {
continue
}
// To build the output record, we spin over the key values
// and append them with the buidler, then spin over the aggregations
// and append each value. The builder is already set up with
Expand Down
Loading
Loading