From 25d059da090760862f9143478524ed6daf0ab449 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 5 May 2023 17:22:58 +0100 Subject: [PATCH] add some typing annotations (#20) * add some typing annotations * minor additions * import Optional * format with black * review comments --------- Co-authored-by: Anton Lydike --- riscemu/CPU.py | 6 +-- riscemu/MMU.py | 2 +- riscemu/registers.py | 12 ++--- riscemu/types/__init__.py | 4 +- riscemu/types/binary_data_memory_section.py | 17 ++++--- riscemu/types/cpu.py | 8 +++- riscemu/types/exceptions.py | 12 ++--- riscemu/types/int32.py | 50 +++++++++++---------- riscemu/types/memory_section.py | 19 +++++--- 9 files changed, 74 insertions(+), 56 deletions(-) diff --git a/riscemu/CPU.py b/riscemu/CPU.py index dc4d304..fc84cb0 100644 --- a/riscemu/CPU.py +++ b/riscemu/CPU.py @@ -51,7 +51,7 @@ class UserModeCPU(CPU): syscall_symbols.update(self.mmu.global_symbols) self.mmu.global_symbols.update(syscall_symbols) - def step(self, verbose=False): + def step(self, verbose: bool = False): """ Execute a single instruction, then return. """ @@ -91,7 +91,7 @@ class UserModeCPU(CPU): if launch_debugger: launch_debug_session(self) - def run(self, verbose=False): + def run(self, verbose: bool = False): while not self.halted: self.step(verbose) @@ -102,7 +102,7 @@ class UserModeCPU(CPU): + FMT_NONE ) - def setup_stack(self, stack_size=1024 * 4) -> bool: + def setup_stack(self, stack_size: int = 1024 * 4) -> bool: """ Create program stack and populate stack pointer :param stack_size: the size of the required stack, defaults to 4Kib diff --git a/riscemu/MMU.py b/riscemu/MMU.py index a13cb0c..8bb79aa 100644 --- a/riscemu/MMU.py +++ b/riscemu/MMU.py @@ -319,7 +319,7 @@ class MMU: sec.base = at_addr self.sections.append(sec) self._update_state() - return True + return True def _update_state(self): """ diff --git a/riscemu/registers.py b/riscemu/registers.py index c6150e6..ef7ef51 100644 --- a/riscemu/registers.py +++ b/riscemu/registers.py @@ -84,19 +84,19 @@ class Registers: def __init__(self, infinite_regs: bool = False): from .types import Int32 - self.vals = defaultdict(lambda: Int32(0)) + self.vals: defaultdict[str, Int32] = defaultdict(lambda: Int32(0)) self.last_set = None self.last_read = None self.infinite_regs = infinite_regs - def dump(self, full=False): + def dump(self, full: bool = False): """ Dump all registers to stdout :param full: If True, floating point registers are dumped too """ named_regs = [self._reg_repr(reg) for reg in Registers.named_registers()] - lines = [[] for i in range(12)] + lines: list[list[str]] = [[] for _ in range(12)] if not full: regs = [("a", 8), ("s", 12), ("t", 7)] else: @@ -142,7 +142,7 @@ class Registers: + " ".join(self._reg_repr("a{}".format(i)) for i in range(8)) ) - def _reg_repr(self, reg): + def _reg_repr(self, reg: str): txt = "{:4}=0x{:08X}".format(reg, self.get(reg, False)) if reg == "fp": reg = "s0" @@ -156,7 +156,7 @@ class Registers: return FMT_GRAY + txt + FMT_NONE return txt - def set(self, reg, val: "Int32", mark_set=True) -> bool: + def set(self, reg: str, val: "Int32", mark_set: bool = True) -> bool: """ Set a register content to val :param reg: The register to set @@ -189,7 +189,7 @@ class Registers: self.vals[reg] = val.unsigned() return True - def get(self, reg, mark_read=True) -> "Int32": + def get(self, reg: str, mark_read: bool = True) -> "Int32": """ Retuns the contents of register reg :param reg: The register name diff --git a/riscemu/types/__init__.py b/riscemu/types/__init__.py index 56705dc..a509fd5 100644 --- a/riscemu/types/__init__.py +++ b/riscemu/types/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import re # define some base type aliases so we can keep track of absolute and relative addresses @@ -6,7 +6,7 @@ T_RelativeAddress = int T_AbsoluteAddress = int # parser options are just dictionaries with arbitrary values -T_ParserOpts = Dict[str, any] +T_ParserOpts = Dict[str, Any] NUMBER_SYMBOL_PATTERN = re.compile(r"^\d+[fb]$") diff --git a/riscemu/types/binary_data_memory_section.py b/riscemu/types/binary_data_memory_section.py index 07d32e1..11570a9 100644 --- a/riscemu/types/binary_data_memory_section.py +++ b/riscemu/types/binary_data_memory_section.py @@ -1,3 +1,4 @@ +from typing import Optional from . import ( MemorySection, InstructionContext, @@ -16,15 +17,17 @@ class BinaryDataMemorySection(MemorySection): context: InstructionContext, owner: str, base: int = 0, - flags: MemoryFlags = None, + flags: Optional[MemoryFlags] = None, ): - self.name = name - self.base = base - self.context = context - self.size = len(data) - self.flags = flags if flags is not None else MemoryFlags(False, False) + super().__init__( + name, + flags if flags is not None else MemoryFlags(False, False), + len(data), + base, + owner, + context, + ) self.data = data - self.owner = owner def read(self, offset: T_RelativeAddress, size: int) -> bytearray: if offset + size > self.size: diff --git a/riscemu/types/cpu.py b/riscemu/types/cpu.py index 5a78056..8ee8b6d 100644 --- a/riscemu/types/cpu.py +++ b/riscemu/types/cpu.py @@ -7,6 +7,10 @@ from ..config import RunConfig from ..colors import FMT_RED, FMT_NONE, FMT_ERROR, FMT_CPU from . import T_AbsoluteAddress, Instruction, Program, ProgramLoader +if typing.TYPE_CHECKING: + from ..MMU import MMU + from ..instructions import InstructionSet + class CPU(ABC): # static cpu configuration @@ -80,11 +84,11 @@ class CPU(ABC): ) @abstractmethod - def step(self, verbose=False): + def step(self, verbose: bool = False): pass @abstractmethod - def run(self, verbose=False): + def run(self, verbose: bool = False): pass def launch(self, program: Program, verbose: bool = False): diff --git a/riscemu/types/exceptions.py b/riscemu/types/exceptions.py index 71ee0a8..ad15072 100644 --- a/riscemu/types/exceptions.py +++ b/riscemu/types/exceptions.py @@ -14,8 +14,8 @@ if typing.TYPE_CHECKING: class RiscemuBaseException(BaseException): @abstractmethod - def message(self): - pass + def message(self) -> str: + raise NotImplemented def print_stacktrace(self): import traceback @@ -27,7 +27,7 @@ class RiscemuBaseException(BaseException): class ParseException(RiscemuBaseException): - def __init__(self, msg, data=None): + def __init__(self, msg: str, data=None): super().__init__(msg, data) self.msg = msg self.data = data @@ -77,7 +77,7 @@ def ASSERT_IN(a1, a2): class LinkerException(RiscemuBaseException): - def __init__(self, msg, data): + def __init__(self, msg: str, data): self.msg = msg self.data = data @@ -93,7 +93,7 @@ class LinkerException(RiscemuBaseException): class MemoryAccessException(RiscemuBaseException): - def __init__(self, msg, addr, size, op): + def __init__(self, msg: str, addr, size, op): super(MemoryAccessException, self).__init__() self.msg = msg self.addr = addr @@ -196,5 +196,5 @@ class NumberFormatException(RiscemuBaseException): # this exception is not printed and simply signals that an interactive debugging session is class LaunchDebuggerException(RiscemuBaseException): - def message(self): + def message(self) -> str: return "" diff --git a/riscemu/types/int32.py b/riscemu/types/int32.py index 258d720..211a5dc 100644 --- a/riscemu/types/int32.py +++ b/riscemu/types/int32.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union from ctypes import c_int32, c_uint32 @@ -52,10 +52,10 @@ class Int32: other = other.value return self.__class__(self._val.value * other) - def __truediv__(self, other): + def __truediv__(self, other: Any): return self // other - def __floordiv__(self, other): + def __floordiv__(self, other: Any): if isinstance(other, Int32): other = other.value return self.__class__(self.value // other) @@ -90,10 +90,12 @@ class Int32: other = other.value return self.__class__(self.value >> other) - def __eq__(self, other: Union["Int32", int]): - if isinstance(other, Int32): - other = other.value - return self.value == other + def __eq__(self, other: object) -> bool: + if isinstance(other, int): + return self.value == other + elif isinstance(other, Int32): + return self.value == other.value + return False def __neg__(self): return self.__class__(-self._val.value) @@ -110,28 +112,28 @@ class Int32: def __str__(self): return str(self.value) - def __format__(self, format_spec): + def __format__(self, format_spec: str): return self.value.__format__(format_spec) def __hash__(self): return hash(self.value) - def __gt__(self, other): + def __gt__(self, other: Any): if isinstance(other, Int32): other = other.value return self.value > other - def __lt__(self, other): + def __lt__(self, other: Any): if isinstance(other, Int32): other = other.value return self.value < other - def __le__(self, other): + def __le__(self, other: Any): if isinstance(other, Int32): other = other.value return self.value <= other - def __ge__(self, other): + def __ge__(self, other: Any): if isinstance(other, Int32): other = other.value return self.value >= other @@ -139,38 +141,38 @@ class Int32: def __bool__(self): return bool(self.value) - def __cmp__(self, other): + def __cmp__(self, other: Any): if isinstance(other, Int32): other = other.value return self.value.__cmp__(other) # right handed binary operators - def __radd__(self, other): + def __radd__(self, other: Any): return self + other - def __rsub__(self, other): + def __rsub__(self, other: Any): return self.__class__(other) - self - def __rmul__(self, other): + def __rmul__(self, other: Any): return self * other - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any): return self.__class__(other) // self - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: Any): return self.__class__(other) // self - def __rmod__(self, other): + def __rmod__(self, other: Any): return self.__class__(other) % self - def __rand__(self, other): + def __rand__(self, other: Any): return self.__class__(other) & self - def __ror__(self, other): + def __ror__(self, other: Any): return self.__class__(other) | self - def __rxor__(self, other): + def __rxor__(self, other: Any): return self.__class__(other) ^ self @property @@ -278,4 +280,6 @@ class UInt32(Int32): :param ammount: Number of positions to shift :return: A new Int32 object representing the shifted value (keeps the signed-ness of the source) """ - return self >> ammount + if isinstance(ammount, Int32): + ammount = ammount.value + return UInt32(self.value >> ammount) diff --git a/riscemu/types/memory_section.py b/riscemu/types/memory_section.py index 3075859..16defb4 100644 --- a/riscemu/types/memory_section.py +++ b/riscemu/types/memory_section.py @@ -43,11 +43,11 @@ class MemorySection(ABC): self, start: T_RelativeAddress, end: Optional[T_RelativeAddress] = None, - fmt: str = None, - bytes_per_row: int = None, + fmt: Optional[str] = None, + bytes_per_row: Optional[int] = None, rows: int = 10, - group: int = None, - highlight: int = None, + group: Optional[int] = None, + highlight: Optional[int] = None, ): """ Dump the section. If no end is given, the rows around start are printed and start is highlighted. @@ -152,8 +152,15 @@ class MemorySection(ABC): ) ) - def dump_all(self, *args, **kwargs): - self.dump(0, self.size, *args, **kwargs) + def dump_all( + self, + fmt: Optional[str] = None, + bytes_per_row: Optional[int] = None, + rows: int = 10, + group: Optional[int] = None, + highlight: Optional[int] = None, + ): + self.dump(0, self.size, fmt, bytes_per_row, rows, group, highlight) def __repr__(self): return "{}[{}] at 0x{:08X} (size={}bytes, flags={}, owner={})".format(