blob: 5cd30b7f036067d1ab3c2dde35db27aa90f042e9 [file] [log] [blame]
// Copyright 2013 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pretty
import (
"encoding"
"fmt"
"reflect"
"sort"
)
func isZeroVal(val reflect.Value) bool {
if !val.CanInterface() {
return false
}
z := reflect.Zero(val.Type()).Interface()
return reflect.DeepEqual(val.Interface(), z)
}
// pointerTracker is a helper for tracking pointer chasing to detect cycles.
type pointerTracker struct {
addrs map[uintptr]int // addr[address] = seen count
lastID int
ids map[uintptr]int // ids[address] = id
}
// track tracks following a reference (pointer, slice, map, etc). Every call to
// track should be paired with a call to untrack.
func (p *pointerTracker) track(ptr uintptr) {
if p.addrs == nil {
p.addrs = make(map[uintptr]int)
}
p.addrs[ptr]++
}
// untrack registers that we have backtracked over the reference to the pointer.
func (p *pointerTracker) untrack(ptr uintptr) {
p.addrs[ptr]--
if p.addrs[ptr] == 0 {
delete(p.addrs, ptr)
}
}
// seen returns whether the pointer was previously seen along this path.
func (p *pointerTracker) seen(ptr uintptr) bool {
_, ok := p.addrs[ptr]
return ok
}
// keep allocates an ID for the given address and returns it.
func (p *pointerTracker) keep(ptr uintptr) int {
if p.ids == nil {
p.ids = make(map[uintptr]int)
}
if _, ok := p.ids[ptr]; !ok {
p.lastID++
p.ids[ptr] = p.lastID
}
return p.ids[ptr]
}
// id returns the ID for the given address.
func (p *pointerTracker) id(ptr uintptr) (int, bool) {
if p.ids == nil {
p.ids = make(map[uintptr]int)
}
id, ok := p.ids[ptr]
return id, ok
}
// reflector adds local state to the recursive reflection logic.
type reflector struct {
*Config
*pointerTracker
}
// follow handles following a possiblly-recursive reference to the given value
// from the given ptr address.
func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
if r.pointerTracker == nil {
// Tracking disabled
return r.val2node(val)
}
// If a parent already followed this, emit a reference marker
if r.seen(ptr) {
id := r.keep(ptr)
return ref{id}
}
// Track the pointer we're following while on this recursive branch
r.track(ptr)
defer r.untrack(ptr)
n := r.val2node(val)
// If the recursion used this ptr, wrap it with a target marker
if id, ok := r.id(ptr); ok {
return target{id, n}
}
// Otherwise, return the node unadulterated
return n
}
func (r *reflector) val2node(val reflect.Value) node {
if !val.IsValid() {
return rawVal("nil")
}
if val.CanInterface() {
v := val.Interface()
if formatter, ok := r.Formatter[val.Type()]; ok {
if formatter != nil {
res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
return rawVal(res[0].Interface().(string))
}
} else {
if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
return stringVal(s.String())
}
if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
if raw, err := t.MarshalText(); err == nil { // if NOT an error
return stringVal(string(raw))
}
}
}
}
switch kind := val.Kind(); kind {
case reflect.Ptr:
if val.IsNil() {
return rawVal("nil")
}
return r.follow(val.Pointer(), val.Elem())
case reflect.Interface:
if val.IsNil() {
return rawVal("nil")
}
return r.val2node(val.Elem())
case reflect.String:
return stringVal(val.String())
case reflect.Slice:
n := list{}
length := val.Len()
ptr := val.Pointer()
for i := 0; i < length; i++ {
n = append(n, r.follow(ptr, val.Index(i)))
}
return n
case reflect.Array:
n := list{}
length := val.Len()
for i := 0; i < length; i++ {
n = append(n, r.val2node(val.Index(i)))
}
return n
case reflect.Map:
// Extract the keys and sort them for stable iteration
keys := val.MapKeys()
pairs := make([]mapPair, 0, len(keys))
for _, key := range keys {
pairs = append(pairs, mapPair{
key: new(formatter).compactString(r.val2node(key)), // can't be cyclic
value: val.MapIndex(key),
})
}
sort.Sort(byKey(pairs))
// Process the keys into the final representation
ptr, n := val.Pointer(), keyvals{}
for _, pair := range pairs {
n = append(n, keyval{
key: pair.key,
val: r.follow(ptr, pair.value),
})
}
return n
case reflect.Struct:
n := keyvals{}
typ := val.Type()
fields := typ.NumField()
for i := 0; i < fields; i++ {
sf := typ.Field(i)
if !r.IncludeUnexported && sf.PkgPath != "" {
continue
}
field := val.Field(i)
if r.SkipZeroFields && isZeroVal(field) {
continue
}
n = append(n, keyval{sf.Name, r.val2node(field)})
}
return n
case reflect.Bool:
if val.Bool() {
return rawVal("true")
}
return rawVal("false")
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rawVal(fmt.Sprintf("%d", val.Int()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rawVal(fmt.Sprintf("%d", val.Uint()))
case reflect.Uintptr:
return rawVal(fmt.Sprintf("0x%X", val.Uint()))
case reflect.Float32, reflect.Float64:
return rawVal(fmt.Sprintf("%v", val.Float()))
case reflect.Complex64, reflect.Complex128:
return rawVal(fmt.Sprintf("%v", val.Complex()))
}
// Fall back to the default %#v if we can
if val.CanInterface() {
return rawVal(fmt.Sprintf("%#v", val.Interface()))
}
return rawVal(val.String())
}
type mapPair struct {
key string
value reflect.Value
}
type byKey []mapPair
func (v byKey) Len() int { return len(v) }
func (v byKey) Swap(i, j int) { v[i], v[j] = v[j], v[i] }
func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }