From 25644844299af8f987ffdea7aa4859ff00c37065 Mon Sep 17 00:00:00 2001
From: Geza Lore <gezalore@gmail.com>
Date: Tue, 13 Sep 2022 15:57:22 +0100
Subject: [PATCH] astgen: Rewrite in a more OOP way, in preparation for
 extensions

Rely less on strings and represent AstNode classes as a 'class Node',
with all associated properties kept together, rather than distributed
over multiple dictionaries or constructed at retrieval time.

No functional change intended.
---
 src/astgen | 362 ++++++++++++++++++++++++++++++-----------------------
 1 file changed, 203 insertions(+), 159 deletions(-)

diff --git a/src/astgen b/src/astgen
index 44529a0e4..f8860dce0 100755
--- a/src/astgen
+++ b/src/astgen
@@ -8,9 +8,125 @@ import re
 import sys
 # from pprint import pprint, pformat
 
-Types = []
-Classes = {}
-Children = {}
+
+class Node:
+    def __init__(self, name, superClass):
+        self._name = name
+        self._superClass = superClass
+        self._subClasses = []  # Initially list, but tuple after completion
+        self._allSuperClasses = None  # Computed on demand after completion
+        self._allSubClasses = None  # Computed on demand after completion
+        self._typeId = None  # Concrete type identifier number for leaf classes
+        self._typeIdMin = None  # Lowest type identifier number for class
+        self._typeIdMax = None  # Highest type identifier number for class
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def superClass(self):
+        return self._superClass
+
+    @property
+    def isCompleted(self):
+        return isinstance(self._subClasses, tuple)
+
+    # Pre completion methods
+    def addSubClass(self, subClass):
+        assert not self.isCompleted
+        self._subClasses.append(subClass)
+
+    # Computes derived properties over entire class hierarchy.
+    # No more changes to the hierarchy are allowed once this was called
+    def complete(self, typeId=0):
+        assert not self.isCompleted
+        # Sort sub-classes and convert to tuple, which marks completion
+        self._subClasses = tuple(sorted(self._subClasses,
+                                        key=lambda _: _.name))
+        # Leaves
+        if self.isLeaf:
+            self._typeId = typeId
+            return typeId + 1
+
+        # Non-leaves
+        for subClass in self._subClasses:
+            typeId = subClass.complete(typeId)
+        return typeId
+
+    # Post completion methods
+    @property
+    def subClasses(self):
+        assert self.isCompleted
+        return self._subClasses
+
+    @property
+    def isRoot(self):
+        assert self.isCompleted
+        return self.superClass is None
+
+    @property
+    def isLeaf(self):
+        assert self.isCompleted
+        return not self.subClasses
+
+    @property
+    def allSuperClasses(self):
+        assert self.isCompleted
+        if self._allSuperClasses is None:
+            if self.superClass is None:
+                self._allSuperClasses = ()
+            else:
+                self._allSuperClasses = self.superClass.allSuperClasses + (
+                    self.superClass, )
+        return self._allSuperClasses
+
+    @property
+    def allSubClasses(self):
+        assert self.isCompleted
+        if self._allSubClasses is None:
+            if self.isLeaf:
+                self._allSubClasses = ()
+            else:
+                self._allSubClasses = self.subClasses + tuple(
+                    _ for subClass in self.subClasses
+                    for _ in subClass.allSubClasses)
+        return self._allSubClasses
+
+    @property
+    def typeId(self):
+        assert self.isCompleted
+        assert self.isLeaf
+        return self._typeId
+
+    @property
+    def typeIdMin(self):
+        assert self.isCompleted
+        if self.isLeaf:
+            return self.typeId
+        if self._typeIdMin is None:
+            self._typeIdMin = min(_.typeIdMin for _ in self.allSubClasses)
+        return self._typeIdMin
+
+    @property
+    def typeIdMax(self):
+        assert self.isCompleted
+        if self.isLeaf:
+            return self.typeId
+        if self._typeIdMax is None:
+            self._typeIdMax = max(_.typeIdMax for _ in self.allSubClasses)
+        return self._typeIdMax
+
+    def isSubClassOf(self, other):
+        assert self.isCompleted
+        if self is other:
+            return True
+        return self in other.allSubClasses
+
+
+Nodes = {}
+SortedNodes = None
+
 ClassRefs = {}
 Stages = {}
 
@@ -111,7 +227,7 @@ class Cpt:
                 self.error("Can't parse from function: " + func)
             typen = match.group(1)
             subnodes = match.group(2)
-            if not subclasses_of(typen):
+            if Nodes[typen].isRoot:
                 self.error("Unknown AstNode typen: " + typen + ": in " + func)
 
             mif = ""
@@ -166,7 +282,7 @@ class Cpt:
         elif match_skip:
             typen = match_skip.group(1)
             self.tree_skip_visit[typen] = 1
