Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions examples/oop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
class Counter:
def __init__(self, value: int) -> None:
self.value = value
class User:
def __init__(self, name: str) -> None:
self.name = name

def increment(self, amount: int) -> None:
self.value += amount
def __repr__(self) -> str:
return "User(name=" + self.name + ")"


c = Counter(42)
print(c.value)
class Group:
def __init__(self, members: list[User]) -> None:
self.members = members

c.increment(8)
print(c.value)
def add_member(self, user: User) -> None:
self.members.append(user)

def remove_member(self, user: User) -> None:
self.members.remove(user)


g1 = Group([User("Alice"), User("Bob")])
g1.add_member(User("Charlie"))
print(g1.members)
g1.remove_member(g1.members[0])
print(g1.members)
3 changes: 2 additions & 1 deletion src/lython/dialects/binding/PyDialectPythonBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ void ensurePyDialectLoaded(const py11::object &contextObj) {
PYBIND11_MODULE(_site_initialize_0, m) {
m.doc() = "Initialization hooks for the Lython py dialect.";
m.def("register_dialects", &registerPyDialect, py11::arg("registry"));
m.def("context_init_hook", &ensurePyDialectLoaded, py11::arg("context") = py11::none());
m.def("context_init_hook", &ensurePyDialectLoaded,
py11::arg("context") = py11::none());
}
16 changes: 14 additions & 2 deletions src/lython/dialects/cpp/PyDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace py {

void PyDialect::initialize() {
addTypes<IntType, FloatType, BoolType, StrType, ObjectType, NoneType,
TupleType, DictType, ClassType, ExceptionType, TracebackType,
LocationType, FuncSignatureType, FuncType, PrimFuncType>();
TupleType, DictType, ListType, ClassType, ExceptionType,
TracebackType, LocationType, FuncSignatureType, FuncType,
PrimFuncType>();

addOperations<
#define GET_OP_LIST
Expand Down Expand Up @@ -91,6 +92,14 @@ Type PyDialect::parseType(DialectAsmParser &parser) const {
return Type();
return DictType::get(ctx, keyType, valueType);
}
if (keyword == "list") {
if (parser.parseLess())
return Type();
Type elementType;
if (parser.parseType(elementType) || parser.parseGreater())
return Type();
return ListType::get(ctx, elementType);
}
if (keyword == "class") {
if (parser.parseLess())
return Type();
Expand Down Expand Up @@ -217,6 +226,9 @@ void PyDialect::printType(Type type, DialectAsmPrinter &printer) const {
printer << "dict<" << dictTy.getKeyType() << ", "
<< dictTy.getValueType() << ">";
})
.Case<ListType>([&](ListType listTy) {
printer << "list<" << listTy.getElementType() << ">";
})
.Case<ClassType>([&](ClassType classTy) {
printer << "class<\"" << classTy.getClassName() << "\">";
})
Expand Down
32 changes: 25 additions & 7 deletions src/lython/dialects/cpp/PyDialectTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ DictTypeStorage *DictTypeStorage::construct(TypeStorageAllocator &allocator,
DictTypeStorage(key.first, key.second);
}

ListTypeStorage *ListTypeStorage::construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ListTypeStorage>()) ListTypeStorage(key);
}

ClassTypeStorage *ClassTypeStorage::construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ClassTypeStorage>())
Expand Down Expand Up @@ -105,6 +110,12 @@ Type DictType::getKeyType() const { return getImpl()->keyType; }

Type DictType::getValueType() const { return getImpl()->valueType; }

ListType ListType::get(MLIRContext *ctx, Type elementType) {
return Base::get(ctx, elementType);
}

Type ListType::getElementType() const { return getImpl()->elementType; }

ClassType ClassType::get(MLIRContext *ctx, ::llvm::StringRef className) {
return Base::get(ctx, className);
}
Expand Down Expand Up @@ -193,6 +204,8 @@ bool isPyTupleType(Type type) { return mlir::isa<TupleType>(type); }

bool isPyDictType(Type type) { return mlir::isa<DictType>(type); }

bool isPyListType(Type type) { return mlir::isa<ListType>(type); }

