Creating new visitors

In the previous notebook, we relied heavily on the FindNodes visitor, which looks through a given IR tree and returns a list of matching instances of a specified Node type. Although this functionality is sufficient for most use cases, there may be scenarios that require the implementation of bespoke visitors.

For node types that could appear in a nested structure, for example Loop or Conditional, we may be interested in knowing at what depth they appear in a given IR tree. The following illustrates how this can be achieved by building a new FindNodesDepth visitor based on FindNodes.

Dataclass to store return values

The default return value for FindNodes is a list of nodes. For FindNodesDepth, we would also like to return the depth of the node. We can create a new dataclass (essentially a c-style struct) called DepthNode to store both these pieces of information:

[1]:
from loki import Node
from dataclasses import dataclass

@dataclass
class DepthNode:
    """Store node object and depth in c-style struct."""

    node: Node
    depth: int

Modifying initialization method

FindNodes has two operating modes. The first (and default mode) is to look through a given IR tree and return a list of matching instances of a specified node type. The second, which is enabled by passing mode='scope' when creating the visitor, returns the InternalNode i.e. the Scope in which a specified node appears.

For our new visitor, we are only interested in the default operating mode of FindNodes. Therefore let us define a new initialization function for our FindNodesDepth class:

[2]:
from loki import FindNodes

class FindNodesDepth(FindNodes):
    """Visitor that computes node-depth relative to subroutine body. Returns list of DepthNode objects."""

    def __init__(self, match, greedy=False):
        super().__init__(match, mode='type', greedy=greedy)

Modifying the visit_Node method

In order to achieve the desired functionality of our new visitor, we will need a new visit_Node method. We start from a copy of FindNodes.visit_Node and make only a few changes to it:

[3]:
def visit_Node(self, o, **kwargs):
    """
    Add the node to the returned list if it matches the criteria and increment depth
    before visiting all children.
    """

    ret = kwargs.pop('ret', self.default_retval())
    depth = kwargs.pop('depth', 0)
    if self.rule(self.match, o):
        ret.append(DepthNode(o, depth))
        if self.greedy:
            return ret
    for i in o.children:
        ret = self.visit(i, depth=depth+1, ret=ret, **kwargs)
    return ret or self.default_retval()

FindNodesDepth.visit_Node = visit_Node

The first change to visit_Node is the addition of a line that sets depth. If visit_Node is called from the base IR tree, then depth is initialized to 0. If on the other hand visit_Node is called recursively, then the current depth of node o is retrieved. The second and final change is the addition of a depth keyword argument to the recursive call to visit for the children of node o. As recursion signifies moving down one level in the IR tree, the depth+1 is passed as an argument.

Having now fully defined our new visitor, we can test it on the following routine containing nested loops:

[4]:
from loki import Sourcefile
from loki import fgen

source = Sourcefile.from_file('src/loop_fuse.F90')
routine = source['loop_fuse_v1']
print(fgen(routine.body))

DO k=1,n
  DO j=1,n
    DO i=1,n
      var_out(i, j, k) = var_in(i, j, k)
    END DO
    DO i=1,n
      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
    END DO
  END DO

  CALL some_kernel(n, var_out(1, 1, k))

  DO j=1,n
    DO i=1,n
      var_out(i, j, k) = var_out(i, j, k) + 1._JPRB
    END DO
    DO i=1,n
      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
    END DO
  END DO
END DO

loop_fuse_v1 contains a total of 7 loops, with a maximum nesting depth of 3. Let us see if our new visitor can identify the loops and their depth correctly:

[5]:
from loki import Loop

loops = FindNodesDepth(Loop).visit(routine.body)

for k, loop in enumerate(loops):
    print(k, loop.node, loop.depth)

depths = [1, 2, 3, 3, 2, 3, 3]
assert(depths == [loop.depth for loop in loops])
0 Loop:: k=1:n 1
1 Loop:: j=1:n 2
2 Loop:: i=1:n 3
3 Loop:: i=1:n 3
4 Loop:: j=1:n 2
5 Loop:: i=1:n 3
6 Loop:: i=1:n 3

As the output shows, the depth of all 7 loops was identified correctly. Note that the subroutine body itself is assigned a depth of 0, and because the outermost k-loop is a child of the subroutine body, it has a depth of 1.

We can also use our new visitor to find the depth of the Assignment statements within the bodies of the loops:

[6]:
from loki import Assignment

assigns = FindNodesDepth(Assignment).visit(routine.body)

for k, assign in enumerate(assigns):
    print(f'{k} {str(assign.node):<60}{assign.depth}')

depths = [4, 4, 4, 4]
assert(depths == [assign.depth for assign in assigns])
0 Assignment:: var_out(i, j, k) = var_in(i, j, k)             4
1 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4
2 Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB  4
3 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4

All the Assignment statements and their respective depths are identified correctly. We can do a similar test on nested if statements:

[7]:
from loki import Subroutine
from loki import Conditional

fcode = """
subroutine nested_conditionals(i,j,k,h)

    logical,intent(in) :: i,j,k,h

    if(i)then
      if(j)then

        if(k)then
          ! do something
        else
          ! do something else
        endif

        if(h)then
          ! also test h
        endif

      endif
    endif

end subroutine nested_conditionals
"""

routine = Subroutine.from_source(fcode)

conds = FindNodesDepth(Conditional).visit(routine.body)
for k, cond in enumerate(conds):
    print(k, cond.node.condition, cond.depth)

depths = [1, 2, 3, 3]
assert(depths == [cond.depth for cond in conds])
0 i 1
1 j 2
2 k 3
3 h 3