From 3c340e997071af149c13623ed3211ed6f675a6db Mon Sep 17 00:00:00 2001 From: Joe Schafer Date: Fri, 7 Feb 2020 01:42:40 -0800 Subject: [PATCH] Fix ast.Walk to respect WalkStop Fixes #97 --- ast/ast.go | 15 ++++++++----- ast/ast_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index 3ba6447..66059e9 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -468,20 +468,25 @@ type Walker func(n Node, entering bool) (WalkStatus, error) // Walk walks a AST tree by the depth first search algorithm. func Walk(n Node, walker Walker) error { + _, err := walkHelper(n, walker) + return err +} + +func walkHelper(n Node, walker Walker) (WalkStatus, error) { status, err := walker(n, true) if err != nil || status == WalkStop { - return err + return status, err } if status != WalkSkipChildren { for c := n.FirstChild(); c != nil; c = c.NextSibling() { - if err = Walk(c, walker); err != nil { - return err + if st, err := walkHelper(c, walker); err != nil || st == WalkStop { + return WalkStop, err } } } status, err = walker(n, false) if err != nil || status == WalkStop { - return err + return WalkStop, err } - return nil + return WalkContinue, nil } diff --git a/ast/ast_test.go b/ast/ast_test.go index 77a80f6..684fbc3 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -1,6 +1,9 @@ package ast -import "testing" +import ( + "reflect" + "testing" +) func TestRemoveChildren(t *testing.T) { root := NewDocument() @@ -16,3 +19,57 @@ func TestRemoveChildren(t *testing.T) { t.Logf("%+v", node2.PreviousSibling()) } + +func TestWalk(t *testing.T) { + tests := []struct { + name string + node Node + want []NodeKind + action map[NodeKind]WalkStatus + }{ + { + "visits all in depth first order", + node(NewDocument(), node(NewHeading(1), NewText()), NewLink()), + []NodeKind{KindDocument, KindHeading, KindText, KindLink}, + map[NodeKind]WalkStatus{}, + }, + { + "stops after heading", + node(NewDocument(), node(NewHeading(1), NewText()), NewLink()), + []NodeKind{KindDocument, KindHeading}, + map[NodeKind]WalkStatus{KindHeading: WalkStop}, + }, + { + "skip children", + node(NewDocument(), node(NewHeading(1), NewText()), NewLink()), + []NodeKind{KindDocument, KindHeading, KindLink}, + map[NodeKind]WalkStatus{KindHeading: WalkSkipChildren}, + }, + } + for _, tt := range tests { + var kinds []NodeKind + collectKinds := func(n Node, entering bool) (WalkStatus, error) { + if entering { + kinds = append(kinds, n.Kind()) + } + if status, ok := tt.action[n.Kind()]; ok { + return status, nil + } + return WalkContinue, nil + } + t.Run(tt.name, func(t *testing.T) { + if err := Walk(tt.node, collectKinds); err != nil { + t.Errorf("Walk() error = %v", err) + } else if !reflect.DeepEqual(kinds, tt.want) { + t.Errorf("Walk() expected = %v, got = %v", tt.want, kinds) + } + }) + } +} + +func node(n Node, children ...Node) Node { + for _, c := range children { + n.AppendChild(n, c) + } + return n +}