Merge pull request #6 from cpuguy83/ctx_termination
Backport Honor context termination
diff --git a/api/grpc/server/server.go b/api/grpc/server/server.go
index bb7ad02..f57a357 100644
--- a/api/grpc/server/server.go
+++ b/api/grpc/server/server.go
@@ -56,6 +56,7 @@
e.Runtime = c.Runtime
e.RuntimeArgs = c.RuntimeArgs
e.StartResponse = make(chan supervisor.StartResponse, 1)
+ e.Ctx = ctx
if c.Checkpoint != "" {
e.CheckpointDir = c.CheckpointDir
e.Checkpoint = &runtime.Checkpoint{
diff --git a/api/grpc/server/server_linux.go b/api/grpc/server/server_linux.go
index 1051f1f..a2eff3b 100644
--- a/api/grpc/server/server_linux.go
+++ b/api/grpc/server/server_linux.go
@@ -50,6 +50,7 @@
e.Stdout = r.Stdout
e.Stderr = r.Stderr
e.StartResponse = make(chan supervisor.StartResponse, 1)
+ e.Ctx = ctx
s.sv.SendTask(e)
if err := <-e.ErrorCh(); err != nil {
return nil, err
diff --git a/runtime/container.go b/runtime/container.go
index c4a2f7e..c26ce2c 100644
--- a/runtime/container.go
+++ b/runtime/container.go
@@ -14,6 +14,7 @@
"github.com/Sirupsen/logrus"
"github.com/docker/containerd/specs"
ocs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/net/context"
"golang.org/x/sys/unix"
)
@@ -24,9 +25,9 @@
// Path returns the path to the bundle
Path() string
// Start starts the init process of the container
- Start(checkpointPath string, s Stdio) (Process, error)
+ Start(ctx context.Context, checkpointPath string, s Stdio) (Process, error)
// Exec starts another process in an existing container
- Exec(string, specs.ProcessSpec, Stdio) (Process, error)
+ Exec(context.Context, string, specs.ProcessSpec, Stdio) (Process, error)
// Delete removes the container's state and any resources
Delete() error
// Processes returns all the containers processes that have been added
@@ -186,7 +187,7 @@
}
p, err := loadProcess(filepath.Join(root, id, pid), pid, c, s)
if err != nil {
- logrus.WithField("id", id).WithField("pid", pid).Debug("containerd: error loading process %s", err)
+ logrus.WithField("id", id).WithField("pid", pid).Debugf("containerd: error loading process %s", err)
continue
}
c.processes[pid] = p
@@ -395,7 +396,7 @@
return os.RemoveAll(filepath.Join(checkpointDir, name))
}
-func (c *container) Start(checkpointPath string, s Stdio) (Process, error) {
+func (c *container) Start(ctx context.Context, checkpointPath string, s Stdio) (Process, error) {
processRoot := filepath.Join(c.root, c.id, InitProcessID)
if err := os.Mkdir(processRoot, 0755); err != nil {
return nil, err
@@ -424,13 +425,13 @@
if err != nil {
return nil, err
}
- if err := c.createCmd(InitProcessID, cmd, p); err != nil {
+ if err := c.createCmd(ctx, InitProcessID, cmd, p); err != nil {
return nil, err
}
return p, nil
}
-func (c *container) Exec(pid string, pspec specs.ProcessSpec, s Stdio) (pp Process, err error) {
+func (c *container) Exec(ctx context.Context, pid string, pspec specs.ProcessSpec, s Stdio) (pp Process, err error) {
processRoot := filepath.Join(c.root, c.id, pid)
if err := os.Mkdir(processRoot, 0755); err != nil {
return nil, err
@@ -464,13 +465,13 @@
if err != nil {
return nil, err
}
- if err := c.createCmd(pid, cmd, p); err != nil {
+ if err := c.createCmd(ctx, pid, cmd, p); err != nil {
return nil, err
}
return p, nil
}
-func (c *container) createCmd(pid string, cmd *exec.Cmd, p *process) error {
+func (c *container) createCmd(ctx context.Context, pid string, cmd *exec.Cmd, p *process) error {
p.cmd = cmd
if err := cmd.Start(); err != nil {
close(p.cmdDoneCh)
@@ -509,10 +510,25 @@
close(p.cmdDoneCh)
}()
}()
- if err := c.waitForCreate(p, cmd); err != nil {
+
+ ch := make(chan error)
+ go func() {
+ if err := c.waitForCreate(p, cmd); err != nil {
+ ch <- err
+ return
+ }
+ c.processes[pid] = p
+ ch <- nil
+ }()
+ select {
+ case <-ctx.Done():
+ cmd.Process.Kill()
+ cmd.Wait()
+ <-ch
+ return ctx.Err()
+ case err := <-ch:
return err
}
- c.processes[pid] = p
return nil
}
diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go
index 3581493..4edf999 100644
--- a/runtime/runtime_test.go
+++ b/runtime/runtime_test.go
@@ -1,6 +1,7 @@
package runtime
import (
+ "context"
"flag"
"fmt"
"io"
@@ -163,7 +164,7 @@
}
func benchmarkStartContainer(b *testing.B, c Container, s Stdio, bundleName string) {
- p, err := c.Start("", s)
+ p, err := c.Start(context.Background(), "", s)
if err != nil {
b.Fatalf("Error starting container %v", err)
}
diff --git a/supervisor/add_process.go b/supervisor/add_process.go
index ab0f689..028ba05 100644
--- a/supervisor/add_process.go
+++ b/supervisor/add_process.go
@@ -5,6 +5,7 @@
"github.com/docker/containerd/runtime"
"github.com/docker/containerd/specs"
+ "golang.org/x/net/context"
)
// AddProcessTask holds everything necessary to add a process to a
@@ -18,6 +19,7 @@
Stdin string
ProcessSpec *specs.ProcessSpec
StartResponse chan StartResponse
+ Ctx context.Context
}
func (s *Supervisor) addProcess(t *AddProcessTask) error {
@@ -26,7 +28,7 @@
if !ok {
return ErrContainerNotFound
}
- process, err := ci.container.Exec(t.PID, *t.ProcessSpec, runtime.NewStdio(t.Stdin, t.Stdout, t.Stderr))
+ process, err := ci.container.Exec(t.Ctx, t.PID, *t.ProcessSpec, runtime.NewStdio(t.Stdin, t.Stdout, t.Stderr))
if err != nil {
return err
}
diff --git a/supervisor/create.go b/supervisor/create.go
index fa0defe..c78f100 100644
--- a/supervisor/create.go
+++ b/supervisor/create.go
@@ -5,6 +5,7 @@
"time"
"github.com/docker/containerd/runtime"
+ "golang.org/x/net/context"
)
// StartTask holds needed parameters to create a new container
@@ -22,6 +23,7 @@
CheckpointDir string
Runtime string
RuntimeArgs []string
+ Ctx context.Context
}
func (s *Supervisor) start(t *StartTask) error {
@@ -57,6 +59,7 @@
Stdin: t.Stdin,
Stdout: t.Stdout,
Stderr: t.Stderr,
+ Ctx: t.Ctx,
}
if t.Checkpoint != nil {
task.CheckpointPath = filepath.Join(t.CheckpointDir, t.Checkpoint.Name)
diff --git a/supervisor/worker.go b/supervisor/worker.go
index 250d396..e3d6a63 100644
--- a/supervisor/worker.go
+++ b/supervisor/worker.go
@@ -6,6 +6,7 @@
"github.com/Sirupsen/logrus"
"github.com/docker/containerd/runtime"
+ "golang.org/x/net/context"
)
// Worker interface
@@ -21,6 +22,7 @@
Stderr string
Err chan error
StartResponse chan StartResponse
+ Ctx context.Context
}
// NewWorker return a new initialized worker
@@ -41,7 +43,7 @@
defer w.wg.Done()
for t := range w.s.startTasks {
started := time.Now()
- process, err := t.Container.Start(t.CheckpointPath, runtime.NewStdio(t.Stdin, t.Stdout, t.Stderr))
+ process, err := t.Container.Start(t.Ctx, t.CheckpointPath, runtime.NewStdio(t.Stdin, t.Stdout, t.Stderr))
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,