Fix ast.Walk to respect WalkStop

Fixes #97
This commit is contained in:
Joe Schafer 2020-02-07 01:42:40 -08:00
parent 39db45a099
commit 3c340e9970
2 changed files with 68 additions and 6 deletions

View file

@ -468,20 +468,25 @@ type Walker func(n Node, entering bool) (WalkStatus, error)
// Walk walks a AST tree by the depth first search algorithm. // Walk walks a AST tree by the depth first search algorithm.
func Walk(n Node, walker Walker) error { func Walk(n Node, walker Walker) error {
_, err := walkHelper(n, walker)
return err
}
func walkHelper(n Node, walker Walker) (WalkStatus, error) {
status, err := walker(n, true) status, err := walker(n, true)
if err != nil || status == WalkStop { if err != nil || status == WalkStop {
return err return status, err
} }
if status != WalkSkipChildren { if status != WalkSkipChildren {
for c := n.FirstChild(); c != nil; c = c.NextSibling() { for c := n.FirstChild(); c != nil; c = c.NextSibling() {
if err = Walk(c, walker); err != nil { if st, err := walkHelper(c, walker); err != nil || st == WalkStop {
return err return WalkStop, err
} }
} }
} }
status, err = walker(n, false) status, err = walker(n, false)
if err != nil || status == WalkStop { if err != nil || status == WalkStop {
return err return WalkStop, err
} }
return nil return WalkContinue, nil
} }

View file

@ -1,6 +1,9 @@
package ast package ast
import "testing" import (
"reflect"
"testing"
)
func TestRemoveChildren(t *testing.T) { func TestRemoveChildren(t *testing.T) {
root := NewDocument() root := NewDocument()
@ -16,3 +19,57 @@ func TestRemoveChildren(t *testing.T) {
t.Logf("%+v", node2.PreviousSibling()) t.Logf("%+v", node2.PreviousSibling())
} }
func TestWalk(t *testing.T) {
tests := []struct {
name string
node Node
want []NodeKind
action map[NodeKind]WalkStatus
}{
{
"visits all in depth first order",
node(NewDocument(), node(NewHeading(1), NewText()), NewLink()),
[]NodeKind{KindDocument, KindHeading, KindText, KindLink},
map[NodeKind]WalkStatus{},
},
{
"stops after heading",
node(NewDocument(), node(NewHeading(1), NewText()), NewLink()),
[]NodeKind{KindDocument, KindHeading},
map[NodeKind]WalkStatus{KindHeading: WalkStop},
},
{
"skip children",
node(NewDocument(), node(NewHeading(1), NewText()), NewLink()),
[]NodeKind{KindDocument, KindHeading, KindLink},
map[NodeKind]WalkStatus{KindHeading: WalkSkipChildren},
},
}
for _, tt := range tests {
var kinds []NodeKind
collectKinds := func(n Node, entering bool) (WalkStatus, error) {
if entering {
kinds = append(kinds, n.Kind())
}
if status, ok := tt.action[n.Kind()]; ok {
return status, nil
}
return WalkContinue, nil
}
t.Run(tt.name, func(t *testing.T) {
if err := Walk(tt.node, collectKinds); err != nil {
t.Errorf("Walk() error = %v", err)
} else if !reflect.DeepEqual(kinds, tt.want) {
t.Errorf("Walk() expected = %v, got = %v", tt.want, kinds)
}
})
}
}
func node(n Node, children ...Node) Node {
for _, c := range children {
n.AppendChild(n, c)
}
return n
}