| // Copyright 2026 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package entx |
| |
| import ( |
| "fmt" |
| |
| "entgo.io/ent/dialect" |
| "entgo.io/ent/dialect/sql" |
| |
| ent "go.chromium.org/infra/fleetconsole/internal/ent/generated" |
| ) |
| |
| // True returns a standard SQL TRUE predicate. |
| // This is primarily used for constructing custom ON clauses in JOINs (e.g., ON TRUE) |
| // or for "no-op" filters in complex modification blocks where a predicate is required. |
| func True() *sql.Predicate { |
| return sql.P(func(b *sql.Builder) { |
| b.WriteString("TRUE") |
| }) |
| } |
| |
| // CrossJoin adds a CROSS JOIN (implemented as JOIN ON TRUE) to the selector. |
| // |
| // Usage: |
| // |
| // entx.CrossJoin(s, t) |
| func CrossJoin(s *sql.Selector, t sql.TableView) *sql.Selector { |
| return s.Join(t).OnP(True()) |
| } |
| |
| // Lateral creates a table selector for use in LATERAL joins. |
| func Lateral(name string) *sql.SelectTable { |
| return sql.Table(fmt.Sprintf("LATERAL %s", name)) |
| } |
| |
| // CountIf generates COUNT(*) FILTER (WHERE <predicate>) or |
| // COUNT(CASE WHEN <predicate> THEN 1 END) depending on the dialect. |
| func CountIf(p *sql.Predicate) ent.AggregateFunc { |
| return func(s *sql.Selector) string { |
| // Sync the predicate's internal builder with the main query. |
| // We set Total() so the predicate generates placeholders ($5, $6) |
| // that continue from where the main query left off. |
| p.SetDialect(s.Dialect()) |
| p.SetTotal(s.Total()) |
| |
| // Compile the predicate to SQL and arguments. |
| condition, args := p.Query() |
| |
| // Inject the arguments into the main query. |
| // We use sql.Expr("", args...) to trick the builder into appending |
| // the arguments to its internal slice WITHOUT writing any new SQL |
| // text to the buffer (because the expr string is empty). |
| s.Arg(sql.Expr("", args...)) |
| |
| if s.Dialect() == dialect.Postgres { |
| // Return the standard Postgres aggregation SQL. |
| return fmt.Sprintf("COUNT(*) FILTER (WHERE %s)", condition) |
| } |
| |
| return fmt.Sprintf("COUNT(CASE WHEN %s THEN 1 END)", condition) |
| } |
| } |