add some typing annotations (#20)

* add some typing annotations

* minor additions

* import Optional

* format with black

* review comments

---------

Co-authored-by: Anton Lydike <me@antonlydike.de>
master
Sasha Lopoukhine 2 years ago committed by GitHub
parent d6d3a18aa6
commit 25d059da09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,7 +51,7 @@ class UserModeCPU(CPU):
syscall_symbols.update(self.mmu.global_symbols) syscall_symbols.update(self.mmu.global_symbols)
self.mmu.global_symbols.update(syscall_symbols) self.mmu.global_symbols.update(syscall_symbols)
def step(self, verbose=False): def step(self, verbose: bool = False):
""" """
Execute a single instruction, then return. Execute a single instruction, then return.
""" """
@ -91,7 +91,7 @@ class UserModeCPU(CPU):
if launch_debugger: if launch_debugger:
launch_debug_session(self) launch_debug_session(self)
def run(self, verbose=False): def run(self, verbose: bool = False):
while not self.halted: while not self.halted:
self.step(verbose) self.step(verbose)
@ -102,7 +102,7 @@ class UserModeCPU(CPU):
+ FMT_NONE + FMT_NONE
) )
def setup_stack(self, stack_size=1024 * 4) -> bool: def setup_stack(self, stack_size: int = 1024 * 4) -> bool:
""" """
Create program stack and populate stack pointer Create program stack and populate stack pointer
:param stack_size: the size of the required stack, defaults to 4Kib :param stack_size: the size of the required stack, defaults to 4Kib

@ -319,7 +319,7 @@ class MMU:
sec.base = at_addr sec.base = at_addr
self.sections.append(sec) self.sections.append(sec)
self._update_state() self._update_state()
return True return True
def _update_state(self): def _update_state(self):
""" """

@ -84,19 +84,19 @@ class Registers:
def __init__(self, infinite_regs: bool = False): def __init__(self, infinite_regs: bool = False):
from .types import Int32 from .types import Int32
self.vals = defaultdict(lambda: Int32(0)) self.vals: defaultdict[str, Int32] = defaultdict(lambda: Int32(0))
self.last_set = None self.last_set = None
self.last_read = None self.last_read = None
self.infinite_regs = infinite_regs self.infinite_regs = infinite_regs
def dump(self, full=False): def dump(self, full: bool = False):
""" """
Dump all registers to stdout Dump all registers to stdout
:param full: If True, floating point registers are dumped too :param full: If True, floating point registers are dumped too
""" """
named_regs = [self._reg_repr(reg) for reg in Registers.named_registers()] named_regs = [self._reg_repr(reg) for reg in Registers.named_registers()]
lines = [[] for i in range(12)] lines: list[list[str]] = [[] for _ in range(12)]
if not full: if not full:
regs = [("a", 8), ("s", 12), ("t", 7)] regs = [("a", 8), ("s", 12), ("t", 7)]
else: else:
@ -142,7 +142,7 @@ class Registers:
+ " ".join(self._reg_repr("a{}".format(i)) for i in range(8)) + " ".join(self._reg_repr("a{}".format(i)) for i in range(8))
) )
def _reg_repr(self, reg): def _reg_repr(self, reg: str):
txt = "{:4}=0x{:08X}".format(reg, self.get(reg, False)) txt = "{:4}=0x{:08X}".format(reg, self.get(reg, False))
if reg == "fp": if reg == "fp":
reg = "s0" reg = "s0"
@ -156,7 +156,7 @@ class Registers:
return FMT_GRAY + txt + FMT_NONE return FMT_GRAY + txt + FMT_NONE
return txt return txt
def set(self, reg, val: "Int32", mark_set=True) -> bool: def set(self, reg: str, val: "Int32", mark_set: bool = True) -> bool:
""" """
Set a register content to val Set a register content to val
:param reg: The register to set :param reg: The register to set
@ -189,7 +189,7 @@ class Registers:
self.vals[reg] = val.unsigned() self.vals[reg] = val.unsigned()
return True return True
def get(self, reg, mark_read=True) -> "Int32": def get(self, reg: str, mark_read: bool = True) -> "Int32":
""" """
Retuns the contents of register reg Retuns the contents of register reg
:param reg: The register name :param reg: The register name

