blob: b677632384536cf591c0cae38dc34250ad34d956 [file] [log] [blame]
// Copyright 2026 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package entx
import (
"fmt"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
ent "go.chromium.org/infra/fleetconsole/internal/ent/generated"
)
// True returns a standard SQL TRUE predicate.
// This is primarily used for constructing custom ON clauses in JOINs (e.g., ON TRUE)
// or for "no-op" filters in complex modification blocks where a predicate is required.
func True() *sql.Predicate {
return sql.P(func(b *sql.Builder) {
b.WriteString("TRUE")
})
}
// CrossJoin adds a CROSS JOIN (implemented as JOIN ON TRUE) to the selector.
//
// Usage:
//
// entx.CrossJoin(s, t)
func CrossJoin(s *sql.Selector, t sql.TableView) *sql.Selector {
return s.Join(t).OnP(True())
}
// Lateral creates a table selector for use in LATERAL joins.
func Lateral(name string) *sql.SelectTable {
return sql.Table(fmt.Sprintf("LATERAL %s", name))
}
// CountIf generates COUNT(*) FILTER (WHERE <predicate>) or
// COUNT(CASE WHEN <predicate> THEN 1 END) depending on the dialect.
func CountIf(p *sql.Predicate) ent.AggregateFunc {
return func(s *sql.Selector) string {
// Sync the predicate's internal builder with the main query.
// We set Total() so the predicate generates placeholders ($5, $6)
// that continue from where the main query left off.
p.SetDialect(s.Dialect())
p.SetTotal(s.Total())
// Compile the predicate to SQL and arguments.
condition, args := p.Query()
// Inject the arguments into the main query.
// We use sql.Expr("", args...) to trick the builder into appending
// the arguments to its internal slice WITHOUT writing any new SQL
// text to the buffer (because the expr string is empty).
s.Arg(sql.Expr("", args...))
if s.Dialect() == dialect.Postgres {
// Return the standard Postgres aggregation SQL.
return fmt.Sprintf("COUNT(*) FILTER (WHERE %s)", condition)
}
return fmt.Sprintf("COUNT(CASE WHEN %s THEN 1 END)", condition)
}
}