compose_syntax/ast/
func.rs1use crate::ast::{AstNode, Expr, Ident, Statement, node};
2use crate::kind::SyntaxKind;
3use crate::{Span, SyntaxNode};
4
5node! {
6 struct Lambda
7}
8
9impl<'a> Lambda<'a> {
10 pub fn params(self) -> Params<'a> {
11 self.0.cast_first()
12 }
13
14 pub fn captures(self) -> CaptureList<'a> {
15 self.0.cast_first()
16 }
17
18 pub fn statements(self) -> impl Iterator<Item = Statement<'a>> {
19 self.0
20 .children()
21 .skip_while(|n| n.kind() != SyntaxKind::Arrow)
22 .filter_map(SyntaxNode::cast)
23 }
24}
25
26node! {
27 struct CaptureList
28}
29
30impl<'a> CaptureList<'a> {
31 pub fn children(self) -> impl DoubleEndedIterator<Item = Capture<'a>> {
32 self.0.children().filter_map(SyntaxNode::cast)
33 }
34}
35
36node! {
37 struct Capture
38}
39
40impl<'a> Capture<'a> {
41 pub fn binding(self) -> Ident<'a> {
42 self.0.cast_first()
43 }
44
45 pub fn is_ref(self) -> bool {
46 self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
47 }
48
49 pub fn ref_span(self) -> Option<Span> {
50 self.0
51 .children()
52 .find(|n| n.kind() == SyntaxKind::Ref)
53 .map(|n| n.span())
54 }
55
56 pub fn is_mut(self) -> bool {
57 self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
58 }
59
60 pub fn mut_span(self) -> Option<Span> {
61 self.0
62 .children()
63 .find(|n| n.kind() == SyntaxKind::Mut)
64 .map(|n| n.span())
65 }
66}
67
68node! {
69 struct Params
70}
71
72impl<'a> Params<'a> {
73 pub fn children(self) -> impl DoubleEndedIterator<Item = Param<'a>> {
74 self.0.children().filter_map(SyntaxNode::cast)
75 }
76}
77
78node! {
79 struct Param
80}
81
82impl<'a> Param<'a> {
83 pub fn kind(self) -> ParamKind<'a> {
84 self.0.cast_first()
85 }
86
87 pub fn is_ref(self) -> bool {
88 self.0.children().any(|n| n.kind() == SyntaxKind::Ref)
89 }
90
91 pub fn ref_span(self) -> Option<Span> {
92 self.0
93 .children()
94 .find(|n| n.kind() == SyntaxKind::Ref)
95 .map(|n| n.span())
96 }
97
98 pub fn is_mut(self) -> bool {
99 self.0.children().any(|n| n.kind() == SyntaxKind::Mut)
100 }
101
102 pub fn mut_span(self) -> Option<Span> {
103 self.0
104 .children()
105 .find(|n| n.kind() == SyntaxKind::Mut)
106 .map(|n| n.span())
107 }
108}
109
110#[derive(Debug)]
111pub enum ParamKind<'a> {
112 Pos(Pattern<'a>),
114 Named(Named<'a>),
116}
117
118impl<'a> Default for ParamKind<'a> {
119 fn default() -> Self {
120 Self::Pos(Pattern::default())
121 }
122}
123
124impl<'a> AstNode<'a> for ParamKind<'a> {
125 fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
126 match node.kind() {
127 SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
128 _ => node.cast().map(Self::Pos),
129 }
130 }
131
132 fn to_untyped(&self) -> &'a SyntaxNode {
133 match self {
134 Self::Named(n) => n.to_untyped(),
135 Self::Pos(p) => p.to_untyped(),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Copy)]
141pub enum Pattern<'a> {
142 Single(Expr<'a>),
143 PlaceHolder(Underscore<'a>),
144 Destructuring(Destructuring<'a>),
145}
146
147impl<'a> Pattern<'a> {
148 pub fn bindings(self) -> Vec<Ident<'a>> {
149 match self {
150 Pattern::Single(Expr::Ident(i)) => vec![i],
151 Pattern::Destructuring(v) => v.bindings(),
152 _ => vec![],
153 }
154 }
155}
156
157impl<'a> AstNode<'a> for Pattern<'a> {
158 fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
159 match node.kind() {
160 SyntaxKind::Underscore => Some(Self::PlaceHolder(Underscore(node))),
161 SyntaxKind::Destructuring => Some(Self::Destructuring(Destructuring(node))),
162 _ => node.cast().map(Self::Single),
163 }
164 }
165
166 fn to_untyped(&self) -> &'a SyntaxNode {
167 match self {
168 Self::Single(e) => e.to_untyped(),
169 Self::PlaceHolder(u) => u.to_untyped(),
170 Self::Destructuring(d) => d.to_untyped(),
171 }
172 }
173}
174
175impl Default for Pattern<'_> {
176 fn default() -> Self {
177 Self::Single(Expr::default())
178 }
179}
180
181node! {
182 struct Underscore
183}
184
185node! {
186 struct Named
187}
188
189impl<'a> Named<'a> {
190 pub fn name(self) -> Ident<'a> {
191 self.0.cast_first()
192 }
193
194 pub fn expr(self) -> Expr<'a> {
195 self.0.cast_last()
196 }
197
198 pub fn pattern(self) -> Pattern<'a> {
202 self.0.cast_last()
203 }
204}
205
206node! {
207 struct Destructuring
208}
209
210impl<'a> Destructuring<'a> {
211 pub fn items(self) -> impl DoubleEndedIterator<Item = DestructuringItem<'a>> {
212 self.0.children().filter_map(SyntaxNode::cast)
213 }
214
215 pub fn bindings(self) -> Vec<Ident<'a>> {
216 self.items()
217 .flat_map(|binding| match binding {
218 DestructuringItem::Named(named) => named.pattern().bindings(),
219 DestructuringItem::Pattern(pattern) => pattern.bindings(),
220 DestructuringItem::Spread(spread) => {
221 spread.sink_ident().into_iter().collect()
222 }
223 })
224 .collect()
225 }
226}
227
228pub enum DestructuringItem<'a> {
229 Pattern(Pattern<'a>),
230 Named(Named<'a>),
231 Spread(Spread<'a>),
232}
233
234node! {
235 struct Spread
236}
237
238impl<'a> Spread<'a> {
239 pub fn expr(self) -> Expr<'a> {
244 self.0.cast_first()
245 }
246
247 pub fn sink_ident(self) -> Option<Ident<'a>> {
252 self.0.try_cast_first()
253 }
254
255 pub fn sink_expr(self) -> Option<Expr<'a>> {
260 self.0.try_cast_first()
261 }
262}
263
264impl<'a> AstNode<'a> for DestructuringItem<'a> {
265 fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
266 match node.kind() {
267 SyntaxKind::Named => Some(Self::Named(Named::from_untyped(node)?)),
268 SyntaxKind::Spread => Some(Self::Spread(Spread(node))),
269 _ => node.cast().map(Self::Pattern),
270 }
271 }
272
273 fn to_untyped(&self) -> &'a SyntaxNode {
274 match self {
275 Self::Named(n) => n.to_untyped(),
276 Self::Pattern(p) => p.to_untyped(),
277 Self::Spread(s) => s.to_untyped(),
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::assert_ast;
286 use crate::ast::FuncCall;
287 use crate::ast::binary::{BinOp, Binary};
288
289 #[test]
290 fn trailing_lambda() {
291 assert_ast!(
292 r#"
293 foo() { a => a + b; }
294 "#,
295 call as FuncCall {
296 with callee: Ident = call.callee() => {
297 assert_eq!(callee.get(), "foo");
298 }
299 call.args().items() => [
300 trailing_lambda as Lambda {
301 with params: Params = trailing_lambda.params() => {
302 params.children() => [
303 param as Param {
304 with pat: Pattern = param.kind() => {
305 with ident: Ident = pat => {
306 assert_eq!(ident.get(), "a");
307 }
308 }
309 }
310 ]
311 }
312
313 trailing_lambda.statements() => [
314 binary as Binary {
315 with lhs: Ident = binary.lhs() => {
316 assert_eq!(lhs.get(), "a");
317 }
318 with rhs: Ident = binary.rhs() => {
319 assert_eq!(rhs.get(), "b");
320 }
321 assert_eq!(binary.op(), BinOp::Add);
322 }
323 ]
324 }
325 ]
326 }
327 )
328 }
329}