compose_eval/expression/
closure.rs

1use crate::vm::{FlowEvent, TrackedContainer};
2use crate::{Eval, Evaluated, Machine, ValueEvaluatedExtensions};
3use compose_library::diag::{IntoSourceDiagnostic, SourceResult, Spanned, bail, error};
4use compose_library::{
5    Args, Binding, BindingKind, Closure, Func, IntoValue, Library, Scope, Scopes, Value,
6    VariableAccessError, Visibility,
7};
8use compose_syntax::ast::{AstNode, Expr, Ident, Param, ParamKind};
9use compose_syntax::{Label, Span, SyntaxNode, ast};
10use ecow::{EcoString, EcoVec};
11use std::collections::HashMap;
12
13impl Eval for ast::Lambda<'_> {
14    fn eval(self, vm: &mut Machine) -> SourceResult<Evaluated> {
15        let guard = vm.temp_root_guard();
16
17        let mut defaults = Vec::new();
18        for param in self.params().children() {
19            if let ast::ParamKind::Named(named) = param.kind() {
20                defaults.push(named.expr().eval(guard.vm)?.value);
21            }
22        }
23
24        let captured = {
25            let mut errors = EcoVec::new();
26            let mut scope = Scope::new();
27            for capture in self.captures().children() {
28                let span = capture.binding().span();
29                let name = capture.binding().get();
30                let binding = match guard.vm.get(&capture.binding()).cloned() {
31                    Ok(v) => v,
32                    Err(e) => {
33                        let VariableAccessError::Unbound(unbound) = e else {
34                            errors.push(e.into_source_diagnostic(span));
35                            continue;
36                        };
37
38                        let mut err = error!(
39                            span, "unknown variable `{name}` in closure capture list";
40                            label_message: "variable `{name}` is not defined in the outer scope and cannot be captured";
41                        );
42                        if let Some(msg) = unbound.misspellings_hint_message() {
43                            err.hint(msg);
44                        }
45                        errors.push(err);
46
47                        continue;
48                    }
49                };
50
51                let value = match validate_capture(capture, &binding, guard.vm) {
52                    Ok(v) => v,
53                    Err(e) => {
54                        errors.extend(e.into_iter());
55                        continue;
56                    }
57                };
58
59                scope.bind(
60                    name.clone(),
61                    Binding::new(value.clone(), span).with_kind(match capture.is_mut() {
62                        true => BindingKind::Mutable,
63                        false => BindingKind::Immutable { first_assign: None },
64                    }),
65                );
66            }
67
68            if !errors.is_empty() {
69                return Err(errors);
70            }
71
72            scope
73        };
74
75        let unresolved_captures = {
76            let mut visitor = CapturesVisitor::new(
77                &guard.vm.frames.top.scopes,
78                Some(guard.vm.engine.world.library()),
79                &captured,
80            );
81            visitor.visit_lambda(self);
82            visitor.finish()
83        };
84
85        let closure = Closure {
86            name: None,
87            node: self.to_untyped().clone(),
88            defaults,
89            num_pos_params: self
90                .params()
91                .children()
92                .filter(|p| matches!(p.kind(), ast::ParamKind::Pos(_)))
93                .count(),
94            captured,
95            unresolved_captures,
96        };
97
98        if !guard.vm.context.closure_capture.should_defer() {
99            closure.resolve_captures()?
100        }
101
102        Ok(Func::from(closure)
103            .into_value()
104            .spanned(self.span())
105            .mutable())
106    }
107}
108
109fn validate_capture<'a>(
110    capture: ast::Capture,
111    binding: &'a Binding,
112    vm: &mut Machine,
113) -> SourceResult<&'a Value> {
114    let span = capture.binding().span();
115    let name = capture.binding().get();
116    if capture.is_mut() && capture.is_ref() && !binding.kind().is_mut() {
117        bail!(
118            capture.span(), "cannot capture variable `{name}` as `ref mut` because it is not declared as mutable";
119            label_message: "capture is declared as a mutable reference";
120            label: Label::secondary(binding.span(), "was defined as immutable here");
121            note: "captured mutable references must match the mutability of the original declaration";
122            hint: "declare the variable as mutable: `let mut {name} = ...`";
123            hint: "or remove `mut` from the capture: `|ref {name}, ...|";
124        );
125    }
126
127    let value = binding.read_checked(span, vm.sink_mut());
128
129    if capture.is_ref() && !value.is_box() {
130        bail!(
131            span, "cannot capture non reference type by reference";
132            label_message: "this captures by reference";
133            note: "only boxed values can be captured by reference"
134        );
135    }
136
137    Ok(value)
138}
139
140fn define(
141    vm: &mut Machine,
142    ident: Ident,
143    Spanned { value, span }: Spanned<Value>,
144    param: Param,
145) -> SourceResult<()> {
146    if param.is_ref() && !value.is_box() {
147        bail!(
148            ident.span(),
149            "cannot take a reference to a value type. consider boxing it first with `box::new(value)`"
150        );
151    }
152    if !param.is_ref() && value.is_box() {
153        bail!(span, "Cannot bind a box value to a non-reference parameter";
154            label_message: "this parameter is a boxed value";
155            label: Label::primary(ident.span(), "this parameter is declared as an owned value");
156            note: "passing a reference type (like `box`) into an owned parameter is not allowed";
157            hint: "if the parameter intends to modify the original value, mark the parameter as `ref mut`";
158            hint: "if you want to pass the value by ownership, use `.clone_inner()` to create a new copy";
159        );
160    }
161
162    let kind = if param.is_mut() {
163        BindingKind::ParamMut
164    } else {
165        BindingKind::Param
166    };
167    vm.define(ident, value, kind, Visibility::Private)?;
168
169    Ok(())
170}
171
172//noinspection RsUnnecessaryQualifications - False Positive
173pub fn eval_lambda(closure: &Closure, vm: &mut Machine, args: Args) -> SourceResult<Value> {
174    let guard = vm.temp_root_guard();
175    let ast_closure = closure
176        .node
177        .cast::<ast::Lambda>()
178        .expect("closure is not an ast closure");
179    let params = ast_closure.params();
180    let statements = ast_closure.statements();
181
182    // Make sure a gc round is aware that the args are reachable
183    guard.vm.track_tmp_root(&args);
184
185    let result = guard
186        .vm
187        .with_frame(move |vm| {
188            let mut args = args;
189            if let Some(Spanned { value, span }) = &closure.name {
190                vm.try_bind(
191                    value.clone(),
192                    Binding::new(Func::from(closure.clone()), *span),
193                )?;
194            }
195
196            for (k, v) in closure.captured.bindings() {
197                vm.bind(k.clone(), v.clone());
198            }
199
200            let mut defaults = closure.defaults.iter();
201            for p in params.children() {
202                match p.kind() {
203                    ast::ParamKind::Pos(pattern) => match pattern {
204                        ast::Pattern::Single(Expr::Ident(ident)) => {
205                            define(vm, ident, args.expect(&ident)?, p)?;
206                        }
207                        pattern => bail!(pattern.span(), "Patterns not supported in closures yet"),
208                    },
209                    ast::ParamKind::Named(named) => {
210                        let name = named.name();
211                        let default = defaults.next().unwrap();
212                        let value = args
213                            .named(&name)?
214                            .unwrap_or_else(|| Spanned::new(default.clone(), named.expr().span()));
215                        define(vm, name, value, p)?;
216                    }
217                }
218            }
219
220            // Ensure all args have been used
221            args.finish()?;
222
223            let mut output = Value::unit();
224            for statement in statements {
225                output = statement.eval(vm)?.value;
226                match &vm.flow {
227                    None => {}
228                    Some(FlowEvent::Return(_, Some(explicit))) => {
229                        let explicit = explicit.clone();
230                        vm.flow = None;
231                        return Ok(explicit);
232                    }
233                    Some(FlowEvent::Return(_, None)) => {
234                        vm.flow = None;
235                        return Ok(Value::unit());
236                    }
237                    Some(other) => bail!(other.forbidden()),
238                }
239            }
240
241            SourceResult::Ok(output)
242        })
243        .track_tmp_root(guard.vm);
244
245    guard.vm.maybe_gc();
246
247    result
248}
249
250/// Visits a closure and determines which variables are captured implicitly.
251#[derive(Debug)]
252pub struct CapturesVisitor<'a> {
253    /// The external scope that variables might be captured from.
254    external: &'a Scopes<'a>,
255    /// The internal scope of variables defined within the closure.
256    internal: Scopes<'a>,
257    /// The variables that are captured.
258    captures: HashMap<EcoString, Span>,
259}
260
261impl<'a> CapturesVisitor<'a> {
262    pub fn new(external: &'a Scopes<'a>, library: Option<&'a Library>, existing: &Scope) -> Self {
263        let mut inst = Self {
264            external,
265            internal: Scopes::new(library),
266            captures: HashMap::new(),
267        };
268
269        for (k, v) in existing.bindings() {
270            inst.internal.top.bind(k.clone(), v.clone());
271        }
272
273        inst
274    }
275    pub(crate) fn visit_lambda(&mut self, closure: ast::Lambda<'a>) {
276        for param in closure.params().children() {
277            match param.kind() {
278                ParamKind::Pos(pat) => {
279                    for ident in pat.bindings() {
280                        self.bind(ident);
281                    }
282                }
283                ParamKind::Named(named) => {
284                    self.bind(named.name());
285                }
286            }
287        }
288
289        for capture in closure.captures().children() {
290            self.visit(capture.to_untyped());
291        }
292
293        for statement in closure.statements() {
294            self.visit(statement.to_untyped());
295        }
296    }
297
298    pub fn visit(&mut self, node: &'a SyntaxNode) {
299        if let Some(ast::Statement::Let(let_binding)) = node.cast() {
300            if let Some(init) = let_binding.initial_value() {
301                self.visit(init.to_untyped())
302            }
303
304            for ident in let_binding.pattern().bindings() {
305                self.bind(ident);
306            }
307            return;
308        }
309
310        let expr = match node.cast::<Expr>() {
311            Some(expr) => expr,
312            None => {
313                if let Some(named) = node.cast::<ast::Named>() {
314                    // Don't capture the name of a named parameter.
315                    self.visit(named.expr().to_untyped());
316                    return;
317                }
318
319                Expr::default()
320            }
321        };
322
323        match expr {
324            Expr::Ident(ident) => self.capture(ident),
325            Expr::CodeBlock(_) => {
326                self.internal.enter();
327                for child in node.children() {
328                    self.visit(child);
329                }
330                self.internal.exit();
331            }
332            Expr::FieldAccess(access) => {
333                self.visit(access.target().to_untyped());
334            }
335            Expr::Lambda(closure) => {
336                for param in closure.params().children() {
337                    if let ast::ParamKind::Named(named) = param.kind() {
338                        self.visit(named.expr().to_untyped());
339                    }
340                }
341
342                for capture in closure.captures().children() {
343                    self.visit(capture.to_untyped());
344                }
345
346                // NOTE: For now we do not try to analyse the body of the closure.
347                // This is because the closure might try to recursively call itself
348                // and in simple ast walking, that is really hard to resolve correctly.
349                // Any errors in the body will be caught when the outer body is evaluated.
350            }
351
352            Expr::ForLoop(for_loop) => {
353                // Created in outer scope
354                self.visit(for_loop.iterable().to_untyped());
355
356                self.internal.enter();
357                let pattern = for_loop.binding();
358                for ident in pattern.bindings() {
359                    self.bind(ident);
360                }
361
362                self.visit(for_loop.body().to_untyped());
363                self.internal.exit();
364            }
365
366            _ => {
367                // If not an expression or named, just go over all the children
368                for child in node.children() {
369                    self.visit(child);
370                }
371            }
372        }
373    }
374
375    fn bind(&mut self, ident: Ident) {
376        self.internal.top.bind(
377            ident.get().clone(),
378            Binding::new(Value::unit(), ident.span()),
379        );
380    }
381
382    fn capture(&mut self, ident: Ident<'a>) {
383        if self.internal.get(&ident).is_ok() {
384            // Was defined internally, no need to capture
385            return;
386        }
387
388        // If the value does not exist in the external scope, it is not captured.
389        if self.external.get(&ident).is_ok() {
390            self.captures
391                .entry(ident.get().clone())
392                .or_insert(ident.span());
393        }
394    }
395
396    fn finish(self) -> HashMap<EcoString, Span> {
397        self.captures
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use crate::expression::closure::CapturesVisitor;
404    use crate::test::*;
405    use compose_library::{Scope, Scopes};
406    use compose_syntax::{FileId, parse};
407
408    #[test]
409    fn capturing() {
410        assert_eval(
411            r#"
412            let a = 10;
413            let f = { |a| => a + 1 };
414            assert::eq(f(), 11);
415        "#,
416        );
417
418        assert_eval(
419            r#"
420            let mut a = box::new(5);
421            let f = { |ref mut a| =>
422                *a += 1;
423                *a;
424            };
425            assert::eq(f(), 6);
426        "#,
427        );
428
429        assert_eval(
430            r#"
431            let a = 2;
432            let b = 3;
433            let f = { |a, b| => a * b };
434            assert::eq(f(), 6);
435        "#,
436        );
437
438        assert_eval(
439            r#"
440            let mut name = box::new("Alice");
441            let age = 30;
442            let show = { |ref name, age| =>
443                "Name: " + *name + ", Age: " + age.to_string();
444            };
445            assert::eq(show(), "Name: Alice, Age: 30");
446            *name = "Bob";
447            assert::eq(show(), "Name: Bob, Age: 30");
448        "#,
449        );
450    }
451
452    #[track_caller]
453    fn test(scopes: &Scopes, existing_scope: &Scope, text: &str, expected_names: &[&str]) {
454        let mut visitor = CapturesVisitor::new(scopes, None, existing_scope);
455        let nodes = parse(text, FileId::new("test.comp"));
456
457        for node in &nodes {
458            assert!(node.errors().is_empty(), "node has errors: {:#?}", node.errors());
459            visitor.visit(node);
460        }
461
462        let captures = visitor.finish();
463        let mut names: Vec<_> = captures.iter().map(|(k, ..)| k).collect();
464        names.sort();
465
466        assert_eq!(names, expected_names);
467    }
468
469    #[test]
470    fn test_captures_visitor() {
471        let mut scopes = Scopes::new(None);
472        scopes.top.define("f", 0i64);
473        scopes.top.define("x", 0i64);
474        scopes.top.define("y", 0i64);
475        scopes.top.define("z", 0i64);
476        let s = &scopes;
477
478        let mut existing = Scope::new();
479        existing.define("a", 0i64);
480        existing.define("b", 0i64);
481        existing.define("c", 0i64);
482        let e = &existing;
483
484        test(s, e, "{ x => x * 2; }", &[]);
485
486        // let binding
487        test(s, e, "let t = x;", &["x"]);
488        test(s, e, "let x = x;", &["x"]);
489        test(s, e, "let x;", &[]);
490        test(s, e, "let x = 2; x + y;", &["y"]);
491        test(s, e, "x + y", &["x", "y"]);
492
493        // assignment
494        test(s, e, "x += y;", &["x", "y"]);
495        test(s, e, "x = y;", &["x", "y"]);
496
497        // closure definition
498        // Closure bodies are ignored
499        test(s, e, "let f = { => x + y; }", &[]);
500        // with capture
501        test(s, e, "let f = { |x| => x + y; }", &["x"]);
502        test(s, e, "let f = { |x| => f(); }", &["x"]);
503        // with params
504        test(s, e, "let f = { x, y, z => f(); }", &[]);
505        // named params
506        test(
507            s,
508            e,
509            "let f = { x: x, y: y, z: z => f(); }",
510            &["x", "y", "z"],
511        );
512
513        // for loop
514        test(s, e, "for (x in y) { x + z; };", &["y", "z"]);
515        test(s, e, "for (x in y) { x; }; x", &["x", "y"]);
516
517        // block
518        test(s, e, "{ x; };", &["x"]);
519        test(s, e, "{ let x; x; };", &[]);
520        test(s, e, "{ let x; x; }; x;", &["x"]);
521
522        // field access
523        test(s, e, "x.y.f(z);", &["x", "z"]);
524
525        // parenthesized
526        test(s, e, "(x + z);", &["x", "z"]);
527        test(s, e, "(({ x => x + y }) + y);", &["y"]);
528    }
529}