add some typing annotations (#20)

* add some typing annotations

* minor additions

* import Optional

* format with black

* review comments

---------

Co-authored-by: Anton Lydike <me@antonlydike.de>
This commit is contained in:
Sasha Lopoukhine 2023-05-05 17:22:58 +01:00 committed by GitHub
parent d6d3a18aa6
commit 25d059da09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 74 additions and 56 deletions

View File

@ -51,7 +51,7 @@ class UserModeCPU(CPU):
syscall_symbols.update(self.mmu.global_symbols)
self.mmu.global_symbols.update(syscall_symbols)
def step(self, verbose=False):
def step(self, verbose: bool = False):
"""
Execute a single instruction, then return.
"""
@ -91,7 +91,7 @@ class UserModeCPU(CPU):
if launch_debugger:
launch_debug_session(self)
def run(self, verbose=False):
def run(self, verbose: bool = False):
while not self.halted:
self.step(verbose)
@ -102,7 +102,7 @@ class UserModeCPU(CPU):
+ FMT_NONE
)
def setup_stack(self, stack_size=1024 * 4) -> bool:
def setup_stack(self, stack_size: int = 1024 * 4) -> bool:
"""
Create program stack and populate stack pointer
:param stack_size: the size of the required stack, defaults to 4Kib

View File

@ -319,7 +319,7 @@ class MMU:
sec.base = at_addr
self.sections.append(sec)
self._update_state()
return True
return True
def _update_state(self):
"""

View File

@ -84,19 +84,19 @@ class Registers:
def __init__(self, infinite_regs: bool = False):
from .types import Int32
self.vals = defaultdict(lambda: Int32(0))
self.vals: defaultdict[str, Int32] = defaultdict(lambda: Int32(0))
self.last_set = None
self.last_read = None
self.infinite_regs = infinite_regs
def dump(self, full=False):
def dump(self, full: bool = False):
"""
Dump all registers to stdout
:param full: If True, floating point registers are dumped too
"""
named_regs = [self._reg_repr(reg) for reg in Registers.named_registers()]
lines = [[] for i in range(12)]
lines: list[list[str]] = [[] for _ in range(12)]
if not full:
regs = [("a", 8), ("s", 12), ("t", 7)]
else:
@ -142,7 +142,7 @@ class Registers:
+ " ".join(self._reg_repr("a{}".format(i)) for i in range(8))
)
def _reg_repr(self, reg):
def _reg_repr(self, reg: str):
txt = "{:4}=0x{:08X}".format(reg, self.get(reg, False))
if reg == "fp":
reg = "s0"
@ -156,7 +156,7 @@ class Registers:
return FMT_GRAY + txt + FMT_NONE
return txt
def set(self, reg, val: "Int32", mark_set=True) -> bool:
def set(self, reg: str, val: "Int32", mark_set: bool = True) -> bool:
"""
Set a register content to val
:param reg: The register to set
@ -189,7 +189,7 @@ class Registers:
self.vals[reg] = val.unsigned()
return True
def get(self, reg, mark_read=True) -> "Int32":
def get(self, reg: str, mark_read: bool = True) -> "Int32":
"""
Retuns the contents of register reg
:param reg: The register name

View File

@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import re
# define some base type aliases so we can keep track of absolute and relative addresses
@ -6,7 +6,7 @@ T_RelativeAddress = int
T_AbsoluteAddress = int
# parser options are just dictionaries with arbitrary values
T_ParserOpts = Dict[str, any]
T_ParserOpts = Dict[str, Any]
NUMBER_SYMBOL_PATTERN = re.compile(r"^\d+[fb]$")

View File

@ -1,3 +1,4 @@
from typing import Optional
from . import (
MemorySection,
InstructionContext,
@ -16,15 +17,17 @@ class BinaryDataMemorySection(MemorySection):
context: InstructionContext,
owner: str,
base: int = 0,
flags: MemoryFlags = None,
flags: Optional[MemoryFlags] = None,
):
self.name = name
self.base = base
self.context = context
self.size = len(data)
self.flags = flags if flags is not None else MemoryFlags(False, False)
super().__init__(
name,
flags if flags is not None else MemoryFlags(False, False),
len(data),
base,
owner,
context,
)
self.data = data
self.owner = owner
def read(self, offset: T_RelativeAddress, size: int) -> bytearray:
if offset + size > self.size:

View File

@ -7,6 +7,10 @@ from ..config import RunConfig
from ..colors import FMT_RED, FMT_NONE, FMT_ERROR, FMT_CPU
from . import T_AbsoluteAddress, Instruction, Program, ProgramLoader
if typing.TYPE_CHECKING:
from ..MMU import MMU
from ..instructions import InstructionSet
class CPU(ABC):
# static cpu configuration
@ -80,11 +84,11 @@ class CPU(ABC):
)
@abstractmethod
def step(self, verbose=False):
def step(self, verbose: bool = False):
pass
@abstractmethod
def run(self, verbose=False):
def run(self, verbose: bool = False):
pass
def launch(self, program: Program, verbose: bool = False):

View File