-            if typen not in Classes:
+            if typen not in Nodes:
                 self.error("Unknown node type: " + typen)
 
         else:
@@ -296,12 +412,13 @@ class Cpt:
         self.print(
             "    // Bottom class up, as more simple transforms are generally better\n"
         )
-        for typen in sorted(Classes.keys()):
+        for node in SortedNodes:
             out_for_type_sc = []
             out_for_type = []
-            bases = subclasses_of(typen)
-            bases.append(typen)
-            for base in bases:
+            classes = list(node.allSuperClasses)
+            classes.append(node)
+            for base in classes:
+                base = base.name
                 if base not in self.treeop:
                     continue
                 for typefunc in self.treeop[base]:
@@ -328,23 +445,23 @@ class Cpt:
             if len(out_for_type_sc) > 0:  # Short-circuited types
                 self.print(
                     "    // Generated by astgen with short-circuiting\n" +
-                    "    virtual void visit(Ast" + typen +
+                    "    virtual void visit(Ast" + node.name +
                     "* nodep) override {\n" +
                     "      iterateAndNextNull(nodep->lhsp());\n" +
                     "".join(out_for_type_sc))
                 if out_for_type[0]:
                     self.print("      iterateAndNextNull(nodep->rhsp());\n")
-                    if is_subclass_of(typen, "NodeTriop"):
+                    if node.isSubClassOf(Nodes["NodeTriop"]):
                         self.print(
                             "      iterateAndNextNull(nodep->thsp());\n")
                     self.print("".join(out_for_type) + "    }\n")
             elif len(out_for_type) > 0:  # Other types with something to print
-                skip = typen in self.tree_skip_visit
+                skip = node.name in self.tree_skip_visit
                 gen = "Gen" if skip else ""
                 override = "" if skip else " override"
                 self.print(
                     "    // Generated by astgen\n" + "    virtual void visit" +
-                    gen + "(Ast" + typen + "* nodep)" + override + " {\n" +
+                    gen + "(Ast" + node.name + "* nodep)" + override + " {\n" +
                     ("" if skip else "        iterateChildren(nodep);\n") +
                     ''.join(out_for_type) + "    }\n")
 
@@ -368,11 +485,13 @@ def read_types(filename):
                 if re.search(r'Ast', supern) or classn == "AstNode":
                     classn = re.sub(r'^Ast', '', classn)
                     supern = re.sub(r'^Ast', '', supern)
-                    Classes[classn] = supern
-                    if supern != '':
-                        if supern not in Children:
-                            Children[supern] = {}
-                        Children[supern][classn] = 1
+                    if supern:
+                        superClass = Nodes[supern]
+                        node = Node(classn, superClass)
+                        Nodes[supern].addSubClass(node)
+                    else:
+                        node = Node(classn, None)
+                    Nodes[classn] = node
 
 
 def read_stages(filename):
@@ -424,37 +543,6 @@ def open_file(filename):
     return fh
 
 
-def subclasses_of(typen):
-    cllist = []
-    subclass = Classes[typen]
-    while True:
-        if subclass not in Classes:
-            break
-        cllist.append(subclass)
-        subclass = Classes[subclass]
-
-    cllist.reverse()
-    return cllist
-
-
-def children_of(typen):
-    cllist = []
-    todo = []
-    todo.append(typen)
-    while len(todo) != 0:
-        subclass = todo.pop(0)
-        if subclass in Children:
-            for child in sorted(Children[subclass].keys()):
-                todo.append(child)
-                cllist.append(child)
-
-    return cllist
-
-
-def is_subclass_of(typen, what):
-    return typen == what or (typen in children_of(what))
-
-
 # ---------------------------------------------------------------------
 
 
@@ -468,20 +556,19 @@ def write_report(filename):
             fh.write("  " + classn + "\n")
 
         fh.write("\nClasses:\n")
-        for typen in sorted(Classes.keys()):
-            fh.write("  class Ast%-17s\n" % typen)
+        for node in SortedNodes:
+            fh.write("  class Ast%-17s\n" % node.name)
             fh.write("    parent: ")
-            for subclass in subclasses_of(typen):
-                if subclass != 'Node':
-                    fh.write("Ast%-12s " % subclass)
+            for superClass in node.allSuperClasses:
+                if not superClass.isRoot:
+                    fh.write("Ast%-12s " % superClass.name)
             fh.write("\n")
             fh.write("    childs:  ")
-            for subclass in children_of(typen):
-                if subclass != 'Node':
-                    fh.write("Ast%-12s " % subclass)
+            for subClass in node.allSubClasses:
+                fh.write("Ast%-12s " % subClass.name)
             fh.write("\n")
