#lang shplait // Start with "class.rhm", "inherit.rhm", and "inherit_parse.rhm" // Part 1: // Add `abs`, but in all layers, so we can eventually // run a Moe program like // abs(-10) // to get 10. // Part 2: // Allow a class to have 0 or multiple superclasses listed after `extends`. // The fields of the superclasses get combined in order before // any new fields in the subclass. // With this change, the `Object` class becomes unnecessary, and we // can remove support for `Object` as a superclass. // A `super` call will need to say which superclass's method is // being called: // (super :: ).() // ^^^^^^^^ type Exp | intE(n :: Int) | plusE(lhs :: Exp, rhs :: Exp) | multE(lhs :: Exp, rhs :: Exp) | argE() | thisE() | newE(class_name :: Symbol, args :: Listof(Exp)) | getE(obj_exp :: Exp, field_name :: Symbol) | sendE(obj_exp :: Exp, method_name :: Symbol, arg_exp :: Exp) | ssendE(obj_exp :: Exp, class_name :: Symbol, method_name :: Symbol, arg_exp :: Exp) | absE(arg :: Exp) type Class | classC(field_names :: Listof(Symbol), methods :: Listof(Symbol * Exp)) type Value | intV(n :: Int) | objV(class_name :: Symbol, fields :: Listof(Value)) // ---------------------------------------- fun find(l :: Listof(Symbol * ?a), name :: Symbol) :: ?a: match l | []: error(#'find, "not found: " +& name) | cons(p, rst_l): if fst(p) == name | snd(p) | find(rst_l, name) module test: check: find([values(#'a, 1)], #'a) ~is 1 check: find([values(#'a, "apple")], #'a) ~is "apple" check: find([values(#'a, 1), values(#'b, 2)], #'b) ~is 2 check: find([], #'a) ~raises "not found: a" check: find([values(#'a, 1)], #'x) ~raises "not found: x" // ---------------------------------------- def interp :: (Exp, Listof(Symbol * Class), Value, Value) -> Value: fun (a, classes, this_val, arg_val): fun recur(exp): interp(exp, classes, this_val, arg_val) match a | intE(n): intV(n) | plusE(l, r): num_plus(recur(l), recur(r)) | multE(l, r): num_mult(recur(l), recur(r)) | thisE(): this_val | argE(): arg_val | newE(class_name, field_exps): def c = find(classes, class_name) def vals = map(recur, field_exps) if length(vals) == length(classC.field_names(c)) | objV(class_name, vals) | error(#'interp, "wrong field count") | getE(obj_exp, field_name): match recur(obj_exp) | objV(class_name, fields): match find(classes, class_name) | classC(field_names, methods): find(map2(fun (n, v): values(n, v), field_names, fields), field_name) | ~else: error(#'interp, "not an object") | sendE(obj_exp, method_name, arg_exp): def obj = recur(obj_exp) def arg_val = recur(arg_exp) match obj | objV(class_name, fields): call_method(class_name, method_name, classes, obj, arg_val) | ~else: error(#'interp, "not an object") | ssendE(obj_exp, class_name, method_name, arg_exp): def obj = recur(obj_exp) def arg_val = recur(arg_exp) call_method(class_name, method_name, classes, obj, arg_val) | absE(arg): match recur(arg) | intV(val): intV(if val < 0 | val * -1 | val) | ~else: error(#'interp, "Not an Integer") fun call_method(class_name, method_name, classes, obj, arg_val): match find(classes, class_name) | classC(field_names, methods): let body_exp = find(methods, method_name): interp(body_exp, classes, obj, arg_val) fun num_op(op :: (Int, Int) -> Int, l :: Value, r :: Value) :: Value: cond | l is_a intV && r is_a intV: intV(op(intV.n(l), intV.n(r))) | ~else: error(#'interp, "not a number") fun num_plus(l :: Value, r :: Value) :: Value: num_op(fun (a, b): a+b, l, r) fun num_mult(l :: Value, r :: Value) :: Value: num_op(fun (a, b): a*b, l, r) // ---------------------------------------- // Examples module test: def posn_class: values( #'Posn, classC([#'x,#'y], [ values(#'mdist, plusE(getE(thisE(), #'x), getE(thisE(), #'y))), values(#'addDist, plusE(sendE(thisE(), #'mdist, intE(0)), sendE(argE(), #'mdist, intE(0)))), values(#'addX, plusE(getE(thisE(), #'x), argE())), values(#'multY, multE(argE(), getE(thisE(), #'y))), values(#'factory12, newE(#'Posn, [intE(1), intE(2)])) ]) ) def posn3D_class: values( #'Posn3D, classC([#'x,#'y, #'z], [ values(#'mdist, plusE(getE(thisE(), #'z), ssendE(thisE(), #'Posn, #'mdist, argE()))), values(#'addDist, ssendE(thisE(), #'Posn, #'addDist, argE())) ]) ) def posn27 = newE(#'Posn, [intE(2), intE(7)]) def posn531 = newE(#'Posn3D, [intE(5), intE(3), intE(1)]) fun interp_posn(a): interp(a, [posn_class, posn3D_class], intV(-1), intV(-1)) // ---------------------------------------- module test: check: interp(absE(intE(1)), [], objV(#'Object, []), intV(0)) ~is intV(1) check: interp(absE(intE(-1)), [], objV(#'Object, []), intV(0)) ~is intV(1) check: interp(absE(intE(0)), [], objV(#'Object, []), intV(0)) ~is intV(0) check: interp(intE(10), [], objV(#'Object, []), intV(0)) ~is intV(10) check: interp(plusE(intE(10), intE(17)), [], objV(#'Object, []), intV(0)) ~is intV(27) check: interp(multE(intE(10), intE(7)), [], objV(#'Object, []), intV(0)) ~is intV(70) check: interp_posn(newE(#'Posn, [intE(2), intE(7)])) ~is objV(#'Posn, [intV(2), intV(7)]) check: interp_posn(sendE(posn27, #'mdist, intE(0))) ~is intV(9) check: interp_posn(sendE(posn27, #'addX, intE(10))) ~is intV(12) check: interp_posn(sendE(ssendE(posn27, #'Posn, #'factory12, intE(0)), #'multY, intE(15))) ~is intV(30) check: interp_posn(sendE(posn531, #'addDist, posn27)) ~is intV(18) check: interp_posn(plusE(intE(1), posn27)) ~raises "not a number" check: interp_posn(getE(intE(1), #'x)) ~raises "not an object" check: interp_posn(sendE(intE(1), #'mdist, intE(0))) ~raises "not an object" check: interp_posn(newE(#'Posn, [intE(0)])) ~raises "wrong field count"