bool isPyClassType(Type type) { return mlir::isa<ClassType>(type); }

bool isPyExceptionType(Type type) { return mlir::isa<ExceptionType>(type); }
Expand All @@ -207,15 +220,13 @@ bool isPyFuncType(Type type) { return mlir::isa<FuncType>(type); }

bool isPyPrimFuncType(Type type) { return mlir::isa<PrimFuncType>(type); }

bool isCallableType(Type type) {
return mlir::isa<FuncType>(type) || mlir::isa<ClassType>(type);
}
bool isCallableType(Type type) { return mlir::isa<FuncType>(type); }

bool isPyType(Type type) {
return llvm::TypeSwitch<Type, bool>(type)
.Case<IntType, FloatType, BoolType, StrType, ObjectType, NoneType,
TupleType, DictType, ClassType, ExceptionType, TracebackType,
LocationType, FuncType>([](auto) { return true; })
TupleType, DictType, ListType, ClassType, ExceptionType,
TracebackType, LocationType, FuncType>([](auto) { return true; })
.Default([](Type) { return false; });
}

Expand All @@ -228,9 +239,9 @@ bool isSubtypeOf(Type subtype, Type supertype) {
if (subtype == supertype)
return true;

// Top type: T <: !py.object for all T
// Top type: object-world T <: !py.object
if (mlir::isa<ObjectType>(supertype))
return isPyType(subtype);
return isPyType(subtype) && !mlir::isa<ClassType>(subtype);

// Tuple covariance: !py.tuple<S> <: !py.tuple<T> if S <: T
auto subtypeTuple = mlir::dyn_cast<TupleType>(subtype);
Expand All @@ -256,6 +267,13 @@ bool isSubtypeOf(Type subtype, Type supertype) {
supertypeDict.getValueType());
}

auto subtypeList = mlir::dyn_cast<ListType>(subtype);
auto supertypeList = mlir::dyn_cast<ListType>(supertype);
if (subtypeList && supertypeList) {
return isSubtypeOf(subtypeList.getElementType(),
supertypeList.getElementType());
}

// No other subtype relations in v2.1
// TODO(v3): Add class hierarchy support
return false;
Expand Down
44 changes: 35 additions & 9 deletions src/lython/dialects/cpp/PyDialectTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ enum class TypeKind : unsigned {
None,
Tuple,
Dict,
List,
Class, // User-defined class type
Exception,
Traceback,
Expand Down Expand Up @@ -77,6 +78,19 @@ struct DictTypeStorage : public mlir::TypeStorage {
mlir::Type valueType;
};

struct ListTypeStorage : public mlir::TypeStorage {
using KeyTy = mlir::Type;

explicit ListTypeStorage(mlir::Type element) : elementType(element) {}

bool operator==(const KeyTy &key) const { return key == elementType; }

static ListTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key);

mlir::Type elementType;
};

struct ClassTypeStorage : public mlir::TypeStorage {
using KeyTy = ::llvm::StringRef;

Expand Down Expand Up @@ -264,9 +278,22 @@ class ClassType : public mlir::Type::TypeBase<ClassType, mlir::Type,
::llvm::StringRef getClassName() const;
};

class ExceptionType
: public mlir::Type::TypeBase<ExceptionType, mlir::Type,
detail::SimpleTypeStorage> {
class ListType : public mlir::Type::TypeBase<ListType, mlir::Type,
detail::ListTypeStorage> {
public:
using Base::Base;
static constexpr ::llvm::StringLiteral name{"py.list"};

static ListType get(mlir::MLIRContext *ctx, mlir::Type elementType);
static bool kindof(unsigned kind) {
return kind == static_cast<unsigned>(TypeKind::List);
}

mlir::Type getElementType() const;
};

class ExceptionType : public mlir::Type::TypeBase<ExceptionType, mlir::Type,
detail::SimpleTypeStorage> {
public:
using Base::Base;
static constexpr ::llvm::StringLiteral name{"py.exception"};
Expand All @@ -277,9 +304,8 @@ class ExceptionType
}
};