-            if ("Ast" + typen) in ClassRefs:  # pylint: disable=superfluous-parens
-                refs = ClassRefs["Ast" + typen]
+            if ("Ast" + node.name) in ClassRefs:  # pylint: disable=superfluous-parens
+                refs = ClassRefs["Ast" + node.name]
                 fh.write("    newed:  ")
                 for stage in sorted(refs['newed'].keys(),
                                     key=lambda val: Stages[val]
@@ -500,27 +587,27 @@ def write_report(filename):
 def write_classes(filename):
     with open_file(filename) as fh:
         fh.write("class AstNode;\n")
-        for typen in sorted(Classes.keys()):
-            fh.write("class Ast%-17s // " % (typen + ";"))
-            for subclass in subclasses_of(typen):
-                fh.write("Ast%-12s " % subclass)
+        for node in SortedNodes:
+            fh.write("class Ast%-17s // " % (node.name + ";"))
+            for superClass in node.allSuperClasses:
+                fh.write("Ast%-12s " % superClass.name)
             fh.write("\n")
 
 
 def write_visitor_decls(filename):
     with open_file(filename) as fh:
-        for typen in sorted(Classes.keys()):
-            if typen != "Node":
-                fh.write("virtual void visit(Ast" + typen + "*);\n")
+        for node in SortedNodes:
+            if not node.isRoot:
+                fh.write("virtual void visit(Ast" + node.name + "*);\n")
 
 
 def write_visitor_defns(filename):
     with open_file(filename) as fh:
-        for typen in sorted(Classes.keys()):
-            if typen != "Node":
-                base = Classes[typen]
-                fh.write("void VNVisitor::visit(Ast" + typen +
-                         "* nodep) { visit(static_cast<Ast" + base +
+        for node in SortedNodes:
+            base = node.superClass
+            if base is not None:
+                fh.write("void VNVisitor::visit(Ast" + node.name +
+                         "* nodep) { visit(static_cast<Ast" + base.name +
                          "*>(nodep)); }\n")
 
 
@@ -528,75 +615,51 @@ def write_impl(filename):
     with open_file(filename) as fh:
         fh.write("\n")
         fh.write("// For internal use. They assume argument is not nullptr.\n")
-        for typen in sorted(Classes.keys()):
+        for node in SortedNodes:
             fh.write("template<> inline bool AstNode::privateTypeTest<Ast" +
-                     typen + ">(const AstNode* nodep) { ")
-            if typen == "Node":
+                     node.name + ">(const AstNode* nodep) { ")
+            if node.isRoot:
                 fh.write("return true; ")
             else:
                 fh.write("return ")
-                if re.search(r'^Node', typen):
+                if not node.isLeaf:
                     fh.write(
                         "static_cast<int>(nodep->type()) >= static_cast<int>(VNType::first"
-                        + typen + ") && ")
+                        + node.name + ") && ")
                     fh.write(
                         "static_cast<int>(nodep->type()) <= static_cast<int>(VNType::last"
-                        + typen + "); ")
+                        + node.name + "); ")
                 else:
-                    fh.write("nodep->type() == VNType::at" + typen + "; ")
+                    fh.write("nodep->type() == VNType::at" + node.name + "; ")
             fh.write("}\n")
 
 
-def write_type_enum(fh, typen, idx, processed, kind, indent):
-    # Skip this if it has already been processed
-    if typen in processed:
-        return idx
-    # Mark processed
-    processed[typen] = 1
-
-    # The last used index
-    last = None
-
-    if not re.match(r'^Node', typen):
-        last = idx
-        if kind == "concrete-enum":
-            fh.write(" " * (indent * 4) + "at" + typen + " = " + str(idx) +
-                     ",\n")
-        elif kind == "concrete-ascii":
-            fh.write(" " * (indent * 4) + "\"" + typen.upper() + "\",\n")
-        idx += 1
-    elif kind == "abstract-enum":
-        fh.write(" " * (indent * 4) + "first" + typen + " = " + str(idx) +
-                 ",\n")
-
-    if typen in Children:
-        for child in sorted(Children[typen].keys()):
-            (idx, last) = write_type_enum(fh, child, idx, processed, kind,
-                                          indent)
-
-    if re.match(r'^Node', typen) and kind == "abstract-enum":
-        fh.write(" " * (indent * 4) + "last" + typen + " = " + str(last) +
-                 ",\n")
-
-    return [idx, last]
-
-
 def write_types(filename):
     with open_file(filename) as fh:
         fh.write("    enum en : uint16_t {\n")
-        (final, ignored) = write_type_enum(  # pylint: disable=W0612
-            fh, "Node", 0, {}, "concrete-enum", 2)
-        fh.write("        _ENUM_END = " + str(final) + "\n")
+        for node in sorted(filter(lambda _: _.isLeaf, SortedNodes),
+                           key=lambda _: _.typeId):
+            fh.write("        at" + node.name + " = " + str(node.typeId) +
+                     ",\n")
+        fh.write("        _ENUM_END = " + str(Nodes["Node"].typeIdMax + 1) +
+                 "\n")
         fh.write("    };\n")
 
         fh.write("    enum bounds : uint16_t {\n")