@ -14,8 +14,8 @@ if typing.TYPE_CHECKING:
class RiscemuBaseException(BaseException):
@abstractmethod
def message(self):
pass
def message(self) -> str:
raise NotImplemented
def print_stacktrace(self):
import traceback
@ -27,7 +27,7 @@ class RiscemuBaseException(BaseException):
class ParseException(RiscemuBaseException):
def __init__(self, msg, data=None):
def __init__(self, msg: str, data=None):
super().__init__(msg, data)
self.msg = msg
self.data = data
@ -77,7 +77,7 @@ def ASSERT_IN(a1, a2):
class LinkerException(RiscemuBaseException):
def __init__(self, msg, data):
def __init__(self, msg: str, data):
self.msg = msg
self.data = data
@ -93,7 +93,7 @@ class LinkerException(RiscemuBaseException):
class MemoryAccessException(RiscemuBaseException):
def __init__(self, msg, addr, size, op):
def __init__(self, msg: str, addr, size, op):
super(MemoryAccessException, self).__init__()
self.msg = msg
self.addr = addr
@ -196,5 +196,5 @@ class NumberFormatException(RiscemuBaseException):
# this exception is not printed and simply signals that an interactive debugging session is
class LaunchDebuggerException(RiscemuBaseException):
def message(self):
def message(self) -> str:
return ""

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Union
from ctypes import c_int32, c_uint32
@ -52,10 +52,10 @@ class Int32:
other = other.value
return self.__class__(self._val.value * other)
def __truediv__(self, other):
def __truediv__(self, other: Any):
return self // other
def __floordiv__(self, other):
def __floordiv__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.__class__(self.value // other)
@ -90,10 +90,12 @@ class Int32:
other = other.value
return self.__class__(self.value >> other)
def __eq__(self, other: Union["Int32", int]):
if isinstance(other, Int32):
other = other.value
return self.value == other
def __eq__(self, other: object) -> bool:
if isinstance(other, int):
return self.value == other
elif isinstance(other, Int32):
return self.value == other.value
return False
def __neg__(self):
return self.__class__(-self._val.value)
@ -110,28 +112,28 @@ class Int32:
def __str__(self):
return str(self.value)
def __format__(self, format_spec):
def __format__(self, format_spec: str):
return self.value.__format__(format_spec)
def __hash__(self):
return hash(self.value)
def __gt__(self, other):
def __gt__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.value > other
def __lt__(self, other):
def __lt__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.value < other
def __le__(self, other):
def __le__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.value <= other
def __ge__(self, other):
def __ge__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.value >= other
@ -139,38 +141,38 @@ class Int32:
def __bool__(self):
return bool(self.value)
def __cmp__(self, other):
def __cmp__(self, other: Any):
if isinstance(other, Int32):
other = other.value
return self.value.__cmp__(other)
# right handed binary operators
def __radd__(self, other):
def __radd__(self, other: Any):
return self + other
def __rsub__(self, other):
def __rsub__(self, other: Any):
return self.__class__(other) - self
def __rmul__(self, other):
def __rmul__(self, other: Any):
return self * other
def __rtruediv__(self, other):
def __rtruediv__(self, other: Any):
return self.__class__(other) // self
def __rfloordiv__(self, other):
def __rfloordiv__(self, other: Any):
return self.__class__(other) // self
def __rmod__(self, other):
def __rmod__(self, other: Any):
return self.__class__(other) % self
def __rand__(self, other):
def __rand__(self, other: Any):
return self.__class__(other) & self
def __ror__(self, other):
def __ror__(self, other: Any):
return self.__class__(other) | self
def __rxor__(self, other):
def __rxor__(self, other: Any):
return self.__class__(other) ^ self
@property
@ -278,4 +280,6 @@ class UInt32(Int32):
:param ammount: Number of positions to shift
:return: A new Int32 object representing the shifted value (keeps the signed-ness of the source)
"""
return self >> ammount
if isinstance(ammount, Int32):
ammount = ammount.value
return UInt32(self.value >> ammount)

View File

@ -43,11 +43,11 @@ class MemorySection(ABC):
self,
start: T_RelativeAddress,
end: Optional[T_RelativeAddress] = None,
fmt: str = None,
bytes_per_row: int = None,
fmt: Optional[str] = None,
bytes_per_row: Optional[int] = None,
rows: int = 10,
group: int = None,
highlight: int = None,
group: Optional[int] = None,
highlight: Optional[int] = None,
):
"""
Dump the section. If no end is given, the rows around start are printed and start is highlighted.
@ -152,8 +152,15 @@ class MemorySection(ABC):
)
)
def dump_all(self, *args, **kwargs):
self.dump(0, self.size, *args, **kwargs)
def dump_all(
self,
fmt: Optional[str] = None,
bytes_per_row: Optional[int] = None,
rows: int = 10,
group: Optional[int] = None,
highlight: Optional[int] = None,
):
self.dump(0, self.size, fmt, bytes_per_row, rows, group, highlight)
def __repr__(self):
return "{}[{}] at 0x{:08X} (size={}bytes, flags={}, owner={})".format(