class TracebackType
: public mlir::Type::TypeBase<TracebackType, mlir::Type,
detail::SimpleTypeStorage> {
class TracebackType : public mlir::Type::TypeBase<TracebackType, mlir::Type,
detail::SimpleTypeStorage> {
public:
using Base::Base;
static constexpr ::llvm::StringLiteral name{"py.traceback"};
Expand All @@ -290,9 +316,8 @@ class TracebackType
}
};

class LocationType
: public mlir::Type::TypeBase<LocationType, mlir::Type,
detail::SimpleTypeStorage> {
class LocationType : public mlir::Type::TypeBase<LocationType, mlir::Type,
detail::SimpleTypeStorage> {
public:
using Base::Base;
static constexpr ::llvm::StringLiteral name{"py.location"};
Expand Down Expand Up @@ -364,6 +389,7 @@ bool isPyObjectType(mlir::Type type);
bool isPyNoneType(mlir::Type type);
bool isPyTupleType(mlir::Type type);
bool isPyDictType(mlir::Type type);
bool isPyListType(mlir::Type type);
bool isPyClassType(mlir::Type type);
bool isPyExceptionType(mlir::Type type);
bool isPyTracebackType(mlir::Type type);
Expand Down
84 changes: 21 additions & 63 deletions src/lython/dialects/cpp/PyVerifier/Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ LogicalResult CallOp::verify() {
Type callableType = getCallable().getType();

if (!isCallableType(callableType))
return emitOpError("callable operand must be !py.func or !py.class");
return emitOpError("callable operand must be !py.func");

FailureOr<ThrowEffect> effectOr =
resolveCallableThrowEffect(getOperation(), getCallable());
Expand All @@ -21,51 +21,6 @@ LogicalResult CallOp::verify() {
FuncSignatureType signature;
if (FuncType funcTy = dyn_cast<FuncType>(callableType)) {
signature = funcTy.getSignature();
} else if (ClassType classTy = dyn_cast<ClassType>(callableType)) {
ClassOp classOp = lookupClassSymbol(getOperation(), classTy);
if (!classOp)
return emitOpError("unable to resolve class '")
<< classTy.getClassName() << "'";
FuncOp callMethod = lookupMethodByName(classOp, "__call__");
if (!callMethod) {
emitOpError("class '")
<< classTy.getClassName() << "' does not define a '__call__' method";
return failure();
}
mlir::TypeAttr fnTypeAttr = callMethod.getFunctionTypeAttr();
if (!fnTypeAttr) {
callMethod.emitOpError("requires 'function_type' attribute");
return failure();
}
FuncSignatureType methodSig =
dyn_cast<FuncSignatureType>(fnTypeAttr.getValue());
if (!methodSig) {
callMethod.emitOpError("'function_type' must be a FuncSignatureType");
return failure();
}

mlir::ArrayRef<mlir::Type> positionalTypes = methodSig.getPositionalTypes();
if (positionalTypes.empty()) {
callMethod.emitOpError(
"__call__ must declare a positional 'self' parameter");
return failure();
}

Type selfType = positionalTypes.front();
if (selfType != classTy && !isPyObjectType(selfType)) {
callMethod.emitOpError(
"first positional parameter must be of type !py.class<")
<< classOp.getSymName() << "> or !py.object";
return failure();
}

signature = FuncSignatureType::get(
methodSig.getContext(),
ArrayRef<Type>(positionalTypes.begin() + 1, positionalTypes.end()),
methodSig.getKwOnlyTypes(),
methodSig.hasVararg() ? methodSig.getVarargType() : Type(),
methodSig.hasKwarg() ? methodSig.getKwargType() : Type(),
methodSig.getResultTypes());
} else {
return emitOpError("unexpected callable type");
}
Expand All @@ -77,8 +32,11 @@ LogicalResult CallOp::verify() {
bool homogeneous = tupleElems.size() == 1;

mlir::ArrayRef<mlir::Type> positionalTypes = signature.getPositionalTypes();
unsigned minPositionalCount =
getMinimumPositionalCountForCallable(signature, getCallable());
if (!signature.hasVararg()) {
if (tupleElems.size() != positionalTypes.size())
if (tupleElems.size() < minPositionalCount ||
tupleElems.size() > positionalTypes.size())
return emitOpError(
"posargs length mismatch with callee positional parameters");
for (auto [elemType, expected] : llvm::zip(tupleElems, positionalTypes))
Expand Down Expand Up @@ -143,10 +101,10 @@ LogicalResult CallOp::verify() {
}

LogicalResult CallVectorOp::verify() {
ArrayAttr expectedArgNamesAttr;
ArrayAttr expectedKwNamesAttr;
FailureOr<FuncSignatureType> signatureOr =
resolveCallableSignature(getOperation(), getCallable(),
expectedKwNamesAttr);
FailureOr<FuncSignatureType> signatureOr = resolveCallableSignature(
getOperation(), getCallable(), expectedArgNamesAttr, expectedKwNamesAttr);
if (failed(signatureOr))
return failure();
FuncSignatureType signature = *signatureOr;
Expand All @@ -159,9 +117,9 @@ LogicalResult CallVectorOp::verify() {
if (effect == ThrowEffect::MayThrow)
return emitOpError("maythrow callee must be invoked with py.invoke");

if (failed(verifyVectorCallOperands(getOperation(), signature, getPosargs(),
getKwnames(), getKwvalues(),
expectedKwNamesAttr)))
if (failed(verifyVectorCallOperands(
getOperation(), signature, getCallable(), getPosargs(), getKwnames(),
getKwvalues(), expectedArgNamesAttr, expectedKwNamesAttr)))
return failure();

auto resultTypes = signature.getResultTypes();
Expand All @@ -176,10 +134,10 @@ LogicalResult CallVectorOp::verify() {
}

LogicalResult InvokeOp::verify() {
ArrayAttr expectedArgNamesAttr;
ArrayAttr expectedKwNamesAttr;
FailureOr<FuncSignatureType> signatureOr =
resolveCallableSignature(getOperation(), getCallable(),
expectedKwNamesAttr);
FailureOr<FuncSignatureType> signatureOr = resolveCallableSignature(
getOperation(), getCallable(), expectedArgNamesAttr, expectedKwNamesAttr);
if (failed(signatureOr))
return failure();
FuncSignatureType signature = *signatureOr;
Expand All @@ -192,15 +150,14 @@ LogicalResult InvokeOp::verify() {
if (effect == ThrowEffect::NoThrow)
return emitOpError("nothrow callee must be invoked with py.call");

if (failed(verifyVectorCallOperands(getOperation(), signature, getPosargs(),
getKwnames(), getKwvalues(),
expectedKwNamesAttr)))
if (failed(verifyVectorCallOperands(
getOperation(), signature, getCallable(), getPosargs(), getKwnames(),
getKwvalues(), expectedArgNamesAttr, expectedKwNamesAttr)))
return failure();

auto resultTypes = signature.getResultTypes();
bool allNone = llvm::all_of(resultTypes, [](Type type) {
return isPyNoneType(type);
});
bool allNone =
llvm::all_of(resultTypes, [](Type type) { return isPyNoneType(type); });
if (getNormalDestOperands().empty() && allNone) {
// Allow dropping !py.none results for statement invokes.
} else {
Expand Down Expand Up @@ -253,7 +210,8 @@ LogicalResult NativeCallOp::verify() {

if (fnType.getNumResults() != getNumResults())
return emitOpError("result count mismatch with callee signature");
for (auto [result, expected] : llvm::zip(getResultTypes(), fnType.getResults())) {
for (auto [result, expected] :
llvm::zip(getResultTypes(), fnType.getResults())) {
if (result != expected)
return emitOpError("result types must match callee return types");
if (isPyType(expected))
Expand Down
2 changes: 2 additions & 0 deletions src/lython/dialects/cpp/PyVerifier/Casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ LogicalResult UpcastOp::verify() {

if (!isPyType(inputType))
return emitOpError("input must be a !py.* type");
if (isa<ClassType>(inputType))
return emitOpError("static class instances cannot be upcast to !py.object");
if (!isPyObjectType(resultType))
return emitOpError("result must be of type !py.object");

Expand Down
Loading
Loading