compose_syntax/ast/
func.rs

1use crate::ast::{AstNode, Expr, Ident, Statement, node};
2use crate::kind::SyntaxKind;
3use crate::{Span, SyntaxNode};
4
5node! {
6    struct Lambda
7}
8
9impl<'a> Lambda<'a> {
10    pub fn params(self) -> Params<'a> {
11        self.0.cast_first()
12    }
13
14    pub fn captures(self) -> CaptureList<'a> {
15        self.0.cast_first()
16    }
17
18    pub fn statements(self) -> impl Iterator<Item = Statement<'a>> {
19        self.0
20            .children()
21            .skip_while(|n| n.kind() != SyntaxKind::Arrow)
22            .filter_map(SyntaxNode::cast)
23    }
24}
25
26node! {
27    struct CaptureList
28}
29
30impl<'a> CaptureList<'a> {
31    pub fn children(self) -> impl DoubleEndedIterator<Item = Capture<'a>> {
32        self.0.children().filter_map(SyntaxNode::cast)
33    }
34}
35
36node! {
37    struct Capture
38}
39
40impl<'a> Capture<'a> {
41    pub fn binding(self) -> Ident<'a> {
42        self.0.cast_first()
43    }
44
45    pub fn is_ref(self) -> bool {
46        self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
47    }
48
49    pub fn ref_span(self) -> Option<Span> {
50        self.0
51            .children()
52            .find(|n| n.kind() == SyntaxKind::Ref)
53            .map(|n| n.span())
54    }
55
56    pub fn is_mut(self) -> bool {
57        self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
58    }
59
60    pub fn mut_span(self) -> Option<Span> {
61        self.0
62            .children()
63            .find(|n| n.kind() == SyntaxKind::Mut)
64            .map(|n| n.span())
65    }
66}
67
68node! {
69    struct Params
70}
71
72impl<'a> Params<'a> {
73    pub fn children(self) -> impl DoubleEndedIterator<Item = Param<'a>> {
74        self.0.children().filter_map(SyntaxNode::cast)
75    }
76}
77
78node! {
79    struct Param
80}
81
82impl<'a> Param<'a> {
83    pub fn kind(self) -> ParamKind<'a> {
84        self.0.cast_first()
85    }
86
87    pub fn is_ref(self) -> bool {
88        self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
89    }
90
91    pub fn ref_span(self) -> Option<Span> {
92        self.0
93            .children()
94            .find(|n| n.kind() == SyntaxKind::Ref)
95            .map(|n| n.span())
96    }
97
98    pub fn is_mut(self) -> bool {
99        self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
100    }
101
102    pub fn mut_span(self) -> Option<Span> {
103        self.0
104            .children()
105            .find(|n| n.kind() == SyntaxKind::Mut)
106            .map(|n| n.span())
107    }
108}
109
110#[derive(Debug)]
111pub enum ParamKind<'a> {
112    // A positional parameter `x`
113    Pos(Pattern<'a>),
114    // A named parameter `help = "try it like this"`
115    Named(Named<'a>),
116}
117
118impl<'a> Default for ParamKind<'a> {
119    fn default() -> Self {
120        Self::Pos(Pattern::default())
121    }
122}
123
124impl<'a> AstNode<'a> for ParamKind<'a> {
125    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
126        match node.kind() {
127            SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
128            _ => node.cast().map(Self::Pos),
129        }
130    }
131
132    fn to_untyped(&self) -> &'a SyntaxNode {
133        match self {
134            Self::Named(n) => n.to_untyped(),
135            Self::Pos(p) => p.to_untyped(),
136        }
137    }
138}
139
140#[derive(Debug, Clone, Copy)]
141pub enum Pattern<'a> {
142    Single(Expr<'a>),
143    PlaceHolder(Underscore<'a>),
144    Destructuring(Destructuring<'a>),
145}
146
147impl<'a> Pattern<'a> {
148    pub fn bindings(self) -> Vec<Ident<'a>> {
149        match self {
150            Pattern::Single(Expr::Ident(i)) => vec![i],
151            Pattern::Destructuring(v) => v.bindings(),
152            _ => vec![],
153        }
154    }
155}
156
157impl<'a> AstNode<'a> for Pattern<'a> {
158    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
159        match node.kind() {
160            SyntaxKind::Underscore => Some(Self::PlaceHolder(Underscore(node))),
161            SyntaxKind::Destructuring => Some(Self::Destructuring(Destructuring(node))),
162            _ => node.cast().map(Self::Single),
163        }
164    }
165
166    fn to_untyped(&self) -> &'a SyntaxNode {
167        match self {
168            Self::Single(e) => e.to_untyped(),
169            Self::PlaceHolder(u) => u.to_untyped(),
170            Self::Destructuring(d) => d.to_untyped(),
171        }
172    }
173}
174
175impl Default for Pattern<'_> {
176    fn default() -> Self {
177        Self::Single(Expr::default())
178    }
179}
180
181node! {
182    struct Underscore
183}
184
185node! {
186    struct Named
187}
188
189impl<'a> Named<'a> {
190    pub fn name(self) -> Ident<'a> {
191        self.0.cast_first()
192    }
193
194    pub fn expr(self) -> Expr<'a> {
195        self.0.cast_last()
196    }
197
198    /// The right hand of the pair as a pattern.
199    ///
200    /// This should only be used in `destructuring`
201    pub fn pattern(self) -> Pattern<'a> {
202        self.0.cast_last()
203    }
204}
205
206node! {
207    struct Destructuring
208}
209
210impl<'a> Destructuring<'a> {
211    pub fn items(self) -> impl DoubleEndedIterator<Item = DestructuringItem<'a>> {
212        self.0.children().filter_map(SyntaxNode::cast)
213    }
214
215    pub fn bindings(self) -> Vec<Ident<'a>> {
216        self.items()
217            .flat_map(|binding| match binding {
218                DestructuringItem::Named(named) => named.pattern().bindings(),
219                DestructuringItem::Pattern(pattern) => pattern.bindings(),
220                DestructuringItem::Spread(spread) => {
221                    spread.sink_ident().into_iter().collect()
222                }
223            })
224            .collect()
225    }
226}
227
228pub enum DestructuringItem<'a> {
229    Pattern(Pattern<'a>),
230    Named(Named<'a>),
231    Spread(Spread<'a>),
232}
233
234node! {
235    struct Spread
236}
237
238impl<'a> Spread<'a> {
239    /// The spread expression.
240    ///
241    /// This should only be accessed if this `Spread` is contained in an
242    /// `ArrayItem`, `MapItem`, or `Arg`.
243    pub fn expr(self) -> Expr<'a> {
244        self.0.cast_first()
245    }
246
247    /// The sink identifier, if present.
248    ///
249    /// This should only be accessed if this `Spread` is contained in a
250    /// `Param` or binding `DestructuringItem`.
251    pub fn sink_ident(self) -> Option<Ident<'a>> {
252        self.0.try_cast_first()
253    }
254
255    /// The sink expressions, if present.
256    ///
257    /// This should only be accessed if this `Spread` is contained in a
258    /// `DestructuringItem`.
259    pub fn sink_expr(self) -> Option<Expr<'a>> {
260        self.0.try_cast_first()
261    }
262}
263
264impl<'a> AstNode<'a> for DestructuringItem<'a> {
265    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
266        match node.kind() {
267            SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
268            SyntaxKind::Spread => Some(Self::Spread(Spread(node))),
269            _ => node.cast().map(Self::Pattern),
270        }
271    }
272
273    fn to_untyped(&self) -> &'a SyntaxNode {
274        match self {
275            Self::Named(n) => n.to_untyped(),
276            Self::Pattern(p) => p.to_untyped(),
277            Self::Spread(s) => s.to_untyped(),
278        }
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::assert_ast;
286    use crate::ast::FuncCall;
287    use crate::ast::binary::{BinOp, Binary};
288
289    #[test]
290    fn trailing_lambda() {
291        assert_ast!(
292            r#"
293            foo() { a => a + b; }
294            "#,
295            call as FuncCall {
296                with callee: Ident = call.callee() => {
297                    assert_eq!(callee.get(), "foo");
298                }
299                call.args().items() => [
300                    trailing_lambda as Lambda {
301                        with params: Params = trailing_lambda.params() => {
302                            params.children() => [
303                                param as Param {
304                                    with pat: Pattern = param.kind() => {
305                                        with ident: Ident = pat => {
306                                            assert_eq!(ident.get(), "a");
307                                        }
308                                    }
309                                }
310                            ]
311                        }
312
313                        trailing_lambda.statements() => [
314                            binary as Binary {
315                                with lhs: Ident = binary.lhs() => {
316                                    assert_eq!(lhs.get(), "a");
317                                }
318                                with rhs: Ident = binary.rhs() => {
319                                    assert_eq!(rhs.get(), "b");
320                                }
321                                assert_eq!(binary.op(), BinOp::Add);
322                            }
323                        ]
324                    }
325                ]
326            }
327        )
328    }
329}