@ -1,4 +1,4 @@
from typing import Dict from typing import Dict, Any
import re import re
# define some base type aliases so we can keep track of absolute and relative addresses # define some base type aliases so we can keep track of absolute and relative addresses
@ -6,7 +6,7 @@ T_RelativeAddress = int
T_AbsoluteAddress = int T_AbsoluteAddress = int
# parser options are just dictionaries with arbitrary values # parser options are just dictionaries with arbitrary values
T_ParserOpts = Dict[str, any] T_ParserOpts = Dict[str, Any]
NUMBER_SYMBOL_PATTERN = re.compile(r"^\d+[fb]$") NUMBER_SYMBOL_PATTERN = re.compile(r"^\d+[fb]$")

@ -1,3 +1,4 @@
from typing import Optional
from . import ( from . import (
MemorySection, MemorySection,
InstructionContext, InstructionContext,
@ -16,15 +17,17 @@ class BinaryDataMemorySection(MemorySection):
context: InstructionContext, context: InstructionContext,
owner: str, owner: str,
base: int = 0, base: int = 0,
flags: MemoryFlags = None, flags: Optional[MemoryFlags] = None,
): ):
self.name = name super().__init__(
self.base = base name,
self.context = context flags if flags is not None else MemoryFlags(False, False),
self.size = len(data) len(data),
self.flags = flags if flags is not None else MemoryFlags(False, False) base,
owner,
context,
)
self.data = data self.data = data
self.owner = owner
def read(self, offset: T_RelativeAddress, size: int) -> bytearray: def read(self, offset: T_RelativeAddress, size: int) -> bytearray:
if offset + size > self.size: if offset + size > self.size:

@ -7,6 +7,10 @@ from ..config import RunConfig
from ..colors import FMT_RED, FMT_NONE, FMT_ERROR, FMT_CPU from ..colors import FMT_RED, FMT_NONE, FMT_ERROR, FMT_CPU
from . import T_AbsoluteAddress, Instruction, Program, ProgramLoader from . import T_AbsoluteAddress, Instruction, Program, ProgramLoader
if typing.TYPE_CHECKING:
from ..MMU import MMU
from ..instructions import InstructionSet
class CPU(ABC): class CPU(ABC):
# static cpu configuration # static cpu configuration
@ -80,11 +84,11 @@ class CPU(ABC):
) )
@abstractmethod @abstractmethod
def step(self, verbose=False): def step(self, verbose: bool = False):
pass pass
@abstractmethod @abstractmethod
def run(self, verbose=False): def run(self, verbose: bool = False):
pass pass
def launch(self, program: Program, verbose: bool = False): def launch(self, program: Program, verbose: bool = False):

@ -14,8 +14,8 @@ if typing.TYPE_CHECKING:
class RiscemuBaseException(BaseException): class RiscemuBaseException(BaseException):
@abstractmethod @abstractmethod
def message(self): def message(self) -> str:
pass raise NotImplemented
def print_stacktrace(self): def print_stacktrace(self):
import traceback import traceback
@ -27,7 +27,7 @@ class RiscemuBaseException(BaseException):
class ParseException(RiscemuBaseException): class ParseException(RiscemuBaseException):
def __init__(self, msg, data=None): def __init__(self, msg: str, data=None):
super().__init__(msg, data) super().__init__(msg, data)
self.msg = msg self.msg = msg
self.data = data self.data = data
@ -77,7 +77,7 @@ def ASSERT_IN(a1, a2):
class LinkerException(RiscemuBaseException): class LinkerException(RiscemuBaseException):
def __init__(self, msg, data): def __init__(self, msg: str, data):
self.msg = msg self.msg = msg
self.data = data self.data = data
@ -93,7 +93,7 @@ class LinkerException(RiscemuBaseException):
class MemoryAccessException(RiscemuBaseException): class MemoryAccessException(RiscemuBaseException):
def __init__(self, msg, addr, size, op): def __init__(self, msg: str, addr, size, op):
super(MemoryAccessException, self).__init__() super(MemoryAccessException, self).__init__()
self.msg = msg self.msg = msg
self.addr = addr self.addr = addr
@ -196,5 +196,5 @@ class NumberFormatException(RiscemuBaseException):
# this exception is not printed and simply signals that an interactive debugging session is # this exception is not printed and simply signals that an interactive debugging session is
class LaunchDebuggerException(RiscemuBaseException): class LaunchDebuggerException(RiscemuBaseException):
def message(self): def message(self) -> str:
return "" return ""

