Add support for floats (#22)
Adding a `Float32` datatype is necessary, since python makes no guarantees to the bitwidth of `float` (it's often a double) Also adding the `RV32F` extension with most operations implemented, and support for floating point registers.master
parent
5a23804ad8
commit
be90879f86
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (riscemu)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (riscemu)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
@ -0,0 +1,203 @@
|
||||
import struct
|
||||
from ctypes import c_float
|
||||
from typing import Union, Any
|
||||
|
||||
bytes_t = bytes
|
||||
|
||||
|
||||
class Float32:
|
||||
__slots__ = ("_val",)
|
||||
|
||||
_val: c_float
|
||||
|
||||
@property
|
||||
def value(self) -> float:
|
||||
"""
|
||||
The value represented by this float
|
||||
"""
|
||||
return self._val.value
|
||||
|
||||
@property
|
||||
def bytes(self) -> bytes:
|
||||
"""
|
||||
The values bit representation (as a bytes object)
|
||||
"""
|
||||
return struct.pack("<f", self.value)
|
||||
|
||||
@property
|
||||
def bits(self) -> int:
|
||||
"""
|
||||
The values bit representation as an int (for easy bit manipulation)
|
||||
"""
|
||||
return int.from_bytes(self.bytes, byteorder="little")
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, val: Union[int, bytes_t, bytearray]):
|
||||
if isinstance(val, int):
|
||||
val = int.to_bytes(byteorder="little")
|
||||
return Float32(val)
|
||||
|
||||
def __init__(
|
||||
self, val: Union[float, c_float, "Float32", bytes_t, bytearray, int] = 0
|
||||
):
|
||||
if isinstance(val, (float, int)):
|
||||
self._val = c_float(val)
|
||||
elif isinstance(val, c_float):
|
||||
self._val = c_float(val.value)
|
||||
elif isinstance(val, (bytes, bytearray)):
|
||||
self._val = c_float(struct.unpack("<f", val)[0])
|
||||
elif isinstance(val, Float32):
|
||||
self._val = val._val
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported value passed to Float32: {} ({})".format(
|
||||
repr(val), type(val)
|
||||
)
|
||||
)
|
||||
|
||||
def __add__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.__class__(self.value + other)
|
||||
|
||||
def __sub__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.__class__(self.value - other)
|
||||
|
||||
def __mul__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.__class__(self.value * other)
|
||||
|
||||
def __truediv__(self, other: Any):
|
||||
return self // other
|
||||
|
||||
def __floordiv__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.__class__(self.value // other)
|
||||
|
||||
def __mod__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.__class__(self.value % other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, (float, int)):
|
||||
return self.value == other
|
||||
elif isinstance(other, Float32):
|
||||
return self.value == other.value
|
||||
return False
|
||||
|
||||
def __neg__(self):
|
||||
return self.__class__(-self.value)
|
||||
|
||||
def __abs__(self):
|
||||
return self.__class__(abs(self.value))
|
||||
|
||||
def __bytes__(self) -> bytes_t:
|
||||
return self.bytes
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({})".format(self.__class__.__name__, self.value)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __format__(self, format_spec: str):
|
||||
return self.value.__format__(format_spec)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __gt__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.value > other
|
||||
|
||||
def __lt__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.value < other
|
||||
|
||||
def __le__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.value <= other
|
||||
|
||||
def __ge__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.value >= other
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __cmp__(self, other: Any):
|
||||
if isinstance(other, Float32):
|
||||
other = other.value
|
||||
return self.value.__cmp__(other)
|
||||
|
||||
def __pow__(self, power, modulo=None):
|
||||
if modulo is not None:
|
||||
raise ValueError("Float32 pow with modulo unsupported")
|
||||
return self.__class__(self.value**power)
|
||||
|
||||
# right handed binary operators
|
||||
|
||||
def __radd__(self, other: Any):
|
||||
return self + other
|
||||
|
||||
def __rsub__(self, other: Any):
|
||||
return self.__class__(other) - self
|
||||
|
||||
def __rmul__(self, other: Any):
|
||||
return self * other
|
||||
|
||||
def __rtruediv__(self, other: Any):
|
||||
return self.__class__(other) // self
|
||||
|
||||
def __rfloordiv__(self, other: Any):
|
||||
return self.__class__(other) // self
|
||||
|
||||
def __rmod__(self, other: Any):
|
||||
return self.__class__(other) % self
|
||||
|
||||
def __rand__(self, other: Any):
|
||||
return self.__class__(other) & self
|
||||
|
||||
def __ror__(self, other: Any):
|
||||
return self.__class__(other) | self
|
||||
|
||||
def __rxor__(self, other: Any):
|
||||
return self.__class__(other) ^ self
|
||||
|
||||
# bytewise operators:
|
||||
|
||||
def __and__(self, other: Union["Float32", float, int]):
|
||||
if isinstance(other, float):
|
||||
other = Float32(other)
|
||||
if isinstance(other, Float32):
|
||||
other = other.bits
|
||||
return self.from_bytes(self.bits & other)
|
||||
|
||||
def __or__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, float):
|
||||
other = Float32(other)
|
||||
if isinstance(other, Float32):
|
||||
other = other.bits
|
||||
return self.from_bytes(self.bits | other)
|
||||
|
||||
def __xor__(self, other: Union["Float32", float]):
|
||||
if isinstance(other, float):
|
||||
other = Float32(other)
|
||||
if isinstance(other, Float32):
|
||||
other = other.bits
|
||||
return self.from_bytes(self.bits ^ other)
|
||||
|
||||
def __lshift__(self, other: int):
|
||||
return self.from_bytes(self.bits << other)
|
||||
|
||||
def __rshift__(self, other: int):
|
||||
return self.from_bytes(self.bits >> other)
|
@ -0,0 +1,21 @@
|
||||
import math
|
||||
|
||||
from riscemu.types import Float32
|
||||
|
||||
# pi encoded as a 32bit little endian float
|
||||
PI_BYTES_LE = b"\xdb\x0fI@"
|
||||
|
||||
|
||||
def test_float_serialization():
|
||||
assert Float32(PI_BYTES_LE) == Float32(math.pi)
|
||||
assert Float32(math.pi).bytes == PI_BYTES_LE
|
||||
|
||||
|
||||
def test_random_float_ops():
|
||||
val = Float32(5)
|
||||
assert val**2 == 25
|
||||
assert val // 2 == 2
|
||||
assert val * 3 == 15
|
||||
assert val - 2 == 3
|
||||
assert val * val == 25
|
||||
assert Float32(9) ** 0.5 == 3
|
@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from riscemu.registers import Registers
|
||||
from riscemu.types import Float32
|
||||
|
||||
|
||||
def test_float_regs():
|
||||
r = Registers()
|
||||
# uninitialized register is zero
|
||||
assert r.get_f("fs0") == 0
|
||||
# get/set
|
||||
val = Float32(3.14)
|
||||
r.set_f("fs0", val)
|
||||
assert r.get_f("fs0") == val
|
||||
|
||||
|
||||
def test_unlimited_regs_works():
|
||||
r = Registers(infinite_regs=True)
|
||||
r.get("infinite")
|
||||
r.get_f("finfinite")
|
||||
|
||||
|
||||
def test_unknown_reg_fails():
|
||||
r = Registers(infinite_regs=False)
|
||||
with pytest.raises(RuntimeError, match="Invalid register: az1"):
|
||||
r.get("az1")
|
Loading…
Reference in New Issue