| // Copyright 2018 The LUCI Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package lib |
| |
| import ( |
| "context" |
| "encoding/json" |
| "io" |
| "log" |
| "os" |
| "sync" |
| "time" |
| |
| "github.com/maruel/subcommands" |
| |
| "go.chromium.org/luci/common/api/swarming/swarming/v1" |
| "go.chromium.org/luci/common/errors" |
| "go.chromium.org/luci/common/sync/parallel" |
| "go.chromium.org/luci/common/system/signals" |
| ) |
| |
| // CmdSpawnTasks returns an object for the `spawn-tasks` subcommand. |
| func CmdSpawnTasks(authFlags AuthFlags) *subcommands.Command { |
| return &subcommands.Command{ |
| UsageLine: "spawn-tasks <options>", |
| ShortDesc: "Spawns a set of Swarming tasks", |
| LongDesc: "Spawns a set of Swarming tasks given a JSON file.", |
| CommandRun: func() subcommands.CommandRun { |
| r := &spawnTasksRun{} |
| r.Init(authFlags) |
| return r |
| }, |
| } |
| } |
| |
| type spawnTasksRun struct { |
| commonFlags |
| jsonInput string |
| jsonOutput string |
| cancelExtraTasks bool |
| } |
| |
| func (c *spawnTasksRun) Init(authFlags AuthFlags) { |
| c.commonFlags.Init(authFlags) |
| c.Flags.StringVar(&c.jsonInput, "json-input", "", "(required) Read Swarming task requests from this file.") |
| c.Flags.StringVar(&c.jsonOutput, "json-output", "", "Write details about the triggered task(s) to this file as json.") |
| // TODO(https://crbug.com/997221): Remove this option. |
| c.Flags.BoolVar(&c.cancelExtraTasks, "cancel-extra-tasks", false, "Cancel extra spawned tasks.") |
| } |
| |
| func (c *spawnTasksRun) Parse(args []string) error { |
| if err := c.commonFlags.Parse(); err != nil { |
| return err |
| } |
| if c.jsonInput == "" { |
| return errors.Reason("input JSON file is required").Err() |
| } |
| return nil |
| } |
| |
| func (c *spawnTasksRun) Run(a subcommands.Application, args []string, env subcommands.Env) int { |
| if err := c.Parse(args); err != nil { |
| printError(a, err) |
| return 1 |
| } |
| if err := c.main(a, args, env); err != nil { |
| printError(a, err) |
| return 1 |
| } |
| return 0 |
| } |
| |
| func (c *spawnTasksRun) main(a subcommands.Application, args []string, env subcommands.Env) error { |
| start := time.Now() |
| ctx, cancel := context.WithCancel(c.defaultFlags.MakeLoggingContext(os.Stderr)) |
| defer cancel() |
| defer signals.HandleInterrupt(cancel)() |
| |
| tasksFile, err := os.Open(c.jsonInput) |
| if err != nil { |
| return errors.Annotate(err, "failed to open tasks file").Err() |
| } |
| defer tasksFile.Close() |
| requests, err := processTasksStream(tasksFile) |
| if err != nil { |
| return err |
| } |
| |
| service, err := c.createSwarmingClient(ctx) |
| if err != nil { |
| return err |
| } |
| results, merr := createNewTasks(ctx, service, requests) |
| |
| var output io.Writer |
| if c.jsonOutput != "" { |
| file, err := os.Create(c.jsonOutput) |
| if err != nil { |
| return err |
| } |
| defer file.Close() |
| output = file |
| } else { |
| output = os.Stdout |
| } |
| |
| data := TriggerResults{Tasks: results} |
| b, err := json.MarshalIndent(&data, "", " ") |
| if err != nil { |
| return errors.Annotate(err, "marshalling trigger result").Err() |
| } |
| |
| if _, err = output.Write(b); err != nil { |
| return errors.Annotate(err, "writing json output").Err() |
| } |
| |
| log.Printf("Duration: %s\n", time.Since(start).Round(time.Millisecond)) |
| return merr |
| } |
| |
| type tasksInput struct { |
| Requests []*swarming.SwarmingRpcsNewTaskRequest `json:"requests"` |
| } |
| |
| func sendSizeBytes(p *swarming.SwarmingRpcsTaskProperties) { |
| if p != nil && p.CasInputRoot != nil && p.CasInputRoot.Digest != nil { |
| p.CasInputRoot.Digest.ForceSendFields = append(p.CasInputRoot.Digest.ForceSendFields, "SizeBytes") |
| } |
| } |
| |
| func processTasksStream(tasks io.Reader) ([]*swarming.SwarmingRpcsNewTaskRequest, error) { |
| dec := json.NewDecoder(tasks) |
| dec.DisallowUnknownFields() |
| |
| requests := tasksInput{} |
| if err := dec.Decode(&requests); err != nil { |
| return nil, errors.Annotate(err, "decoding tasks file").Err() |
| } |
| |
| // Populate the tasks with information about the current envirornment |
| // if they're not already set. |
| currentUser := os.Getenv(UserEnvVar) |
| parentTaskID := os.Getenv(TaskIDEnvVar) |
| for _, request := range requests.Requests { |
| if request.User == "" { |
| request.User = currentUser |
| } |
| if request.ParentTaskId == "" { |
| request.ParentTaskId = parentTaskID |
| } |
| } |
| |
| // Allow to send 0 size bytes for input digest. |
| for _, request := range requests.Requests { |
| sendSizeBytes(request.Properties) |
| for _, slice := range request.TaskSlices { |
| sendSizeBytes(slice.Properties) |
| } |
| } |
| |
| return requests.Requests, nil |
| } |
| |
| func createNewTasks(ctx context.Context, service swarmingService, requests []*swarming.SwarmingRpcsNewTaskRequest) ([]*swarming.SwarmingRpcsTaskRequestMetadata, error) { |
| var mu sync.Mutex |
| results := make([]*swarming.SwarmingRpcsTaskRequestMetadata, 0, len(requests)) |
| err := parallel.WorkPool(8, func(gen chan<- func() error) { |
| for _, request := range requests { |
| request := request |
| gen <- func() error { |
| result, err := service.NewTask(ctx, request) |
| if err != nil { |
| return err |
| } |
| mu.Lock() |
| defer mu.Unlock() |
| results = append(results, result) |
| return nil |
| } |
| } |
| }) |
| return results, err |
| } |