Skip to content

Commit c4d37ff

Browse files
pirapiraclaude
andcommitted
Continue Phase 9a: int/bytes subclass interop with builtins
Enable int/bytes subclass instances to work seamlessly with Python builtins and operators. Key changes: hex/bin/oct/range/divmod accept instances with wrappedValue, list indexing/slicing coerces instance indices, comparison operators unwrap instances for sorted() support, and bytes gains __radd__/__lt__/__le__/__gt__/__ge__ dunder dispatch. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6599975 commit c4d37ff

5 files changed

Lines changed: 187 additions & 32 deletions

File tree

LeanPython/Interpreter/Eval.lean

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,47 @@ private def dummySpan : SourceSpan :=
6060
{ start := { line := 0, column := 0, offset := 0 }
6161
stop := { line := 0, column := 0, offset := 0 } }
6262

63-
private def computeSliceIndices (len : Int) (startOpt stopOpt stepOpt : Option Value)
63+
/-- Try to coerce a Value to Int for indexing (handles int subclass instances). -/
64+
private def coerceToInt (v : Value) : Option Int :=
65+
match v with
66+
| .int n => some n
67+
| .bool b => some (if b then 1 else 0)
68+
| _ => none
69+
70+
/-- Try to coerce a Value to Int, including instances with wrappedValue. -/
71+
private partial def coerceToIntM (v : Value) : InterpM (Option Int) := do
72+
match v with
73+
| .int n => pure (some n)
74+
| .bool b => pure (some (if b then 1 else 0))
75+
| .instance iref => do
76+
let id_ ← heapGetInstanceData iref
77+
match id_.wrappedValue with
78+
| some (.int n) => pure (some n)
79+
| _ => pure none
80+
| _ => pure none
81+
82+
private partial def computeSliceIndices (len : Int) (startOpt stopOpt stepOpt : Option Value)
6483
: InterpM (Int × Int × Int) := do
65-
let step : Int := match stepOpt with
66-
| some (.int s) => s
67-
| some (.none) | none => 1
68-
| _ => 1
84+
let extractSliceInt (v : Value) : InterpM (Option Int) := coerceToIntM v
85+
let step : Int ← match stepOpt with
86+
| some v => do
87+
match ← extractSliceInt v with
88+
| some s => pure s
89+
| none => pure 1
90+
| none => pure 1
6991
if step == 0 then throwValueError "slice step cannot be zero"
70-
let start : Int := match startOpt with
71-
| some (.int s) => if s < 0 then max 0 (len + s) else min s len
72-
| some (.none) | none => if step > 0 then 0 else len - 1
73-
| _ => if step > 0 then 0 else len - 1
74-
let stop : Int := match stopOpt with
75-
| some (.int s) => if s < 0 then max 0 (len + s) else min s len
76-
| some (.none) | none => if step > 0 then len else -1
77-
| _ => if step > 0 then len else -1
92+
let start : Int ← match startOpt with
93+
| some v => do
94+
match ← extractSliceInt v with
95+
| some s => pure (if s < 0 then max 0 (len + s) else min s len)
96+
| none => pure (if step > 0 then 0 else len - 1)
97+
| none => pure (if step > 0 then 0 else len - 1)
98+
let stop : Int ← match stopOpt with
99+
| some v => do
100+
match ← extractSliceInt v with
101+
| some s => pure (if s < 0 then max 0 (len + s) else min s len)
102+
| none => pure (if step > 0 then len else -1)
103+
| none => pure (if step > 0 then len else -1)
78104
return (start, stop, step)
79105