@ -1,4 +1,4 @@
from typing import Union from typing import Any, Union
from ctypes import c_int32, c_uint32 from ctypes import c_int32, c_uint32
@ -52,10 +52,10 @@ class Int32:
other = other.value other = other.value
return self.__class__(self._val.value * other) return self.__class__(self._val.value * other)
def __truediv__(self, other): def __truediv__(self, other: Any):
return self // other return self // other
def __floordiv__(self, other): def __floordiv__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.__class__(self.value // other) return self.__class__(self.value // other)
@ -90,10 +90,12 @@ class Int32:
other = other.value other = other.value
return self.__class__(self.value >> other) return self.__class__(self.value >> other)
def __eq__(self, other: Union["Int32", int]): def __eq__(self, other: object) -> bool:
if isinstance(other, Int32): if isinstance(other, int):
other = other.value return self.value == other
return self.value == other elif isinstance(other, Int32):
return self.value == other.value
return False
def __neg__(self): def __neg__(self):
return self.__class__(-self._val.value) return self.__class__(-self._val.value)
@ -110,28 +112,28 @@ class Int32:
def __str__(self): def __str__(self):
return str(self.value) return str(self.value)
def __format__(self, format_spec): def __format__(self, format_spec: str):
return self.value.__format__(format_spec) return self.value.__format__(format_spec)
def __hash__(self): def __hash__(self):
return hash(self.value) return hash(self.value)
def __gt__(self, other): def __gt__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.value > other return self.value > other
def __lt__(self, other): def __lt__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.value < other return self.value < other
def __le__(self, other): def __le__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.value <= other return self.value <= other
def __ge__(self, other): def __ge__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.value >= other return self.value >= other
@ -139,38 +141,38 @@ class Int32:
def __bool__(self): def __bool__(self):
return bool(self.value) return bool(self.value)
def __cmp__(self, other): def __cmp__(self, other: Any):
if isinstance(other, Int32): if isinstance(other, Int32):
other = other.value other = other.value
return self.value.__cmp__(other) return self.value.__cmp__(other)
# right handed binary operators # right handed binary operators
def __radd__(self, other): def __radd__(self, other: Any):
return self + other return self + other
def __rsub__(self, other): def __rsub__(self, other: Any):
return self.__class__(other) - self return self.__class__(other) - self
def __rmul__(self, other): def __rmul__(self, other: Any):
return self * other return self * other
def __rtruediv__(self, other): def __rtruediv__(self, other: Any):
return self.__class__(other) // self return self.__class__(other) // self
def __rfloordiv__(self, other): def __rfloordiv__(self, other: Any):
return self.__class__(other) // self return self.__class__(other) // self
def __rmod__(self, other): def __rmod__(self, other: Any):
return self.__class__(other) % self return self.__class__(other) % self
def __rand__(self, other): def __rand__(self, other: Any):
return self.__class__(other) & self return self.__class__(other) & self
def __ror__(self, other): def __ror__(self, other: Any):
return self.__class__(other) | self return self.__class__(other) | self
def __rxor__(self, other): def __rxor__(self, other: Any):
return self.__class__(other) ^ self return self.__class__(other) ^ self
@property @property
@ -278,4 +280,6 @@ class UInt32(Int32):
:param ammount: Number of positions to shift :param ammount: Number of positions to shift
:return: A new Int32 object representing the shifted value (keeps the signed-ness of the source) :return: A new Int32 object representing the shifted value (keeps the signed-ness of the source)
""" """
return self >> ammount if isinstance(ammount, Int32):
ammount = ammount.value
return UInt32(self.value >> ammount)

@ -43,11 +43,11 @@ class MemorySection(ABC):
self, self,
start: T_RelativeAddress, start: T_RelativeAddress,
end: Optional[T_RelativeAddress] = None, end: Optional[T_RelativeAddress] = None,
fmt: str = None, fmt: Optional[str] = None,
bytes_per_row: int = None, bytes_per_row: Optional[int] = None,
rows: int = 10, rows: int = 10,
group: int = None, group: Optional[int] = None,
highlight: int = None, highlight: Optional[int] = None,
): ):
""" """
Dump the section. If no end is given, the rows around start are printed and start is highlighted. Dump the section. If no end is given, the rows around start are printed and start is highlighted.
@ -152,8 +152,15 @@ class MemorySection(ABC):
) )
) )
def dump_all(self, *args, **kwargs): def dump_all(
self.dump(0, self.size, *args, **kwargs) self,
fmt: Optional[str] = None,
bytes_per_row: Optional[int] = None,
rows: int = 10,
group: Optional[int] = None,
highlight: Optional[int] = None,
):
self.dump(0, self.size, fmt, bytes_per_row, rows, group, highlight)
def __repr__(self): def __repr__(self):
return "{}[{}] at 0x{:08X} (size={}bytes, flags={}, owner={})".format( return "{}[{}] at 0x{:08X} (size={}bytes, flags={}, owner={})".format(

Loading…
Cancel
Save