compose_syntax/ast/
func.rs

1use crate::ast::{node, AstNode, Expr, Ident, Statement};
2use crate::kind::SyntaxKind;
3use crate::{Span, SyntaxNode};
4
5node! {
6    struct Lambda
7}
8
9impl<'a> Lambda<'a> {
10    pub fn params(self) -> Params<'a> {
11        self.0.cast_first()
12    }
13
14    pub fn captures(self) -> CaptureList<'a> {
15        self.0.cast_first()
16    }
17
18    pub fn statements(self) -> impl Iterator<Item=Statement<'a>> {
19        self.0
20            .children()
21            .skip_while(|n| n.kind() != SyntaxKind::Arrow)
22            .filter_map(SyntaxNode::cast)
23    }
24}
25
26node! {
27    struct CaptureList
28}
29
30impl<'a> CaptureList<'a> {
31    pub fn children(self) -> impl DoubleEndedIterator<Item=Capture<'a>> {
32        self.0.children().filter_map(SyntaxNode::cast)
33    }
34}
35
36node! {
37    struct Capture
38}
39
40impl<'a> Capture<'a> {
41    pub fn binding(self) -> Ident<'a> {
42        self.0.cast_first()
43    }
44
45    pub fn is_ref(self) -> bool {
46        self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
47    }
48
49    pub fn ref_span(self) -> Option<Span> {
50        self.0
51            .children()
52            .find(|n| n.kind() == SyntaxKind::Ref)
53            .map(|n| n.span())
54    }
55
56    pub fn is_mut(self) -> bool {
57        self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
58    }
59
60    pub fn mut_span(self) -> Option<Span> {
61        self.0
62            .children()
63            .find(|n| n.kind() == SyntaxKind::Mut)
64            .map(|n| n.span())
65    }
66}
67
68node! {
69    struct Params
70}
71
72impl<'a> Params<'a> {
73    pub fn children(self) -> impl DoubleEndedIterator<Item=Param<'a>> {
74        self.0.children().filter_map(SyntaxNode::cast)
75    }
76}
77
78node! {
79    struct Param
80}
81
82impl<'a> Param<'a> {
83    pub fn kind(self) -> ParamKind<'a> {
84        self.0.cast_first()
85    }
86
87    pub fn is_ref(self) -> bool {
88        self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
89    }
90
91    pub fn ref_span(self) -> Option<Span> {
92        self.0
93            .children()
94            .find(|n| n.kind() == SyntaxKind::Ref)
95            .map(|n| n.span())
96    }
97
98    pub fn is_mut(self) -> bool {
99        self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
100    }
101
102    pub fn mut_span(self) -> Option<Span> {
103        self.0
104            .children()
105            .find(|n| n.kind() == SyntaxKind::Mut)
106            .map(|n| n.span())
107    }
108}
109
110#[derive(Debug)]
111pub enum ParamKind<'a> {
112    // A positional parameter `x`
113    Pos(Pattern<'a>),
114    // A named parameter `help = "try it like this"`
115    Named(Named<'a>),
116}
117
118impl<'a> Default for ParamKind<'a> {
119    fn default() -> Self {
120        Self::Pos(Pattern::default())
121    }
122}
123
124impl<'a> AstNode<'a> for ParamKind<'a> {
125    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
126        match node.kind() {
127            SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
128            _ => node.cast().map(Self::Pos),
129        }
130    }
131
132    fn to_untyped(&self) -> &'a SyntaxNode {
133        match self {
134            Self::Named(n) => n.to_untyped(),
135            Self::Pos(p) => p.to_untyped(),
136        }
137    }
138}
139
140#[derive(Debug, Clone, Copy)]
141pub enum Pattern<'a> {
142    Single(Expr<'a>),
143    PlaceHolder(Underscore<'a>),
144    Destructuring(Destructuring<'a>),
145}
146
147impl<'a> Pattern<'a> {
148    pub fn bindings(self) -> Vec<Ident<'a>> {
149        match self {
150            Pattern::Single(Expr::Ident(i)) => vec![i],
151            Pattern::Destructuring(v) => v.bindings(),
152            _ => vec![],
153        }
154    }
155}
156
157impl<'a> AstNode<'a> for Pattern<'a> {
158    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
159        match node.kind() {
160            SyntaxKind::Underscore => Some(Self::PlaceHolder(Underscore(node))),
161            SyntaxKind::Destructuring => Some(Self::Destructuring(Destructuring(node))),
162            _ => node.cast().map(Self::Single),
163        }
164    }
165
166    fn to_untyped(&self) -> &'a SyntaxNode {
167        match self {
168            Self::Single(e) => e.to_untyped(),
169            Self::PlaceHolder(u) => u.to_untyped(),
170            Self::Destructuring(d) => d.to_untyped(),
171        }
172    }
173}
174
175impl Default for Pattern<'_> {
176    fn default() -> Self {
177        Self::Single(Expr::default())
178    }
179}
180
181node! {
182    struct Underscore
183}
184
185node! {
186    struct Named
187}
188
189impl<'a> Named<'a> {
190    pub fn name(self) -> Ident<'a> {
191        self.0.cast_first()
192    }
193
194    pub fn expr(self) -> Expr<'a> {
195        self.0.cast_last()
196    }
197
198    pub fn pattern(self) -> Pattern<'a> {
199        self.0.cast_last()
200    }
201}
202
203node! {
204    struct Destructuring
205}
206
207impl<'a> Destructuring<'a> {
208    pub fn items(self) -> impl DoubleEndedIterator<Item=DestructuringItem<'a>> {
209        self.0.children().filter_map(SyntaxNode::cast)
210    }
211
212    pub fn bindings(self) -> Vec<Ident<'a>> {
213        self.items()
214            .flat_map(|binding| match binding {
215                DestructuringItem::Named(named) => named.pattern().bindings(),
216                DestructuringItem::Pattern(pattern) => pattern.bindings(),
217            })
218            .collect()
219    }
220}
221
222pub enum DestructuringItem<'a> {
223    Pattern(Pattern<'a>),
224    Named(Named<'a>),
225}
226
227impl<'a> AstNode<'a> for DestructuringItem<'a> {
228    fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
229        match node.kind() {
230            SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
231            _ => node.cast().map(Self::Pattern),
232        }
233    }
234
235    fn to_untyped(&self) -> &'a SyntaxNode {
236        match self {
237            Self::Named(n) => n.to_untyped(),
238            Self::Pattern(p) => p.to_untyped(),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::assert_ast;
247    use crate::ast::binary::{BinOp, Binary};
248    use crate::ast::FuncCall;
249
250    #[test]
251    fn trailing_lambda() {
252        assert_ast!(
253            r#"
254            foo() { a => a + b; }
255            "#,
256            call as FuncCall {
257                with callee: Ident = call.callee() => {
258                    assert_eq!(callee.get(), "foo");
259                }
260                call.args().items() => [
261                    trailing_lambda as Lambda {
262                        with params: Params = trailing_lambda.params() => {
263                            params.children() => [
264                                param as Param {
265                                    with pat: Pattern = param.kind() => {
266                                        with ident: Ident = pat => {
267                                            assert_eq!(ident.get(), "a");
268                                        }
269                                    }
270                                }
271                            ]
272                        }
273
274                        trailing_lambda.statements() => [
275                            binary as Binary {
276                                with lhs: Ident = binary.lhs() => {
277                                    assert_eq!(lhs.get(), "a");
278                                }
279                                with rhs: Ident = binary.rhs() => {
280                                    assert_eq!(rhs.get(), "b");
281                                }
282                                assert_eq!(binary.op(), BinOp::Add);
283                            }
284                        ]
285                    }
286                ]
287            }
288        )
289    }
290}