import sys
import unittest
from dataclasses import dataclass
from typing import *
sys.setrecursionlimit(10**6)
## binary trees are another example of mixed compound data:
BinTree = Union['BTNode',None]
@dataclass(frozen=True)
class BTNode:
val : Any
left : BinTree
right : BinTree
# examples of data
bt1 = None
bt2 = BTNode(4,None,None)
bt3 = BTNode(4,
BTNode("thingy",None,None),
None)
bt4 = BTNode("swim",
BTNode("run",BTNode("walk",None,None), BTNode("banana",None,None)),
BTNode("sleep",None,None))
# count the nodes in a binary tree:
def bt_nodes(bt : BinTree) -> int:
match bt:
case None:
return 0
case BTNode(_,l,r):
return bt_nodes(l) + bt_nodes(r) + 1
# does the list contain the string "banana" ?
def contains_banana(bt : BinTree) -> bool:
match bt:
case None:
return False
case BTNode(v,l,r):
return (v == "banana") or contains_banana(l) or contains_banana(r)
# flip the binary tree left to right
def mirror(bt : BinTree) -> BinTree:
match bt:
case None:
return None
case BTNode(v,l,r):
return BTNode(v,mirror(r),mirror(l))
class MyTests(unittest.TestCase):
def test_count_nodes(self):
self.assertEqual(0,bt_nodes(bt1))
self.assertEqual(1,bt_nodes(bt2))
self.assertEqual(2,bt_nodes(bt3))
self.assertEqual(5,bt_nodes(bt4))
def test_contains_banana(self):
self.assertEqual(False,contains_banana(bt1))
self.assertEqual(True,contains_banana(bt4))
self.assertEqual(False,contains_banana(bt2))
def test_mirror(self):
self.assertEqual(bt1,mirror(bt1))
self.assertEqual(bt2,mirror(bt2))
self.assertEqual(BTNode(4,
None,
BTNode("thingy",None,None)),
mirror(bt3))
self.assertEqual(BTNode("swim",
BTNode("sleep",None,None),
BTNode("run",BTNode("banana",None,None), BTNode("walk",None,None))),
mirror(bt4))