-        write_type_enum(fh, "Node", 0, {}, "abstract-enum", 2)
+        for node in sorted(filter(lambda _: not _.isLeaf, SortedNodes),
+                           key=lambda _: _.typeIdMin):
+            fh.write("        first" + node.name + " = " +
+                     str(node.typeIdMin) + ",\n")
+            fh.write("        last" + node.name + " = " + str(node.typeIdMax) +
+                     ",\n")
         fh.write("        _BOUNDS_END\n")
         fh.write("    };\n")
 
         fh.write("    const char* ascii() const {\n")
         fh.write("        static const char* const names[_ENUM_END + 1] = {\n")
-        write_type_enum(fh, "Node", 0, {}, "concrete-ascii", 3)
+        for node in sorted(filter(lambda _: _.isLeaf, SortedNodes),
+                           key=lambda _: _.typeId):
+            fh.write("            \"" + node.name.upper() + "\",\n")
         fh.write("            \"_ENUM_END\"\n")
         fh.write("        };\n")
         fh.write("        return names[m_e];\n")
@@ -605,45 +668,21 @@ def write_types(filename):
 
 def write_yystype(filename):
     with open_file(filename) as fh:
-        for typen in sorted(Classes.keys()):
-            fh.write("Ast{t}* {m}p;\n".format(t=typen,
-                                              m=typen[0].lower() + typen[1:]))
+        for node in SortedNodes:
+            fh.write("Ast{t}* {m}p;\n".format(t=node.name,
+                                              m=node.name[0].lower() +
+                                              node.name[1:]))
 
 
 def write_macros(filename):
     with open_file(filename) as fh:
-        typen = "None"
-        base = "None"
-
-        in_filename = "V3AstNodes.h"
-        ifile = Args.I + "/" + in_filename
-        with open(ifile) as ifh:
-            for (lineno, line) in enumerate(ifh, 1):
-                # Drop expanded macro definitions - but keep empty line so compiler
-                # message locations are accurate
-                line = re.sub(r'^\s*#(define|undef)\s+ASTGEN_.*$', '', line)
-
-                # Track current node type and base class
-                match = re.search(
-                    r'\s*class\s*Ast(\S+)\s*(final|VL_NOT_FINAL)?\s*:\s*(public)?\s*(AstNode\S*)',
-                    line)
-                if match:
-                    typen = match.group(1)
-                    base = match.group(4)
-                    if not typen.startswith("Node"):
-                        macro = "#define ASTGEN_SUPER_{t}(...) {b}(VNType::at{t}, __VA_ARGS__)\n" \
-                                .format(b=base, t=typen)
-                        fh.write(macro)
-
-                match = re.search(r"ASTGEN_SUPER_(\w+)", line)
-                if match:
-                    if typen != match.group(1):
-                        print((
-                            "V3AstNodes.h:{l} ERROR: class Ast{t} calls wrong superclass "
-                            +
-                            "constructor macro (should call ASTGEN_SUPER_{t})"
-                        ).format(l=lineno, t=typen))
-                        sys.exit(1)
+        for node in SortedNodes:
+            # Only care about leaf classes
+            if not node.isLeaf:
+                continue
+            fh.write(
+                "#define ASTGEN_SUPER_{t}(...) Ast{b}(VNType::at{t}, __VA_ARGS__)\n"
+                .format(t=node.name, b=node.superClass.name))
 
 
 ######################################################################
@@ -673,19 +712,24 @@ Args = parser.parse_args()
 
 read_types(Args.I + "/V3Ast.h")
 read_types(Args.I + "/V3AstNodes.h")
-for typen in sorted(Classes.keys()):
+
+# Compute derived properties over the whole AstNode hierarchy
+Nodes["Node"].complete()
+
+SortedNodes = tuple(map(lambda _: Nodes[_], sorted(Nodes.keys())))
+
+for node in SortedNodes:
     # Check all leaves are not AstNode* and non-leaves are AstNode*
-    children = children_of(typen)
-    if re.match(r'^Node', typen):
-        if len(children) == 0:
+    if re.match(r'^Node', node.name):
+        if node.isLeaf:
             sys.exit(
                 "%Error: Final AstNode subclasses must not be named AstNode*: Ast"
-                + typen)
+                + node.name)
     else:
-        if len(children) != 0:
+        if not node.isLeaf:
             sys.exit(
                 "%Error: Non-final AstNode subclasses must be named AstNode*: Ast"
-                + typen)
+                + node.name)
 
 read_stages(Args.I + "/Verilator.cpp")