blob: 78d54cb05c5abe5fa7baca08a2b224d9974266f4 [file] [log] [blame]
// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package memory
import (
"context"
"fmt"
net_mail "net/mail"
"net/textproto"
"path/filepath"
"strings"
"sync"
"go.chromium.org/luci/gae/service/mail"
"go.chromium.org/luci/gae/service/user"
)
type mailData struct {
sync.Mutex
queue []*mail.TestMessage
admins []string
adminsPlain []string
}
// mailImpl is a contextual pointer to the current mailData.
type mailImpl struct {
context.Context
data *mailData
}
var _ mail.RawInterface = (*mailImpl)(nil)
// useMail adds a mail.RawInterface implementation to context, accessible
// by mail.Raw(c) or the exported mail methods.
func useMail(c context.Context) context.Context {
data := &mailData{
admins: []string{"admin@example.com"},
adminsPlain: []string{"admin@example.com"},
}
return mail.SetFactory(c, func(ic context.Context) mail.RawInterface {
return &mailImpl{ic, data}
})
}
func parseEmails(emails ...string) error {
for _, e := range emails {
if _, err := net_mail.ParseAddress(e); err != nil {
return fmt.Errorf("invalid email (%q): %s", e, err)
}
}
return nil
}
func checkMessage(msg *mail.TestMessage, adminsPlain []string, user string) error {
sender, err := net_mail.ParseAddress(msg.Sender)
if err != nil {
return fmt.Errorf("unparsable Sender address: %s: %s", msg.Sender, err)
}
senderOK := user != "" && sender.Address == user
if !senderOK {
for _, a := range adminsPlain {
if sender.Address == a {
senderOK = true
break
}
}
}
if !senderOK {
return fmt.Errorf("invalid Sender: %s", msg.Sender)
}
if len(msg.To) == 0 && len(msg.Cc) == 0 && len(msg.Bcc) == 0 {
return fmt.Errorf("one of To, Cc or Bcc must be non-empty")
}
if err := parseEmails(msg.To...); err != nil {
return err
}
if err := parseEmails(msg.Cc...); err != nil {
return err
}
if err := parseEmails(msg.Bcc...); err != nil {
return err
}
if len(msg.Body) == 0 && len(msg.HTMLBody) == 0 {
return fmt.Errorf("one of Body or HTMLBody must be non-empty")
}
if len(msg.Attachments) > 0 {
msg.MIMETypes = make([]string, len(msg.Attachments))
for i := range msg.Attachments {
n := msg.Attachments[i].Name
ext := strings.TrimLeft(strings.ToLower(filepath.Ext(n)), ".")
if badExtensions.Has(ext) {
return fmt.Errorf("illegal attachment extension for %q", n)
}
mimetype := extensionMapping[ext]
if mimetype == "" {
mimetype = "application/octet-stream"
}
msg.MIMETypes[i] = mimetype
}
}
fixKeys := map[string]string{}
for k := range msg.Headers {
canonK := textproto.CanonicalMIMEHeaderKey(k)
if !okHeaders.Has(canonK) {
return fmt.Errorf("disallowed header: %s", k)
}
if canonK != k {
fixKeys[k] = canonK
}
}
for k, canonK := range fixKeys {
vals := msg.Headers[k]
delete(msg.Headers, k)
msg.Headers[canonK] = vals
}
return nil
}
func (m *mailImpl) sendImpl(msg *mail.Message) error {
email := ""
if u := user.Current(m); u != nil {
email = u.Email
}
m.data.Lock()
adminsPlain := m.data.adminsPlain[:]
m.data.Unlock()
testMsg := &mail.TestMessage{Message: *msg}
if err := checkMessage(testMsg, adminsPlain, email); err != nil {
return err
}
m.data.Lock()
m.data.queue = append(m.data.queue, testMsg)
m.data.Unlock()
return nil
}
func (m *mailImpl) Send(msg *mail.Message) error {
return m.sendImpl(msg.Copy())
}
func (m *mailImpl) SendToAdmins(msg *mail.Message) error {
msg = msg.Copy()
m.data.Lock()
ads := m.data.admins[:]
m.data.Unlock()
msg.To = make([]string, len(ads))
copy(msg.To, ads)
return m.sendImpl(msg)
}
func (m *mailImpl) GetTestable() mail.Testable { return m }
func (m *mailImpl) SetAdminEmails(emails ...string) {
adminsPlain := make([]string, len(emails))
for i, e := range emails {
adr, err := net_mail.ParseAddress(e)
if err != nil {
panic(fmt.Errorf("invalid email (%q): %s", e, err))
}
adminsPlain[i] = adr.Address
}
m.data.Lock()
m.data.admins = emails
m.data.adminsPlain = adminsPlain
m.data.Unlock()
}
func (m *mailImpl) SentMessages() []*mail.TestMessage {
m.data.Lock()
msgs := m.data.queue[:]
m.data.Unlock()
ret := make([]*mail.TestMessage, len(msgs))
for i, m := range msgs {
ret[i] = m.Copy()
}
return ret
}
func (m *mailImpl) Reset() {
m.data.Lock()
m.data.queue = nil
m.data.Unlock()
}