diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 9f156107..a8e47600 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -145,33 +145,21 @@ loop: cmp := operators.ComparatorFromByte(int(op) - int(OpAllEq)) res, err := operators.ArrayAll(ctx, cmp, reg[src1], reg[src2]) - if err == nil { - reg[dst] = res - } else if _, catch := vm.tryCatch(vm.pc); catch { - reg[dst] = runtime.False - } else { + if err := vm.setOrTryCatch(dst, res, err); err != nil { return nil, err } case OpAnyEq, OpAnyNe, OpAnyGt, OpAnyGte, OpAnyLt, OpAnyLte, OpAnyIn: cmp := operators.ComparatorFromByte(int(op) - int(OpAnyEq)) res, err := operators.ArrayAny(ctx, cmp, reg[src1], reg[src2]) - if err == nil { - reg[dst] = res - } else if _, catch := vm.tryCatch(vm.pc); catch { - reg[dst] = runtime.False - } else { + if err := vm.setOrTryCatch(dst, res, err); err != nil { return nil, err } case OpNoneEq, OpNoneNe, OpNoneGt, OpNoneGte, OpNoneLt, OpNoneLte, OpNoneIn: cmp := operators.ComparatorFromByte(int(op) - int(OpNoneEq)) res, err := operators.ArrayNone(ctx, cmp, reg[src1], reg[src2]) - if err == nil { - reg[dst] = res - } else if _, catch := vm.tryCatch(vm.pc); catch { - reg[dst] = runtime.False - } else { + if err := vm.setOrTryCatch(dst, res, err); err != nil { return nil, err } case OpLoadArray: @@ -214,11 +202,7 @@ loop: arg := reg[src2] out, err := vm.loadIndex(ctx, src, arg) - if err == nil { - reg[dst] = out - } else if op == OpLoadIndexOptional { - reg[dst] = runtime.None - } else { + if err := vm.setOrOptional(dst, out, err, op == OpLoadIndexOptional); err != nil { return nil, err } @@ -227,11 +211,7 @@ loop: arg := reg[src2] out, err := vm.loadKey(ctx, src, arg) - if err == nil { - reg[dst] = out - } else if op == OpLoadKeyOptional { - reg[dst] = runtime.None - } else { + if err := vm.setOrOptional(dst, out, err, op == OpLoadKeyOptional); err != nil { return nil, err } @@ -251,111 +231,49 @@ loop: out, err = vm.loadKey(ctx, src, runtime.ToString(prop)) } - if err == nil { - reg[dst] = out - } else if op == OpLoadPropertyOptional { - reg[dst] = runtime.None - } else { + if err := vm.setOrOptional(dst, out, err, op == OpLoadPropertyOptional); err != nil { return nil, err } + case OpCall, OpProtectedCall: out, err := vm.callv(ctx, vm.pc-1, src1, src2) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } + case OpCall0, OpProtectedCall0: out, err := vm.call0(ctx, vm.pc-1) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall0 { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } case OpCall1, OpProtectedCall1: out, err := vm.call1(ctx, vm.pc-1, src1) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall1 { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } case OpCall2, OpProtectedCall2: out, err := vm.call2(ctx, vm.pc-1, src1, src2) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall2 { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } case OpCall3, OpProtectedCall3: out, err := vm.call3(ctx, vm.pc-1, src1) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall3 { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } case OpCall4, OpProtectedCall4: out, err := vm.call4(ctx, vm.pc-1, src1) - if err == nil { - reg[dst] = out - } else if op == OpProtectedCall4 { - reg[dst] = runtime.None - } else if catch, ok := vm.tryCatch(vm.pc); ok { - reg[dst] = runtime.None - - if catch[2] > 0 { - vm.pc = catch[2] - } - } else { + if err := vm.setCallResult(op, dst, out, err); err != nil { return nil, err } @@ -530,7 +448,7 @@ loop: var timeout runtime.Int if reg[src2] != nil && reg[src2] != runtime.None { - t, err := runtime.CastInt(reg[src1]) + t, err := runtime.CastInt(reg[src2]) if err != nil { if _, catch := vm.tryCatch(vm.pc); catch { diff --git a/pkg/vm/vm_helpers.go b/pkg/vm/vm_helpers.go index aa51421e..b2292932 100644 --- a/pkg/vm/vm_helpers.go +++ b/pkg/vm/vm_helpers.go @@ -177,3 +177,74 @@ func (vm *VM) castSubscribeArgs(dst, eventName, opts runtime.Value) (runtime.Obs return observable, eventNameStr, options, nil } + +func (vm *VM) setOrTryCatch(dst Operand, val runtime.Value, err error) error { + reg := vm.registers.Values + + if err == nil { + reg[dst] = val + + return nil + } + + if _, catch := vm.tryCatch(vm.pc); catch { + reg[dst] = runtime.None + + return nil + } + + return err +} + +func (vm *VM) setCallResult(op Opcode, dst Operand, out runtime.Value, err error) error { + reg := vm.registers.Values + + if err == nil { + reg[dst] = out + + return nil + } + + if isProtectedCall(op) { + reg[dst] = runtime.None + + return nil + } + + if catch, ok := vm.tryCatch(vm.pc); ok { + reg[dst] = runtime.None + + if catch[2] > 0 { + vm.pc = catch[2] + } + + return nil + } + + return err +} + +func (vm *VM) setOrOptional(dst Operand, val runtime.Value, err error, optional bool) error { + if err == nil { + vm.registers.Values[dst] = val + + return nil + } + + if optional { + vm.registers.Values[dst] = runtime.None + + return nil + } + + return err +} + +func isProtectedCall(op Opcode) bool { + switch op { + case OpProtectedCall, OpProtectedCall0, OpProtectedCall1, OpProtectedCall2, OpProtectedCall3, OpProtectedCall4: + return true + default: + return false + } +}