Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,6 @@ func (c *context) pick(pick *ast.BinaryExpr, x ast.Expr) (Value, error) {
if err != nil {
return nil, err
}
if val.Type() != tagTyp {
return nil, c.error(pick.Right.Span(),
fmt.Sprintf("#%s requires a value of type %s, got %s",
tag, c.reg.String(tagTyp), c.reg.String(val.Type())))
}
return Variant{ref, tag, val}, nil
}
}
Expand Down
3 changes: 1 addition & 2 deletions eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ var expressions = []struct {
// | "hello " ++ name -> name
// | _ -> "<empty>" <| "hello Oseg"`, Text("Oseg")},
{`box::empty ; box : #empty`, `#empty`},
// TODO: Cannot infer type of `n -> x * 2`.
// {`typ::fun (n -> x * 2) ; typ : #fun (int -> int)`, `#fun n -> x * 2`},
{`typ::fun (x -> x * 2) ; typ : #fun (int -> int)`, `#fun x -> x * 2`},

// Destructuring.
{`{ a = 1, b = 2 } |> | { a = c, b = d } -> c + d`, `3`},
Expand Down
6 changes: 1 addition & 5 deletions scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,7 @@ func (s *Scanner) bytes() (tok token.Token, span token.Span) {
s.next()
}

if s.offset-offs < 2 {
s.error(s.offset, "too short base64 string")
tok = token.BAD
return
}
// The two chars `~~` encodes an empty byte array.

for (s.offset-offs)%4 > 0 {
if s.ch != '=' {
Expand Down
1 change: 1 addition & 0 deletions scanner/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var elements = []elt{
{token.TEXT, `"world"`, literal},
{token.BYTE, "~ca", literal},
{token.BYTES, "~~aGVsbG8gd29ybGQ=", literal},
{token.BYTES, "~~", literal},

// Operators and delimiters
{token.ASSIGN, "=", operator},
Expand Down
254 changes: 100 additions & 154 deletions types/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ func (s *Scope[T]) Bind(name string, val T) *Scope[T] {
return &Scope[T]{s, name, val}
}

func (s *Scope[T]) Rebind(name string, val T) bool {
if bound := s.Get(name); bound != nil {
bound.val = val
return true
}
return false
}

type TypeScope = *Scope[TypeRef]

type context struct {
Expand All @@ -59,13 +51,20 @@ func (c *context) bind(name string, ref TypeRef) TypeScope {
return c.scope
}

func Infer(se ast.SourceExpr) (string, error) {
var reg Registry
var scope TypeScope
// Unbinds the last bound variable.
func (c *context) unbind() {
c.scope = c.scope.parent
}

func DefaultScope() (reg Registry, scope TypeScope) {
for _, p := range primitives {
scope = scope.Bind(reg.String(p), p)
}
return
}

func Infer(se ast.SourceExpr) (string, error) {
reg, scope := DefaultScope()

ref, err := InferInScope(&reg, scope, se)
if err != nil {
Expand All @@ -91,15 +90,21 @@ func InferInScope(reg *Registry, scope TypeScope, se ast.SourceExpr) (ref TypeRe
}
}()

return context.infer(se.Expr), err
ref = context.infer(se.Expr)
return ref, err
}

func (c *context) infer(expr ast.Expr) TypeRef {
switch x := expr.(type) {
case *ast.Literal:
return literalTypeRef(x.Kind)
case *ast.Ident:
return c.scope.Lookup(c.source.GetString(x.Pos))
name := c.source.GetString(x.Pos)
ref := c.scope.Lookup(name)
if ref == NeverRef {
c.bail(x.Pos, "unbound variable: "+name)
}
return c.reg.Instantiate(ref)
case *ast.WhereExpr:
return c.where(x)
case *ast.ListExpr:
Expand All @@ -108,69 +113,61 @@ func (c *context) infer(expr ast.Expr) TypeRef {
return c.record(x)
case ast.EnumExpr:
return c.enum(x)

case *ast.FuncExpr:
unbound := c.reg.Unbound()
// Hold onto the binding, in case inferring the body rebinds its type.
binding := c.bind(c.source.GetString(x.Arg.Span()), unbound)
// Not sure how to juggle vars vs unbound. :/
binder := c.reg.Var()
c.bind(c.source.GetString(x.Arg.Span()), binder)
defer c.unbind()
ret := c.infer(x.Body)
return c.reg.Func(binding.val, ret)
return c.reg.Func(binder, ret)

case *ast.CallExpr:
// Special-case pick with a value.
if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK {
return c.pick(pick, x.Arg)
}

typ := c.infer(x.Fn)
res := c.reg.Var()
fn := c.infer(x.Fn)
arg := c.infer(x.Arg)
c.ensure(x, fn, c.reg.Func(arg, res))
return res

if !typ.IsFunction() {
// If the argument is an unbound identifier, rebind it.
id, ok := x.Fn.(*ast.Ident)
if ok && typ.IsUnbound() {
name := c.source.GetString(id.Pos)
// Let's steal the now unused (?) unbound.
fn := c.reg.Func(arg, typ)
if c.scope.Rebind(name, fn) {
return typ
}

// Let's try to rebind a type.
}

c.bail(x.Span(), fmt.Sprintf("cannot call non-function %s", c.reg.String(typ)))
}
fn := c.reg.GetFunc(typ)

ref := c.call(fn, arg)
if ref != NeverRef {
return ref
case *ast.BinaryExpr:
if x.Op == token.PICK {
return c.pick(x, nil)
}

c.bail(x.Span(), fmt.Sprintf("cannot call %s with %s", c.reg.String(typ), c.reg.String(arg)))

case *ast.BinaryExpr:
left := c.infer(x.Left)
right := c.infer(x.Right)
switch x.Op {
case token.PICK:
return c.pick(x, nil)
case token.PREPEND:
return c.pend(x.Left, x.Right, left, right)
case token.APPEND:
return c.pend(x.Right, x.Left, right, left)
case token.CONCAT:
if left == TextRef {
if right == TextRef {
return TextRef
} else if right.IsUnbound() {

} else {
return NeverRef
}
if left == TextRef || right == TextRef {
c.ensure(x, left, right)
return TextRef
}
case token.ADD:
if left == IntRef {
return c.ensure(x.Right, right, IntRef)
if left == BytesRef || right == BytesRef {
c.ensure(x, left, right)
return BytesRef
}
if right == IntRef {
return c.ensure(x.Left, left, IntRef)
// Local var to ensure left and right are lists.
a := c.reg.List(c.reg.Var())
c.ensure(x, left, right)
c.ensure(x, left, a)
return a
case token.ADD, token.SUB, token.MUL:
if left == FloatRef || right == FloatRef {
c.ensure(x, left, right)
return FloatRef
}
// Assume int, like ML does.
c.ensure(x.Left, left, IntRef)
return c.ensure(x.Right, right, IntRef)
}
panic(fmt.Sprintf("can't infer binary expression %s", x.Op.String()))
}
Expand All @@ -179,78 +176,37 @@ func (c *context) infer(expr ast.Expr) TypeRef {
}

func (c *context) ensure(x ast.Expr, got, want TypeRef) TypeRef {
if got == want {
return got
}

if got.IsUnbound() {
c.rebind(x, want)
return want
}

if c.isAssignable(want, got) {
return got
}

c.bail(x.Span(), fmt.Sprintf("expected %s, got %s", c.reg.String(want), c.reg.String(got)))
return NeverRef
}

func (c *context) call(fn FuncRef, arg TypeRef) TypeRef {
if c.isAssignable(fn.Arg, arg) {
return fn.Result
}

if fn.Arg.IsUnbound() {
return c.reg.Bind(fn.Result, fn.Arg, arg)
}

if fn.Arg.IsFunction() && arg.IsFunction() {
afn := c.reg.GetFunc(fn.Arg)
bfn := c.reg.GetFunc(arg)

// If completely unbound, replace with arg.
if afn.Arg.IsUnbound() && afn.Result.IsUnbound() {
res := c.reg.Bind(fn.Result, afn.Result, bfn.Result)
return c.reg.Bind(res, afn.Arg, bfn.Arg)
}
}

return NeverRef
}

func (c *context) isAssignable(a, b TypeRef) bool {
if a == b {
return true
}
if got != want {
// Really? Must make this API better.
defer func() {
if pnc := recover(); pnc != nil {
if msg, ok := pnc.(string); ok {
c.bail(x.Span(), msg)
} else {
panic(pnc)
}
}
}()

aTag, _ := a.extract()
switch aTag {
case listTag:
if b.IsList() && c.reg.GetList(b).IsUnbound() {
return true
}
c.reg.unify(got, want)
}

return false
return want
}

func (c *context) where(x *ast.WhereExpr) TypeRef {
name := c.source.GetString(x.Id.Pos)
if x.Typ == nil {
// If there is no type annotation, we can infer it from the value.
c.bind(name, c.infer(x.Val))
return c.infer(x.Expr)
}

tRef := c.typ(x.Typ)
vRef := c.infer(x.Val)
if tRef != vRef {
c.bail(x.Val.Span(), fmt.Sprintf("cannot assign %s to %s", c.reg.String(vRef), c.reg.String(tRef)))
tyVal := c.infer(x.Val)

// If there's an annotation, make sure it matches the inferred type.
if x.Typ != nil {
c.ensure(x.Typ, tyVal, c.typ(x.Typ))
}

c.bind(name, tRef)
return c.infer(x.Expr)
c.bind(name, c.reg.generalize(tyVal))
defer c.unbind()
tyExpr := c.infer(x.Expr)
return tyExpr
}

func (c *context) typ(x ast.Expr) TypeRef {
Expand All @@ -273,42 +229,26 @@ func (c *context) typ(x ast.Expr) TypeRef {
return NeverRef
}

func (c *context) list(x *ast.ListExpr) (res TypeRef) {
func (c *context) list(x *ast.ListExpr) TypeRef {
res := NeverRef

for _, v := range x.Elements {
typ := c.infer(v)
if typ == res {
continue
} else if res == NeverRef {

if res == NeverRef {
res = typ
} else if typ.IsUnbound() {
c.rebind(v, res)
} else {
c.bail(v.Span(), "list elements must all be of type "+c.reg.String(res))
// Bad list.
return NeverRef
continue
}

c.ensure(v, res, typ)
}

if res == NeverRef {
res = c.reg.Unbound()
res = c.reg.Var()
}
return c.reg.List(res)
}

// Re-binds the type of expresion x, or fails.
func (c *context) rebind(x ast.Expr, ref TypeRef) {
name := c.source.GetString(x.Span())
_, ok := x.(*ast.Ident)
if ok {
s := c.scope.Get(name)
if s != nil {
// Let's steal the now unused (?) generic.
s.val = ref
return
}
}
c.bail(x.Span(), fmt.Sprintf("can't rebind type of non-identifier %s", name))
}

func (c *context) record(x *ast.RecordExpr) TypeRef {
// If there is a rest/spread, our type is equal to that.
if x.Rest != nil {
Expand Down Expand Up @@ -377,14 +317,7 @@ func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef {
}
} else {
valRef := c.infer(val)
// TODO: check assignability instead
if typ != valRef {
// Wrong type.
c.bail(val.Span(),
fmt.Sprintf("cannot assign %s to #%s which needs %s",
c.reg.String(valRef), tag, c.reg.String(typ)))
return NeverRef
}
c.ensure(val, valRef, typ)
}

return ref
Expand Down Expand Up @@ -412,3 +345,16 @@ func literalTypeRef(tok token.Token) TypeRef {

return NeverRef
}

// Either pre-pend or ap-pend.
func (c *context) pend(singleX, listX ast.Expr, single, list TypeRef) TypeRef {
// Special-case bytes.
if single == ByteRef || list == BytesRef {
c.ensure(singleX, single, ByteRef)
c.ensure(listX, list, BytesRef)
return BytesRef
}

c.ensure(singleX, c.reg.List(single), list)
return list
}
Loading
Loading