80106
private def sliceIndices (start stop step : Int) : List Nat :=
@@ -1659,6 +1685,9 @@ partial def callValueDispatch (callee : Value) (args : List Value)
16591685
| "__add__" => match args with
16601686
| [a, b] => return .bytes ((← extractBytes a) ++ (← extractBytes b))
16611687
| _ => throwTypeError "bytes.__add__ takes 2 arguments"
1688+
| "__radd__" => match args with
1689+
| [a, b] => return .bytes ((← extractBytes b) ++ (← extractBytes a))
1690+
| _ => throwTypeError "bytes.__radd__ takes 2 arguments"
16621691
| "__mul__" | "__rmul__" => match args with
16631692
| [a, .int n] => do
16641693
let b ← extractBytes a
@@ -1674,6 +1703,26 @@ partial def callValueDispatch (callee : Value) (args : List Value)
16741703
| "__ne__" => match args with
16751704
| [a, b] => return .bool ((← extractBytes a) != (← extractBytes b))
16761705
| _ => throwTypeError "bytes.__ne__ takes 2 arguments"
1706+
| "__lt__" => match args with
1707+
| [a, b] => do
1708+
let ba ← extractBytes a; let bb ← extractBytes b
1709+
return .bool (ba.toList < bb.toList)
1710+
| _ => throwTypeError "bytes.__lt__ takes 2 arguments"
1711+
| "__le__" => match args with
1712+
| [a, b] => do
1713+
let ba ← extractBytes a; let bb ← extractBytes b
1714+
return .bool (ba.toList < bb.toList || ba == bb)
1715+
| _ => throwTypeError "bytes.__le__ takes 2 arguments"
1716+
| "__gt__" => match args with
1717+
| [a, b] => do
1718+
let ba ← extractBytes a; let bb ← extractBytes b
1719+
return .bool (bb.toList < ba.toList)
1720+
| _ => throwTypeError "bytes.__gt__ takes 2 arguments"
1721+
| "__ge__" => match args with
1722+
| [a, b] => do
1723+
let ba ← extractBytes a; let bb ← extractBytes b
1724+
return .bool (bb.toList < ba.toList || ba == bb)
1725+
| _ => throwTypeError "bytes.__ge__ takes 2 arguments"
16771726
| "__hash__" => match args with
16781727
| [a] => do
16791728
let b ← extractBytes a
@@ -2382,6 +2431,15 @@ partial def getAttributeValue (obj : Value) (attr : String) : InterpM Value := d
23822431

23832432
-- Subscript access
23842433
partial def evalSubscriptValue (obj idx : Value) : InterpM Value := do
2434+
-- Coerce instance indices with wrappedValue to plain .int
2435+
let idx ← do
2436+
match idx with
2437+
| .instance iref =>
2438+
let id_ ← heapGetInstanceData iref
2439+
match id_.wrappedValue with
2440+
| some (.int n) => pure (.int n)
2441+
| _ => pure idx
2442+
| _ => pure idx
23852443
match obj with
23862444
| .list ref => do
23872445
let arr ← heapGetList ref

LeanPython/Runtime/Builtins.lean

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ open LeanPython.Stdlib.Os
3636
open LeanPython.Stdlib.Time
3737
open LeanPython.Stdlib.Logging
3838

39+
-- ============================================================
40+
-- Helpers
41+
-- ============================================================
42+
43+
/-- Extract an Int from a Value, handling int subclass instances (wrappedValue). -/
44+
partial def extractIntValue (v : Value) : InterpM Int := do
45+
match v with
46+
| .int n => pure n
47+
| .bool b => pure (if b then 1 else 0)
48+
| .instance iref => do
49+
let id_ ← heapGetInstanceData iref
50+
match id_.wrappedValue with
51+
| some (.int n) => pure n
52+
| _ => throwTypeError s!"'{typeName v}' cannot be interpreted as an integer"
53+
| _ => throwTypeError s!"'{typeName v}' cannot be interpreted as an integer"
54+
3955
-- ============================================================
4056
-- Individual builtin implementations
4157
-- ============================================================
@@ -71,10 +87,11 @@ partial def builtinLen (args : List Value) : InterpM Value := do
7187
/-- `range(stop)` or `range(start, stop[, step])` - materializes to list. -/
7288
partial def builtinRange (args : List Value) : InterpM Value := do
7389
let (start_, stop_, step_) ← match args with
74-
| [.int n] => pure (0, n, (1 : Int))
75-
| [.int s, .int e] => pure (s, e, (1 : Int))
76-
| [.int s, .int e, .int st] => pure (s, e, st)
77-
| [.bool b] => pure (0, (if b then 1 else 0 : Int), (1 : Int))
90+
| [a] => do let n ← extractIntValue a; pure (0, n, (1 : Int))
91+
| [a, b] => do let s ← extractIntValue a; let e ← extractIntValue b; pure (s, e, (1 : Int))
92+
| [a, b, c] => do
93+
let s ← extractIntValue a; let e ← extractIntValue b; let st ← extractIntValue c
94+
pure (s, e, st)
7895
| _ => throwTypeError "range() requires int arguments"
7996
if step_ == 0 then
8097
throwValueError "range() arg 3 must not be zero"
@@ -482,9 +499,11 @@ def builtinPow (args : List Value) : InterpM Value := do
482499
| _ => throwTypeError "pow() takes 2 or 3 arguments"
483500

