Creating new visitors
In the previous notebook, we relied heavily on the FindNodes visitor, which looks through a given IR tree and returns a list of matching instances of a specified Node type. Although this functionality is sufficient for most use cases, there may be scenarios that require the implementation of bespoke visitors.
For node types that could appear in a nested structure, for example Loop or Conditional, we may be interested in knowing at what depth they appear in a given IR tree. The following illustrates how this can be achieved by building a new FindNodesDepth
visitor based on FindNodes
.
Dataclass to store return values
The default return value for FindNodes
is a list of nodes. For FindNodesDepth
, we would also like to return the depth of the node. We can create a new dataclass (essentially a c-style struct) called DepthNode
to store both these pieces of information:
[1]:
from loki import Node
from dataclasses import dataclass
@dataclass
class DepthNode:
"""Store node object and depth in c-style struct."""
node: Node
depth: int
Modifying initialization method
FindNodes
has two operating modes. The first (and default mode) is to look through a given IR tree and return a list of matching instances of a specified node type. The second, which is enabled by passing mode='scope'
when creating the visitor, returns the InternalNode i.e. the Scope in which a specified node appears.
For our new visitor, we are only interested in the default operating mode of FindNodes
. Therefore let us define a new initialization function for our FindNodesDepth
class:
[2]:
from loki import FindNodes
class FindNodesDepth(FindNodes):
"""Visitor that computes node-depth relative to subroutine body. Returns list of DepthNode objects."""
def __init__(self, match, greedy=False):
super().__init__(match, mode='type', greedy=greedy)
Modifying the visit_Node
method
In order to achieve the desired functionality of our new visitor, we will need a new visit_Node
method. We start from a copy of FindNodes.visit_Node
and make only a few changes to it:
[3]:
def visit_Node(self, o, **kwargs):
"""
Add the node to the returned list if it matches the criteria and increment depth
before visiting all children.
"""
ret = kwargs.pop('ret', self.default_retval())
depth = kwargs.pop('depth', 0)
if self.rule(self.match, o):
ret.append(DepthNode(o, depth))
if self.greedy:
return ret
for i in o.children:
ret = self.visit(i, depth=depth+1, ret=ret, **kwargs)
return ret or self.default_retval()
FindNodesDepth.visit_Node = visit_Node
The first change to visit_Node
is the addition of a line that sets depth
. If visit_Node
is called from the base IR tree, then depth
is initialized to 0. If on the other hand visit_Node
is called recursively, then the current depth
of node o
is retrieved. The second and final change is the addition of a depth
keyword argument to the recursive call to visit
for the children of node o
. As recursion signifies moving down one level in the IR tree, the
depth+1
is passed as an argument.
Having now fully defined our new visitor, we can test it on the following routine containing nested loops:
[4]:
from loki import Sourcefile
from loki import fgen
source = Sourcefile.from_file('src/loop_fuse.F90')
routine = source['loop_fuse_v1']
print(fgen(routine.body))
DO k=1,n
DO j=1,n
DO i=1,n
var_out(i, j, k) = var_in(i, j, k)
END DO
DO i=1,n
var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
END DO
END DO
CALL some_kernel(n, var_out(1, 1, k))
DO j=1,n
DO i=1,n
var_out(i, j, k) = var_out(i, j, k) + 1._JPRB
END DO
DO i=1,n
var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
END DO
END DO
END DO
loop_fuse_v1
contains a total of 7 loops, with a maximum nesting depth of 3. Let us see if our new visitor can identify the loops and their depth correctly:
[5]:
from loki import Loop
loops = FindNodesDepth(Loop).visit(routine.body)
for k, loop in enumerate(loops):
print(k, loop.node, loop.depth)
depths = [1, 2, 3, 3, 2, 3, 3]
assert(depths == [loop.depth for loop in loops])
0 Loop:: k=1:n 1
1 Loop:: j=1:n 2
2 Loop:: i=1:n 3
3 Loop:: i=1:n 3
4 Loop:: j=1:n 2
5 Loop:: i=1:n 3
6 Loop:: i=1:n 3
As the output shows, the depth of all 7 loops was identified correctly. Note that the subroutine body itself is assigned a depth of 0, and because the outermost k
-loop is a child of the subroutine body, it has a depth of 1.
We can also use our new visitor to find the depth of the Assignment statements within the bodies of the loops:
[6]:
from loki import Assignment
assigns = FindNodesDepth(Assignment).visit(routine.body)
for k, assign in enumerate(assigns):
print(f'{k} {str(assign.node):<60}{assign.depth}')
depths = [4, 4, 4, 4]
assert(depths == [assign.depth for assign in assigns])
0 Assignment:: var_out(i, j, k) = var_in(i, j, k) 4
1 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k) 4
2 Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB 4
3 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k) 4
All the Assignment
statements and their respective depths are identified correctly. We can do a similar test on nested if
statements:
[7]:
from loki import Subroutine
from loki import Conditional
fcode = """
subroutine nested_conditionals(i,j,k,h)
logical,intent(in) :: i,j,k,h
if(i)then
if(j)then
if(k)then
! do something
else
! do something else
endif
if(h)then
! also test h
endif
endif
endif
end subroutine nested_conditionals
"""
routine = Subroutine.from_source(fcode)
conds = FindNodesDepth(Conditional).visit(routine.body)
for k, cond in enumerate(conds):
print(k, cond.node.condition, cond.depth)
depths = [1, 2, 3, 3]
assert(depths == [cond.depth for cond in conds])
0 i 1
1 j 2
2 k 3
3 h 3