compose_eval/expression/
closure.rs

1use crate::vm::{FlowEvent, TrackedContainer};
2use crate::{Eval, Evaluated, Machine};
3use compose_library::diag::{IntoSourceDiagnostic, SourceResult, Spanned, bail, error};
4use compose_library::{Args, Binding, BindingKind, Closure, Func, Library, Scope, Scopes, Value, VariableAccessError, Visibility};
5use compose_syntax::ast::{AstNode, Expr, Ident, Param, ParamKind};
6use compose_syntax::{Label, Span, SyntaxNode, ast};
7use ecow::{EcoString, EcoVec};
8use std::collections::HashMap;
9
10impl Eval for ast::Lambda<'_> {
11    fn eval(self, vm: &mut Machine) -> SourceResult<Evaluated> {
12        let guard = vm.temp_root_guard();
13
14        let mut defaults = Vec::new();
15        for param in self.params().children() {
16            if let ast::ParamKind::Named(named) = param.kind() {
17                defaults.push(named.expr().eval(guard.vm)?.value);
18            }
19        }
20
21        let captured = {
22            let mut errors = EcoVec::new();
23            let mut scope = Scope::new();
24            for capture in self.captures().children() {
25                let span = capture.binding().span();
26                let name = capture.binding().get();
27                let binding = match guard.vm.get(&capture.binding()).cloned() {
28                    Ok(v) => v,
29                    Err(e) => {
30                        let VariableAccessError::Unbound(unbound) = e else {
31                            errors.push(e.into_source_diagnostic(span));
32                            continue;
33                        };
34
35                        let mut err = error!(
36                            span, "unknown variable `{name}` in closure capture list";
37                            label_message: "variable `{name}` is not defined in the outer scope and cannot be captured";
38                        );
39                        unbound.apply_hint(&mut err);
40                        errors.push(err);
41
42                        continue;
43                    }
44                };
45
46                let value = match validate_capture(capture, &binding, guard.vm) {
47                    Ok(v) => v,
48                    Err(e) => {
49                        errors.extend(e.into_iter());
50                        continue;
51                    }
52                };
53
54                scope.bind(
55                    name.clone(),
56                    Binding::new(value.clone(), span).with_kind(match capture.is_mut() {
57                        true => BindingKind::Mutable,
58                        false => BindingKind::Immutable { first_assign: None },
59                    }),
60                );
61            }
62
63            if !errors.is_empty() {
64                return Err(errors);
65            }
66
67            scope
68        };
69
70        let unresolved_captures = {
71            let mut visitor = CapturesVisitor::new(
72                &guard.vm.frames.top.scopes,
73                Some(guard.vm.engine.world.library()),
74                &captured,
75            );
76            visitor.visit_lambda(self);
77            visitor.finish()
78        };
79
80        let closure = Closure {
81            name: None,
82            node: self.to_untyped().clone(),
83            defaults,
84            num_pos_params: self
85                .params()
86                .children()
87                .filter(|p| matches!(p.kind(), ast::ParamKind::Pos(_)))
88                .count(),
89            captured,
90            unresolved_captures,
91        };
92
93        if !guard.vm.context.closure_capture.should_defer() {
94            closure.resolve()?
95        }
96
97        Ok(Evaluated::mutable(
98            Value::Func(Func::from(closure)).spanned(self.span()),
99        ))
100    }
101}
102
103fn validate_capture<'a>(
104    capture: ast::Capture,
105    binding: &'a Binding,
106    vm: &mut Machine,
107) -> SourceResult<&'a Value> {
108    let span = capture.binding().span();
109    let name = capture.binding().get();
110    if capture.is_mut() && capture.is_ref() && !binding.kind().is_mut() {
111        bail!(
112            capture.span(), "cannot capture variable `{name}` as `ref mut` because it is not declared as mutable";
113            label_message: "capture is declared as a mutable reference";
114            label: Label::secondary(binding.span(), "was defined as immutable here");
115            note: "captured mutable references must match the mutability of the original declaration";
116            hint: "declare the variable as mutable: `let mut {name} = ...`";
117            hint: "or remove `mut` from the capture: `|ref {name}, ...|";
118        );
119    }
120
121    let value = binding.read_checked(span, vm.sink_mut());
122
123    if capture.is_ref() && !value.is_box() {
124        bail!(
125            span, "cannot capture non reference type by reference";
126            label_message: "this captures by reference";
127            note: "only boxed values can be captured by reference"
128        );
129    }
130
131    Ok(value)
132}
133
134fn define(
135    vm: &mut Machine,
136    ident: Ident,
137    Spanned { value, span }: Spanned<Value>,
138    param: Param,
139) -> SourceResult<()> {
140    if param.is_ref() && !value.is_box() {
141        bail!(
142            ident.span(),
143            "cannot take a reference to a value type. consider boxing it first with `box::new(value)`"
144        );
145    }
146    if !param.is_ref() && value.is_box() {
147        bail!(span, "Cannot bind a box value to a non-reference parameter";
148            label_message: "this parameter is a boxed value";
149            label: Label::primary(ident.span(), "this parameter is declared as an owned value");
150            note: "passing a reference type (like `box`) into an owned parameter is not allowed";
151            hint: "if the parameter intends to modify the original value, mark the parameter as `ref mut`";
152            hint: "if you want to pass the value by ownership, use `.clone_inner()` to create a new copy";
153        );
154    }
155
156    let kind = if param.is_mut() {
157        BindingKind::ParamMut
158    } else {
159        BindingKind::Param
160    };
161    vm.define(ident, value, kind, Visibility::Private)?;
162
163    Ok(())
164}
165
166//noinspection RsUnnecessaryQualifications - False Positive
167pub fn eval_lambda(closure: &Closure, vm: &mut Machine, args: Args) -> SourceResult<Value> {
168    let guard = vm.temp_root_guard();
169    let ast_closure = closure
170        .node
171        .cast::<ast::Lambda>()
172        .expect("closure is not an ast closure");
173    let params = ast_closure.params();
174    let statements = ast_closure.statements();
175
176    // Make sure a gc round is aware that the args are reachable
177    guard.vm.track_tmp_root(&args);
178
179    let result = guard
180        .vm
181        .with_frame(move |vm| {
182            let mut args = args;
183            if let Some(Spanned { value, span }) = &closure.name {
184                vm.try_bind(
185                    value.clone(),
186                    Binding::new(Func::from(closure.clone()), *span),
187                )?;
188            }
189
190            for (k, v) in closure.captured.bindings() {
191                vm.bind(k.clone(), v.clone());
192            }
193
194            let mut defaults = closure.defaults.iter();
195            for p in params.children() {
196                match p.kind() {
197                    ast::ParamKind::Pos(pattern) => match pattern {
198                        ast::Pattern::Single(Expr::Ident(ident)) => {
199                            define(vm, ident, args.expect(&ident)?, p)?;
200                        }
201                        pattern => bail!(pattern.span(), "Patterns not supported in closures yet"),
202                    },
203                    ast::ParamKind::Named(named) => {
204                        let name = named.name();
205                        let default = defaults.next().unwrap();
206                        let value = args
207                            .named(&name)?
208                            .unwrap_or_else(|| Spanned::new(default.clone(), named.expr().span()));
209                        define(vm, name, value, p)?;
210                    }
211                }
212            }
213
214            // Ensure all args have been used
215            args.finish()?;
216
217            let mut output = Value::unit();
218            for statement in statements {
219                output = statement.eval(vm)?.value;
220                match &vm.flow {
221                    None => {}
222                    Some(FlowEvent::Return(_, Some(explicit))) => {
223                        let explicit = explicit.clone();
224                        vm.flow = None;
225                        return Ok(explicit)
226                    },
227                    Some(FlowEvent::Return(_, None)) => {
228                        vm.flow = None;
229                        return Ok(Value::unit())
230                    }
231                    Some(other) => bail!(other.forbidden()),
232                }
233            }
234
235            SourceResult::Ok(output)
236        })
237        .track_tmp_root(guard.vm);
238
239    guard.vm.maybe_gc();
240
241    result
242}
243
244/// Visits a closure and determines which variables are captured implicitly.
245#[derive(Debug)]
246pub struct CapturesVisitor<'a> {
247    /// The external scope that variables might be captured from.
248    external: &'a Scopes<'a>,
249    /// The internal scope of variables defined within the closure.
250    internal: Scopes<'a>,
251    /// The variables that are captured.
252    captures: HashMap<EcoString, Span>,
253}
254
255impl<'a> CapturesVisitor<'a> {
256    pub fn new(external: &'a Scopes<'a>, library: Option<&'a Library>, existing: &Scope) -> Self {
257        let mut inst = Self {
258            external,
259            internal: Scopes::new(library),
260            captures: HashMap::new(),
261        };
262
263        for (k, v) in existing.bindings() {
264            inst.internal.top.bind(k.clone(), v.clone());
265        }
266
267        inst
268    }
269    pub(crate) fn visit_lambda(&mut self, closure: ast::Lambda<'a>) {
270        for param in closure.params().children() {
271            match param.kind() {
272                ParamKind::Pos(pat) => {
273                    for ident in pat.bindings() {
274                        self.bind(ident);
275                    }
276                }
277                ParamKind::Named(named) => {
278                    self.bind(named.name());
279                }
280            }
281        }
282
283        for capture in closure.captures().children() {
284            self.visit(capture.to_untyped());
285        }
286
287        for statement in closure.statements() {
288            self.visit(statement.to_untyped());
289        }
290    }
291
292    pub fn visit(&mut self, node: &'a SyntaxNode) {
293        if let Some(ast::Statement::Let(let_binding)) = node.cast() {
294            if let Some(init) = let_binding.initial_value() {
295                self.visit(init.to_untyped())
296            }
297
298            for ident in let_binding.pattern().bindings() {
299                self.bind(ident);
300            }
301            return;
302        }
303
304        let expr = match node.cast::<Expr>() {
305            Some(expr) => expr,
306            None => {
307                if let Some(named) = node.cast::<ast::Named>() {
308                    // Don't capture the name of a named parameter.
309                    self.visit(named.expr().to_untyped());
310                    return;
311                }
312
313                Expr::default()
314            }
315        };
316
317        match expr {
318            Expr::Ident(ident) => self.capture(ident),
319            Expr::CodeBlock(_) => {
320                self.internal.enter();
321                for child in node.children() {
322                    self.visit(child);
323                }
324                self.internal.exit();
325            }
326            Expr::FieldAccess(access) => {
327                self.visit(access.target().to_untyped());
328            }
329            Expr::Lambda(closure) => {
330                for param in closure.params().children() {
331                    if let ast::ParamKind::Named(named) = param.kind() {
332                        self.visit(named.expr().to_untyped());
333                    }
334                }
335
336                for capture in closure.captures().children() {
337                    self.visit(capture.to_untyped());
338                }
339
340                // NOTE: For now we do not try to analyse the body of the closure.
341                // This is because the closure might try to recursively call itself
342                // and in simple ast walking, that is really hard to resolve correctly.
343                // Any errors in the body will be caught when the outer body is evaluated.
344            }
345
346            Expr::ForLoop(for_loop) => {
347                // Created in outer scope
348                self.visit(for_loop.iterable().to_untyped());
349
350                self.internal.enter();
351                let pattern = for_loop.binding();
352                for ident in pattern.bindings() {
353                    self.bind(ident);
354                }
355
356                self.visit(for_loop.body().to_untyped());
357                self.internal.exit();
358            }
359
360            _ => {
361                // If not an expression or named, just go over all the children
362                for child in node.children() {
363                    self.visit(child);
364                }
365            }
366        }
367    }
368
369    fn bind(&mut self, ident: Ident) {
370        self.internal.top.bind(
371            ident.get().clone(),
372            Binding::new(Value::unit(), ident.span()),
373        );
374    }
375
376    fn capture(&mut self, ident: Ident<'a>) {
377        if self.internal.get(&ident).is_ok() {
378            // Was defined internally, no need to capture
379            return;
380        }
381
382        // If the value does not exist in the external scope, it is not captured.
383        if self.external.get(&ident).is_ok() {
384            self.captures
385                .entry(ident.get().clone())
386                .or_insert(ident.span());
387        }
388    }
389
390    fn finish(self) -> HashMap<EcoString, Span> {
391        self.captures
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use crate::expression::closure::CapturesVisitor;
398    use crate::test::*;
399    use compose_library::{Scope, Scopes};
400    use compose_syntax::{FileId, parse};
401
402    #[test]
403    fn capturing() {
404        assert_eval(
405            r#"
406            let a = 10;
407            let f = { |a| => a + 1 };
408            assert::eq(f(), 11);
409        "#,
410        );
411
412        assert_eval(
413            r#"
414            let mut a = box::new(5);
415            let f = { |ref mut a| =>
416                *a += 1;
417                *a;
418            };
419            assert::eq(f(), 6);
420        "#,
421        );
422
423        assert_eval(
424            r#"
425            let a = 2;
426            let b = 3;
427            let f = { |a, b| => a * b };
428            assert::eq(f(), 6);
429        "#,
430        );
431
432        assert_eval(
433            r#"
434            let mut name = box::new("Alice");
435            let age = 30;
436            let show = { |ref name, age| =>
437                "Name: " + *name + ", Age: " + age.to_string();
438            };
439            assert::eq(show(), "Name: Alice, Age: 30");
440            *name = "Bob";
441            assert::eq(show(), "Name: Bob, Age: 30");
442        "#,
443        );
444    }
445
446    #[track_caller]
447    fn test(scopes: &Scopes, existing_scope: &Scope, text: &str, result: &[&str]) {
448        let mut visitor = CapturesVisitor::new(scopes, None, existing_scope);
449        let nodes = parse(text, FileId::new("test.comp"));
450        for node in &nodes {
451            visitor.visit(node);
452        }
453
454        let captures = visitor.finish();
455        let mut names: Vec<_> = captures.iter().map(|(k, ..)| k).collect();
456        names.sort();
457
458        assert_eq!(names, result);
459    }
460
461    #[test]
462    fn test_captures_visitor() {
463        let mut scopes = Scopes::new(None);
464        scopes.top.define("f", 0i64);
465        scopes.top.define("x", 0i64);
466        scopes.top.define("y", 0i64);
467        scopes.top.define("z", 0i64);
468        let s = &scopes;
469
470        let mut existing = Scope::new();
471        existing.define("a", 0i64);
472        existing.define("b", 0i64);
473        existing.define("c", 0i64);
474        let e = &existing;
475
476        test(s, e, "{ x => x * 2; }", &[]);
477
478        // let binding
479        test(s, e, "let t = x;", &["x"]);
480        test(s, e, "let x = x;", &["x"]);
481        test(s, e, "let x;", &[]);
482        test(s, e, "let x = 2; x + y;", &["y"]);
483        test(s, e, "x + y", &["x", "y"]);
484
485        // assignment
486        test(s, e, "x += y;", &["x", "y"]);
487        test(s, e, "x = y;", &["x", "y"]);
488
489        // closure definition
490        // Closure bodies are ignored
491        test(s, e, "let f = { => x + y; }", &[]);
492        // with capture
493        test(s, e, "let f = { |x| => x + y; }", &["x"]);
494        test(s, e, "let f = { |x| => f(); }", &["x"]);
495        // with params
496        test(s, e, "let f = { x, y, z => f(); }", &[]);
497        // named params
498        test(
499            s,
500            e,
501            "let f = { x = x, y = y, z = z => f(); }",
502            &["x", "y", "z"],
503        );
504
505        // for loop
506        test(s, e, "for (x in y) { x + z; };", &["y", "z"]);
507        test(s, e, "for (x in y) { x; }; x", &["x", "y"]);
508
509        // block
510        test(s, e, "{ x; };", &["x"]);
511        test(s, e, "{ let x; x; };", &[]);
512        test(s, e, "{ let x; x; }; x;", &["x"]);
513
514        // field access
515        test(s, e, "x.y.f(z);", &["x", "z"]);
516
517        // parenthesized
518        test(s, e, "(x + z);", &["x", "z"]);
519        test(s, e, "(({ x => x + y }) + y);", &["y"]);
520    }
521}