484501
/-- `divmod(a, b)` -/
485-
def builtinDivmod (args : List Value) : InterpM Value := do
502+
partial def builtinDivmod (args : List Value) : InterpM Value := do
486503
match args with
487-
| [.int a, .int b] =>
504+
| [va, vb] => do
505+
let a ← extractIntValue va
506+
let b ← extractIntValue vb
488507
if b == 0 then throwZeroDivision "integer division or modulo by zero"
489508
else return .tuple #[.int (Int.fdiv a b), .int (Int.fmod a b)]
490509
| _ => throwTypeError "divmod() requires two integer arguments"
@@ -634,23 +653,26 @@ partial def callBuiltin (name : String) (args : List Value)
634653
| _ => throwTypeError "input() takes at most 1 argument"
635654
| "hex" => do
636655
match args with
637-
| [.int n] =>
656+
| [v] => do
657+
let n ← extractIntValue v
638658
let digits := Nat.toDigits 16 n.natAbs
639659
let s := String.ofList digits
640660
if n >= 0 then return .str s!"0x{s}"
641661
else return .str s!"-0x{s}"
642662
| _ => throwTypeError "hex() takes exactly one argument"
643663
| "oct" => do
644664
match args with
645-
| [.int n] =>
665+
| [v] => do
666+
let n ← extractIntValue v
646667
let digits := Nat.toDigits 8 n.natAbs
647668
let s := String.ofList digits
648669
if n >= 0 then return .str s!"0o{s}"
649670
else return .str s!"-0o{s}"
650671
| _ => throwTypeError "oct() takes exactly one argument"
651672
| "bin" => do
652673
match args with
653-
| [.int n] =>
674+
| [v] => do
675+
let n ← extractIntValue v
654676
let digits := Nat.toDigits 2 n.natAbs
655677
let s := String.ofList digits
656678
if n >= 0 then return .str s!"0b{s}"

LeanPython/Runtime/Ops.lean

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -594,42 +594,59 @@ private def isIdentical : Value → Value → Bool
594594
| .instance a, .instance b => a == b
595595
| _, _ => false
596596

597+
/-- Unwrap an instance's wrappedValue to a primitive value for comparison. -/
598+
private partial def unwrapForCmp (v : Value) : InterpM Value := do
599+
match v with
600+
| .instance iref =>
601+
let id_ ← heapGetInstanceData iref
602+
match id_.wrappedValue with
603+
| some inner => pure inner
604+
| none => pure v
605+
| _ => pure v
606+
597607
/-- Evaluate a single comparison operation. -/
598-
partial def evalCmpOp (op : CmpOp) (left right : Value) : InterpM Bool :=
608+
partial def evalCmpOp (op : CmpOp) (left right : Value) : InterpM Bool := do
609+
-- Unwrap instances with wrappedValue for comparison dispatch
610+
let left' ← unwrapForCmp left
611+
let right' ← unwrapForCmp right
599612
match op with
600613
| .eq => valueEq left right
601614
| .notEq => do return !(← valueEq left right)
602615
| .lt => do
603-
match ← numericCompare left right with
616+
match ← numericCompare left' right' with
604617
| some .lt => return true
605618
| some _ => return false
606619
| none =>
607-
match left, right with
620+
match left', right' with
608621
| .str a, .str b => return (a < b)
622+
| .bytes a, .bytes b => return (a.toList < b.toList)
609623
| _, _ => throwTypeError s!"'<' not supported between instances of '{typeName left}' and '{typeName right}'"
610624
| .ltE => do
611-
match ← numericCompare left right with
625+
match ← numericCompare left' right' with
612626
| some .gt => return false
613627
| some _ => return true
614628
| none =>
615-
match left, right with
629+
match left', right' with
616630
| .str a, .str b => return (decide (a ≤ b))
631+
| .bytes a, .bytes b => return (a.toList < b.toList || a == b)
617632
| _, _ => throwTypeError s!"'<=' not supported between instances of '{typeName left}' and '{typeName right}'"
618633
| .gt => do
619-
match ← numericCompare left right with
634+
match ← numericCompare left' right' with
620635
| some .gt => return true
621636
| some _ => return false
622637
| none =>
623-
match left, right with
638+
match left', right' with
624639
| .str a, .str b => return (decide (a > b))
640+
| .bytes a, .bytes b => return (b.toList < a.toList)
625641
| _, _ => throwTypeError s!"'>' not supported between instances of '{typeName left}' and '{typeName right}'"
626642
| .gtE => do
627-
match ← numericCompare left right with
643+
match ← numericCompare left' right' with
628644
| some .lt => return false
629645
| some _ => return true
630646
| none =>
631-
match left, right with
647+
match left', right' with
632648
| .str a, .str b => return (decide (a ≥ b))
649+
| .bytes a, .bytes b => return (b.toList < a.toList || a == b)
633650
| _, _ => throwTypeError s!"'>=' not supported between instances of '{typeName left}' and '{typeName right}'"
634651
| .is_ => return (isIdentical left right)
635652
| .isNot => return (!isIdentical left right)

