You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
garble/reflect.go

481 lines
13 KiB
Go

package main
import (
"fmt"
"go/types"
"path/filepath"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/tools/go/ssa"
)
type reflectInspector struct {
pkg *types.Package
checkedAPIs map[string]bool
result pkgCache
}
// Record all instances of reflection use, and don't obfuscate types which are used in reflection.
func (ri *reflectInspector) recordReflection(ssaPkg *ssa.Package) {
if reflectSkipPkg[ssaPkg.Pkg.Path()] {
return
}
lenPrevReflectAPIs := len(ri.result.ReflectAPIs)
// find all unchecked APIs to add them to checkedAPIs after the pass
notCheckedAPIs := make(map[string]bool)
for _, knownAPI := range maps.Keys(ri.result.ReflectAPIs) {
if !ri.checkedAPIs[knownAPI] {
notCheckedAPIs[knownAPI] = true
}
}
ri.ignoreReflectedTypes(ssaPkg)
// all previously unchecked APIs have now been checked add them to checkedAPIs,
// to avoid checking them twice
maps.Copy(ri.checkedAPIs, notCheckedAPIs)
// if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API
if len(ri.result.ReflectAPIs) > lenPrevReflectAPIs {
ri.recordReflection(ssaPkg)
}
}
// find all functions, methods and interface declarations of a package and record their
// reflection use
func (ri *reflectInspector) ignoreReflectedTypes(ssaPkg *ssa.Package) {
// Some packages reach into reflect internals, like go-spew.
// It's not particularly right of them to do that,
// and it's entirely unsupported, but try to accomodate for now.
// At least it's enough to leave the rtype and Value types intact.
if ri.pkg.Path() == "reflect" {
scope := ri.pkg.Scope()
ri.recursivelyRecordUsedForReflect(scope.Lookup("rtype").Type())
ri.recursivelyRecordUsedForReflect(scope.Lookup("Value").Type())
}
for _, memb := range ssaPkg.Members {
switch x := memb.(type) {
case *ssa.Type:
// methods aren't package members only their reciever types are
// so some logic is required to find the methods a type has
method := func(mset *types.MethodSet) {
for i, n := 0, mset.Len(); i < n; i++ {
at := mset.At(i)
if m := ssaPkg.Prog.MethodValue(at); m != nil {
ri.checkFunction(m)
} else {
m := at.Obj().(*types.Func)
// handle interface declarations
ri.checkInterfaceMethod(m)
}
}
}
// yes, finding all methods really only works with both calls
mset := ssaPkg.Prog.MethodSets.MethodSet(x.Type())
method(mset)
mset = ssaPkg.Prog.MethodSets.MethodSet(types.NewPointer(x.Type()))
method(mset)
case *ssa.Function:
// these not only include top level functions, but also synthetic
// functions like the initialization of global variables
ri.checkFunction(x)
}
}
}
// Exported methods with unnamed structs as paramters may be "used" in interface declarations
// elsewhere, these interfaces will break if any method uses reflection on the same parameter.
//
// Therefore never obfuscate unnamed structs which are used as a method parameter
// and treat them like a parameter which is actually used in reflection.
//
// See "UnnamedStructMethod" in the reflect.txtar test for an example.
func (ri *reflectInspector) checkMethodSignature(reflectParams map[int]bool, sig *types.Signature) {
if sig.Recv() == nil {
return
}
params := sig.Params()
for i := 0; i < params.Len(); i++ {
if reflectParams[i] {
continue
}
ignore := false
param := params.At(i)
switch x := param.Type().(type) {
case *types.Struct:
ignore = true
case *types.Array:
if _, ok := x.Elem().(*types.Struct); ok {
ignore = true
}
case *types.Slice:
if _, ok := x.Elem().(*types.Struct); ok {
ignore = true
}
}
if ignore {
reflectParams[i] = true
ri.recursivelyRecordUsedForReflect(param.Type())
}
}
}
// Checks the signature of an interface method for potential reflection use.
func (ri *reflectInspector) checkInterfaceMethod(m *types.Func) {
reflectParams := make(map[int]bool)
maps.Copy(reflectParams, ri.result.ReflectAPIs[m.FullName()])
sig := m.Type().(*types.Signature)
if m.Exported() {
ri.checkMethodSignature(reflectParams, sig)
}
if len(reflectParams) > 0 {
ri.result.ReflectAPIs[m.FullName()] = reflectParams
/* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */
}
}
// Checks all callsites in a function declaration for use of reflection.
func (ri *reflectInspector) checkFunction(fun *ssa.Function) {
/* if fun != nil && fun.Synthetic != "loaded from gc object file" {
// fun.WriteTo crashes otherwise
fun.WriteTo(os.Stdout)
} */
f, _ := fun.Object().(*types.Func)
reflectParams := make(map[int]bool)
if f != nil {
maps.Copy(reflectParams, ri.result.ReflectAPIs[f.FullName()])
if f.Exported() {
ri.checkMethodSignature(reflectParams, fun.Signature)
}
}
/* fmt.Printf("f: %v\n", f)
fmt.Printf("fun: %v\n", fun) */
for _, block := range fun.Blocks {
for _, inst := range block.Instrs {
/* fmt.Printf("inst: %v, t: %T\n", inst, inst) */
call, ok := inst.(*ssa.Call)
if !ok {
continue
}
callName := call.Call.Value.String()
if m := call.Call.Method; m != nil {
callName = call.Call.Method.FullName()
}
if ri.checkedAPIs[callName] {
// only check apis which were not already checked
continue
}
/* fmt.Printf("callName: %v\n", callName) */
// record each call argument passed to a function parameter which is used in reflection
knownParams := ri.result.ReflectAPIs[callName]
for knownParam := range knownParams {
if len(call.Call.Args) <= knownParam {
continue
}
arg := call.Call.Args[knownParam]
/* fmt.Printf("flagging arg: %v\n", arg) */
visited := make(map[ssa.Value]bool)
reflectedParam := ri.recordArgReflected(arg, visited)
if reflectedParam == nil {
continue
}
pos := slices.Index(fun.Params, reflectedParam)
if pos < 0 {
continue
}
/* fmt.Printf("recorded param: %v func: %v\n", pos, fun) */
reflectParams[pos] = true
}
}
}
if len(reflectParams) > 0 {
ri.result.ReflectAPIs[f.FullName()] = reflectParams
/* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */
}
}
// recordArgReflected finds the type(s) of a function argument, which is being used in reflection
// and excludes these types from obfuscation
// It also checks if this argument has any relation to a function paramter and returns it if found.
func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter {
// make sure we visit every val only once, otherwise there will be infinite recursion
if visited[val] {
return nil
}
/* fmt.Printf("val: %v %T %v\n", val, val, val.Type()) */
visited[val] = true
switch val := val.(type) {
case *ssa.IndexAddr:
for _, ref := range *val.Referrers() {
if store, ok := ref.(*ssa.Store); ok {
ri.recordArgReflected(store.Val, visited)
}
}
return ri.recordArgReflected(val.X, visited)
case *ssa.Slice:
return ri.recordArgReflected(val.X, visited)
case *ssa.MakeInterface:
return ri.recordArgReflected(val.X, visited)
case *ssa.UnOp:
return ri.recordArgReflected(val.X, visited)
case *ssa.FieldAddr:
return ri.recordArgReflected(val.X, visited)
case *ssa.Alloc:
/* fmt.Printf("recording val %v \n", *val.Referrers()) */
ri.recursivelyRecordUsedForReflect(val.Type())
for _, ref := range *val.Referrers() {
if idx, ok := ref.(*ssa.IndexAddr); ok {
ri.recordArgReflected(idx, visited)
}
}
// relatedParam needs to revisit nodes so create an empty map
visited := make(map[ssa.Value]bool)
// check if the found alloc gets tainted by function parameters
return relatedParam(val, visited)
case *ssa.Const:
ri.recursivelyRecordUsedForReflect(val.Type())
case *ssa.Global:
ri.recursivelyRecordUsedForReflect(val.Type())
// TODO: this might need similar logic to *ssa.Alloc, however
// reassigning a function param to a global variable and then reflecting
// it is probably unlikely to occur
case *ssa.Parameter:
// this only finds the parameters who want to be found,
// otherwise relatedParam is used for more in depth analysis
ri.recursivelyRecordUsedForReflect(val.Type())
return val
}
return nil
}
// relatedParam checks if a route to a function paramter can be constructed
// from a ssa.Value, and returns the paramter if it found one.
func relatedParam(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter {
// every val should only be visited once to prevent infinite recursion
if visited[val] {
return nil
}
/* fmt.Printf("related val: %v %T %v\n", val, val, val.Type()) */
visited[val] = true
switch x := val.(type) {
case *ssa.Parameter:
// a paramter has been found
return x
case *ssa.UnOp:
if param := relatedParam(x.X, visited); param != nil {
return param
}
case *ssa.FieldAddr:
/* fmt.Printf("addr: %v\n", x)
fmt.Printf("addr.X: %v %T\n", x.X, x.X) */
if param := relatedParam(x.X, visited); param != nil {
return param
}
}
refs := val.Referrers()
if refs == nil {
return nil
}
for _, ref := range *refs {
/* fmt.Printf("ref: %v %T\n", ref, ref) */
var param *ssa.Parameter
switch ref := ref.(type) {
case *ssa.FieldAddr:
param = relatedParam(ref, visited)
case *ssa.UnOp:
param = relatedParam(ref, visited)
case *ssa.Store:
if param := relatedParam(ref.Val, visited); param != nil {
return param
}
param = relatedParam(ref.Addr, visited)
}
if param != nil {
return param
}
}
return nil
}
// recursivelyRecordUsedForReflect calls recordUsedForReflect on any named
// types and fields under typ.
//
// Only the names declared in the current package are recorded. This is to ensure
// that reflection detection only happens within the package declaring a type.
// Detecting it in downstream packages could result in inconsistencies.
func (ri *reflectInspector) recursivelyRecordUsedForReflect(t types.Type) {
switch t := t.(type) {
case *types.Named:
obj := t.Obj()
if obj.Pkg() == nil || obj.Pkg() != ri.pkg {
return // not from the specified package
}
if usedForReflect(ri.result, obj) {
return // prevent endless recursion
}
ri.recordUsedForReflect(obj)
// Record the underlying type, too.
ri.recursivelyRecordUsedForReflect(t.Underlying())
case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// This check is similar to the one in *types.Named.
// It's necessary for unnamed struct types,
// as they aren't named but still have named fields.
if field.Pkg() == nil || field.Pkg() != ri.pkg {
return // not from the specified package
}
// Record the field itself, too.
ri.recordUsedForReflect(field)
ri.recursivelyRecordUsedForReflect(field.Type())
}
case interface{ Elem() types.Type }:
// Get past pointers, slices, etc.
ri.recursivelyRecordUsedForReflect(t.Elem())
}
}
// TODO: consider caching recordedObjectString via a map,
// if that shows an improvement in our benchmark
func recordedObjectString(obj types.Object) objectString {
// For exported fields, "pkgpath.Field" is not unique,
// because two exported top-level types could share "Field".
//
// Moreover, note that not all fields belong to named struct types;
// an API could be exposing:
//
// var usedInReflection = struct{Field string}
//
// For now, a hack: assume that packages don't declare the same field
// more than once in the same line. This works in practice, but one
// could craft Go code to break this assumption.
// Also note that the compiler's object files include filenames and line
// numbers, but not column numbers nor byte offsets.
// TODO(mvdan): give this another think, and add tests involving anon types.
// Note that fields are never top-level.
pkg := obj.Pkg()
if pkg.Scope() != obj.Parent() {
pos := fset.Position(obj.Pos())
return fmt.Sprintf("%s.%s - %s:%d", pkg.Path(), obj.Name(),
filepath.Base(pos.Filename), pos.Line)
}
// For top-level exported names, "pkgpath.Name" is unique.
return pkg.Path() + "." + obj.Name()
}
// recordUsedForReflect records the objects whose names we cannot obfuscate due to reflection.
// We currently record named types and fields.
func (ri *reflectInspector) recordUsedForReflect(obj types.Object) {
if obj.Pkg().Path() != ri.pkg.Path() {
panic("called recordUsedForReflect with a foreign object")
}
if obj, ok := obj.(*types.Var); ok && obj.IsField() {
ri.result.ReflectObjects[recordedObjectString(obj)] = struct{}{}
if !obj.Embedded() {
return
}
embeddedType, ok := obj.Type().(*types.Named)
if !ok {
return
}
embeddedObj := embeddedType.Obj()
if embeddedObj.Pkg().Scope() == embeddedObj.Parent() {
// not local type
return
}
if embeddedObj.Pkg() == nil || embeddedObj.Pkg() != ri.pkg {
// not from the specified package
return
}
ri.result.ReflectObjects[recordedObjectString(embeddedObj)] = struct{}{}
return
}
// we don't need to record the local type names
if obj.Pkg().Scope() != obj.Parent() {
return
}
ri.result.ReflectObjects[recordedObjectString(obj)] = struct{}{}
}
func usedForReflect(cache pkgCache, obj types.Object) bool {
_, ok := cache.ReflectObjects[recordedObjectString(obj)]
return ok
}