compose_library/foundations/
func.rs

1use crate::diag::{SourceResult, StrResult, bail};
2use crate::foundations::args::Args;
3use crate::vm::Vm;
4use crate::{Sink, Trace, Value};
5use compose_error_codes::E0010_UNCAPTURED_VARIABLE;
6use compose_library::diag::{Spanned, error};
7use compose_library::{Scope, UntypedRef};
8use compose_macros::{cast, ty};
9use compose_syntax::ast::{AstNode};
10use compose_syntax::{Label, Span, SyntaxNode, ast};
11use compose_utils::Static;
12use ecow::{EcoString, eco_format, eco_vec};
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::LazyLock;
16
17#[derive(Clone, Debug, PartialEq)]
18#[ty(cast)]
19pub struct Func {
20    pub kind: FuncKind,
21    pub span: Span,
22}
23
24impl Func {
25    pub(crate) fn spanned(mut self, span: Span) -> Func {
26        if self.span.is_detached() {
27            self.span = span;
28        }
29        self
30    }
31
32    pub(crate) fn named(mut self, name: Spanned<EcoString>) -> Func {
33        if let FuncKind::Closure(closure) = &mut self.kind {
34            closure.unresolved_captures.remove(&name.value);
35            closure.name = Some(name);
36        }
37        self
38    }
39
40    pub(crate) fn resolve(&self) -> SourceResult<()> {
41        if let FuncKind::Closure(closure) = &self.kind {
42            closure.resolve()?;
43        }
44        Ok(())
45    }
46
47    pub(crate) fn span(&self) -> Span {
48        self.span
49    }
50}
51
52impl fmt::Display for Func {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match &self.kind {
55            FuncKind::Native(native) => write!(f, "{}", native.0.name),
56            FuncKind::Closure(closure) => closure.fmt(f),
57        }
58    }
59}
60
61impl Trace for Func {
62    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
63        match &self.kind {
64            FuncKind::Native(native) => {
65                native.scope.visit_refs(f);
66            }
67            FuncKind::Closure(closure) => {
68                closure.captured.visit_refs(f);
69                closure.defaults.iter().for_each(|v| v.visit_refs(f));
70            }
71        }
72    }
73}
74
75#[derive(Clone, Debug, PartialEq)]
76pub enum FuncKind {
77    Native(Static<NativeFuncData>),
78    Closure(Box<Closure>),
79}
80
81impl Func {
82    pub fn call(&self, vm: &mut dyn Vm, args: Args) -> SourceResult<Value> {
83        vm.call_func(self, args)
84    }
85
86    pub fn scope(&self) -> Option<&'static Scope> {
87        match &self.kind {
88            FuncKind::Native(native) => Some(&native.0.scope),
89            FuncKind::Closure(_) => None,
90        }
91    }
92
93    pub fn name(&self) -> Option<&str> {
94        match &self.kind {
95            FuncKind::Native(native) => Some(native.0.name),
96            FuncKind::Closure(_) => None,
97        }
98    }
99
100    pub fn field(&self, field: &str, access_span: Span, sink: &mut Sink) -> StrResult<&Value> {
101        let scope = self
102            .scope()
103            .ok_or("Cannot access fields on user-defined functions")?;
104        match scope.get(field) {
105            Some(binding) => Ok(binding.read_checked(access_span, sink)),
106            None => match self.name() {
107                Some(name) => bail!("function `{name}` does not contain field `{field}`"),
108                None => bail!("Function does not contain field `{field}`"),
109            },
110        }
111    }
112
113    pub fn path(&self, path: &str, access_span: Span, sink: &mut Sink) -> StrResult<&Value> {
114        let scope = self
115            .scope()
116            .ok_or("Cannot access fields on user-defined functions")?;
117        match scope.get(path) {
118            Some(binding) => Ok(binding.read_checked(access_span, sink)),
119            None => match self.name() {
120                Some(name) => bail!("function `{name}` does not contain associated field `{path}`"),
121                None => bail!("Function does not contain associated field `{path}`"),
122            },
123        }
124    }
125
126    pub fn is_associated_function(&self) -> bool {
127        match self.kind {
128            FuncKind::Native(n) => match n.fn_type {
129                FuncType::Method => false,
130                FuncType::MethodMut => false,
131                FuncType::Associated => true,
132            },
133            FuncKind::Closure(_) => false,
134        }
135    }
136
137    pub fn requires_mut_self(&self) -> bool {
138        match self.kind {
139            FuncKind::Native(n) => match n.fn_type {
140                FuncType::Method => false,
141                FuncType::MethodMut => true,
142                FuncType::Associated => false,
143            },
144            FuncKind::Closure(_) => false,
145        }
146    }
147}
148
149pub trait NativeFunc {
150    fn data() -> &'static NativeFuncData;
151}
152
153#[derive(Debug)]
154pub struct NativeFuncData {
155    pub closure: fn(&mut dyn Vm, &mut Args) -> SourceResult<Value>,
156    pub name: &'static str,
157    pub scope: LazyLock<&'static Scope>,
158    pub fn_type: FuncType,
159}
160
161impl NativeFuncData {
162    pub fn call(&self, vm: &mut dyn Vm, mut args: Args) -> SourceResult<Value> {
163        (self.closure)(vm, &mut args)
164    }
165}
166
167#[derive(Debug, Clone)]
168pub struct Closure {
169    pub node: SyntaxNode,
170    pub defaults: Vec<Value>,
171    pub num_pos_params: usize,
172    pub name: Option<Spanned<EcoString>>,
173    pub captured: Scope,
174    pub unresolved_captures: HashMap<EcoString, Span>,
175}
176
177impl PartialEq for Closure {
178    fn eq(&self, other: &Self) -> bool {
179        self.node == other.node
180    }
181}
182
183impl Closure {
184    pub fn resolve(&self) -> SourceResult<()> {
185        if !self.unresolved_captures.is_empty() {
186            let mut captures = self.unresolved_captures.iter();
187
188            let names = self
189                .unresolved_captures
190                .keys()
191                .map(|name| name.trim())
192                .collect::<Vec<_>>()
193                .join(", ");
194
195            let params = self
196                .node
197                .cast::<ast::Lambda>()
198                .expect("Closure contains non lambda node")
199                .params()
200                .children()
201                .map(|p| p.to_untyped().to_text())
202                .collect::<Vec<_>>()
203                .join(", ");
204
205            let (first_name, first_span) = captures.next().unwrap();
206
207            let mut err = error!(*first_span, "outer variables used in closure but not captured";
208                label_message: "outer variable `{first_name}` used here";
209                hint: "explicitly capture them by adding them to a capture group: `|{names}| ({params}) => ...`";
210                code: &E0010_UNCAPTURED_VARIABLE
211            );
212
213            for (name, span) in captures {
214                err = err.with_label(Label::primary(
215                    *span,
216                    eco_format!("outer variable `{name}` used here"),
217                ));
218            }
219
220            return Err(eco_vec!(err));
221        }
222        Ok(())
223    }
224}
225
226impl fmt::Display for Closure {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        let closure: ast::Lambda = self.node.cast().expect("closure");
229        let params = closure.params().to_untyped().to_text();
230
231        write!(f, "{} => ...", params)
232    }
233}
234
235#[derive(Debug)]
236pub enum FuncType {
237    Method,
238    MethodMut,
239    Associated,
240}
241
242impl From<FuncKind> for Func {
243    fn from(value: FuncKind) -> Self {
244        Self {
245            span: Span::detached(),
246            kind: value,
247        }
248    }
249}
250
251impl From<&'static NativeFuncData> for Func {
252    fn from(data: &'static NativeFuncData) -> Self {
253        FuncKind::Native(Static(data)).into()
254    }
255}
256
257impl From<Closure> for Func {
258    fn from(closure: Closure) -> Self {
259        FuncKind::Closure(Box::new(closure)).into()
260    }
261}
262
263#[derive(Debug)]
264pub struct ParamInfo {
265    pub name: &'static str,
266}
267
268cast! {
269    &'static NativeFuncData,
270    self => Func::from(self).into_value(),
271}