LeanPythonTest/Stdlib.lean

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,61 @@ private def assertPyError (source errSubstr : String) : IO Unit := do
826826

827827
-- @classmethod + @override stacking (leanSpec pattern)
828828
#eval assertPy "from abc import ABC, abstractmethod\nfrom typing import override\nclass SSZType(ABC):\n @classmethod\n @abstractmethod\n def is_fixed_size(cls):\n pass\n @classmethod\n @abstractmethod\n def get_byte_length(cls):\n pass\nclass BaseUint(int, SSZType):\n __slots__ = ()\n BITS = 64\n def __new__(cls, value):\n return super().__new__(cls, int(value))\n @classmethod\n @override\n def is_fixed_size(cls):\n return True\n @classmethod\n @override\n def get_byte_length(cls):\n return cls.BITS // 8\nclass Uint64(BaseUint):\n BITS = 64\nprint(Uint64.is_fixed_size())\nprint(Uint64.get_byte_length())" "True\n8\n"
829+
830+
-- ============================================================
831+
-- Phase 9a continued: int/bytes subclass interop with builtins
832+
-- ============================================================
833+
834+
-- hex() on int subclass instance
835+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint64(BaseUint):\n pass\nprint(hex(Uint64(42)))\nprint(hex(Uint64(255)))" "0x2a\n0xff\n"
836+
837+
-- bin() on int subclass instance
838+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint8(BaseUint):\n pass\nprint(bin(Uint8(42)))" "0b101010\n"
839+
840+
-- oct() on int subclass instance
841+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint8(BaseUint):\n pass\nprint(oct(Uint8(42)))" "0o52\n"
842+
843+
-- range() with int subclass instances
844+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint8(BaseUint):\n pass\nresult = list(range(Uint8(5)))\nprint(result)" "[0, 1, 2, 3, 4]\n"
845+
846+
-- range(start, stop, step) with int subclass instances
847+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint8(BaseUint):\n pass\nresult = list(range(Uint8(2), Uint8(8), Uint8(2)))\nprint(result)" "[2, 4, 6]\n"
848+
849+
-- divmod() with int subclass instances
850+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint64(BaseUint):\n pass\nresult = divmod(Uint64(100), Uint64(3))\nprint(result[0])\nprint(result[1])" "33\n1\n"
851+
852+
-- List indexing with int subclass instance
853+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint64(BaseUint):\n pass\ndata = ['a', 'b', 'c', 'd', 'e']\nprint(data[Uint64(2)])\nprint(data[Uint64(0)])\nprint(data[Uint64(4)])" "c\na\ne\n"
854+
855+
-- Slice with int subclass instances
856+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint8(BaseUint):\n pass\ndata = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\nprint(data[Uint8(2):Uint8(7)])" "[2, 3, 4, 5, 6]\n"
857+
858+
-- repr() and str() on int subclass via Python-defined __repr__/__str__
859+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\n def __repr__(self):\n return type(self).__name__ + '(' + str(int(self)) + ')'\n def __str__(self):\n return str(int(self))\nclass Uint64(BaseUint):\n pass\nx = Uint64(42)\nprint(repr(x))\nprint(str(x))" "Uint64(42)\n42\n"
860+
861+
-- int() conversion of int subclass instance
862+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\nclass Uint64(BaseUint):\n pass\nx = Uint64(42)\nresult = int(x)\nprint(result)\nprint(result == 42)" "42\nTrue\n"
863+
864+
-- bytes.__radd__: b"\xff" + bytes_subclass_instance
865+
#eval assertPy "class BaseBytes(bytes):\n def __new__(cls, value=b''):\n return super().__new__(cls, value)\nclass Bytes4(BaseBytes):\n pass\nx = Bytes4(b'\\x01\\x02\\x03\\x04')\nresult = b'\\xff' + x\nprint(len(result))\nprint(result[0])\nprint(result[4])" "5\n255\n4\n"
866+
867+
-- sorted() on bytes subclass instances (lexicographic)
868+
#eval assertPy "class BaseBytes(bytes):\n def __new__(cls, value=b''):\n return super().__new__(cls, value)\nclass Bytes4(BaseBytes):\n pass\na = Bytes4(b'\\x00\\x00\\x00\\x02')\nb = Bytes4(b'\\x00\\x00\\x00\\x01')\nc = Bytes4(b'\\xff\\xff\\xff\\xff')\nresult = sorted([c, a, b])\nprint(bytes(result[0]).hex())\nprint(bytes(result[1]).hex())\nprint(bytes(result[2]).hex())" "00000001\n00000002\nffffffff\n"
869+
870+
-- In-place operators: x += Uint64(5)
871+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\n def __add__(self, other):\n return type(self)(int(self) + int(other))\nclass Uint64(BaseUint):\n pass\nx = Uint64(10)\ny = x\nx += Uint64(5)\nprint(int(x))\nprint(type(x).__name__)\nprint(int(y))" "15\nUint64\n10\n"
872+
873+
-- Reverse operators: plain int + Uint raises TypeError when __radd__ checks type
874+
#eval assertPy "class BaseUint(int):\n def __new__(cls, value):\n return super().__new__(cls, int(value))\n def __radd__(self, other):\n if not isinstance(other, BaseUint):\n raise TypeError('bad type for +')\n return type(self)(int(other) + int(self))\nclass Uint64(BaseUint):\n pass\ntry:\n result = 100 + Uint64(3)\n print('no error')\nexcept TypeError as e:\n print(str(e))" "bad type for +\n"
875+
876+
-- io.BytesIO seek and read round-trip
877+
#eval assertPy "import io\nstream = io.BytesIO()\nstream.write(b'\\x01\\x02\\x03\\x04')\nstream.seek(0)\ndata = stream.read()\nprint(len(data))\nprint(data[0])\nprint(data[3])" "4\n1\n4\n"
878+
879+
-- BaseBytes __add__ returns raw bytes
880+
#eval assertPy "class BaseBytes(bytes):\n def __new__(cls, value=b''):\n return super().__new__(cls, value)\nclass Bytes4(BaseBytes):\n pass\na = Bytes4(b'\\x01\\x02\\x03\\x04')\nb = Bytes4(b'\\x05\\x06\\x07\\x08')\nresult = a + b\nprint(len(result))\nprint(result[0])\nprint(result[7])" "8\n1\n8\n"
881+
882+
-- Boolean type: construction, validation, arithmetic rejection
883+
#eval assertPy "class Boolean(int):\n def __new__(cls, value):\n if not isinstance(value, int):\n raise TypeError('Expected bool or int')\n if value not in (0, 1):\n raise ValueError('Boolean value must be 0 or 1')\n return super().__new__(cls, value)\n def __add__(self, other):\n raise TypeError('Arithmetic not supported for Boolean.')\n def __eq__(self, other):\n return isinstance(other, int) and int(self) == int(other)\nbt = Boolean(True)\nbf = Boolean(False)\nprint(int(bt))\nprint(int(bf))\nprint(bt == 1)\nprint(bf == 0)\ntry:\n bt + bf\nexcept TypeError as e:\n print(str(e))" "1\n0\nTrue\nTrue\nArithmetic not supported for Boolean.\n"
884+
885+
-- Boolean rejects invalid values
886+
#eval assertPyError "class Boolean(int):\n def __new__(cls, value):\n if value not in (0, 1):\n raise ValueError('must be 0 or 1')\n return super().__new__(cls, value)\nBoolean(2)" "must be 0 or 1"

0 commit comments

Comments
 (0)