diff --git a/riscemu/priv/CSR.py b/riscemu/priv/CSR.py index 26b204c..ef35763 100644 --- a/riscemu/priv/CSR.py +++ b/riscemu/priv/CSR.py @@ -1,6 +1,7 @@ -from typing import Dict, Union, Callable +from typing import Dict, Union, Callable, Optional from collections import defaultdict -from functools import wraps +from .privmodes import PrivModes +from .Exceptions import IllegalInstructionTrap MSTATUS_OFFSETS = { 'uie': 0, @@ -68,27 +69,22 @@ class CSR: self.listeners = defaultdict(lambda: (lambda x, y: None)) def set(self, addr: Union[str, int], val: int): - if isinstance(addr, str): - if addr not in self.name_to_addr: - print("Unknown CSR register {}".format(addr)) - return - addr = self.name_to_addr[addr] + addr = self._addr_to_name(addr) + if addr is None: + return self.listeners[addr](self.regs[addr], val) self.regs[addr] = val def get(self, addr: Union[str, int]): - if isinstance(addr, str): - if addr not in self.name_to_addr: - print("Unknown CSR register {}".format(addr)) - return - addr = self.name_to_addr[addr] + addr = self._addr_to_name(addr) + if addr is None: + return return self.regs[addr] def set_listener(self, addr: Union[str, int], listener: Callable[[int, int], None]): - if isinstance(addr, str): - if not addr in self.name_to_addr: - print("Unknown CSR register {}".format(addr)) - addr = self.name_to_addr[addr] + addr = self._addr_to_name(addr) + if addr is None: + return self.listeners[addr] = listener # mstatus properties @@ -121,4 +117,20 @@ class CSR: def inner(func: Callable[[int, int], None]): self.set_listener(addr, func) return func - return inner \ No newline at end of file + return inner + + def assert_can_read(self, mode: PrivModes, addr: int): + if (addr >> 8) & 3 > mode.value(): + raise IllegalInstructionTrap() + + def assert_can_write(self, mode: PrivModes, addr: int): + if (addr >> 8) & 3 > mode.value() or addr >> 10 == 11: + raise IllegalInstructionTrap() + + def _addr_to_name(self, addr: Union[str, int]) -> Optional[int]: + if isinstance(addr, str): + if addr not in self.name_to_addr: + print("Unknown CSR register {}".format(addr)) + return None + return self.name_to_addr[addr] + return addr diff --git a/riscemu/priv/PrivRV32I.py b/riscemu/priv/PrivRV32I.py index 3a25d52..f0b5c23 100644 --- a/riscemu/priv/PrivRV32I.py +++ b/riscemu/priv/PrivRV32I.py @@ -21,11 +21,15 @@ class PrivRV32I(RV32I): """ def instruction_csrrw(self, ins: 'LoadedInstruction'): - rd, rs, ind = self.parse_crs_ins(ins) + rd, rs, csr_addr = self.parse_crs_ins(ins) if rd != 'zero': - old_val = int_from_bytes(self.cpu.csr[ind]) + self.cpu.csr.assert_can_read(self.cpu.mode, csr_addr) + old_val = self.cpu.csr.get(csr_addr) self.regs.set(rd, old_val) - self.cpu.csr.set(ind, rs) + if rs != 'zero': + new_val = self.regs.get(rs) + self.cpu.csr.assert_can_write(self.cpu.mode, csr_addr) + self.cpu.csr.set(csr_addr, new_val) def instruction_csrrs(self, ins: 'LoadedInstruction'): INS_NOT_IMPLEMENTED(ins) @@ -122,7 +126,7 @@ class PrivRV32I(RV32I): def parse_crs_ins(self, ins: 'LoadedInstruction'): ASSERT_LEN(ins.args, 3) - return ins.get_reg(0), self.get_reg_content(ins, 1), ins.get_imm(2) + return ins.get_reg(0), ins.get_reg(1), ins.get_imm(2) def parse_mem_ins(self, ins: 'LoadedInstruction') -> Tuple[str, int]: ASSERT_LEN(ins.args, 3)