compose_syntax/ast/
binary.rs

1use crate::ast::{node, Expr};
2use crate::kind::SyntaxKind;
3use crate::precedence::{Precedence, PrecedenceTrait};
4
5pub enum Assoc {
6    Left,
7    Right,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum BinOp {
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Mod,
17    And,
18    Or,
19    Eq,
20    Neq,
21    Lt,
22    Lte,
23    Gt,
24    Gte,
25
26    BitAnd,
27    BitOr,
28    BitXor,
29    BitShl,
30    BitShr,
31}
32
33impl BinOp {
34    pub(crate) fn from_kind(kind: SyntaxKind) -> Option<Self> {
35        Some(match kind {
36            SyntaxKind::Plus => Self::Add,
37            SyntaxKind::Minus => Self::Sub,
38            SyntaxKind::Star => Self::Mul,
39            SyntaxKind::Slash => Self::Div,
40            SyntaxKind::Percent => Self::Mod,
41            SyntaxKind::AmpAmp => Self::And,
42            SyntaxKind::PipePipe => Self::Or,
43
44            SyntaxKind::EqEq => Self::Eq,
45            SyntaxKind::BangEq => Self::Neq,
46            SyntaxKind::Lt => Self::Lt,
47            SyntaxKind::LtEq => Self::Lte,
48            SyntaxKind::Gt => Self::Gt,
49            SyntaxKind::GtEq => Self::Gte,
50
51            SyntaxKind::Amp => Self::BitAnd,
52            SyntaxKind::Pipe => Self::BitOr,
53            SyntaxKind::Hat => Self::BitXor,
54            SyntaxKind::LtLt => Self::BitShl,
55            SyntaxKind::GtGt => Self::BitShr,
56
57            _ => return None,
58        })
59    }
60
61    pub fn descriptive_name(self) -> &'static str {
62        match self {
63            Self::Add => "+",
64            Self::Sub => "-",
65            Self::Mul => "*",
66            Self::Div => "/",
67            Self::Mod => "%",
68            Self::And => "&&",
69            Self::Or => "||",
70            Self::Eq => "==",
71            Self::Neq => "!=",
72            Self::Lt => "<",
73            Self::Lte => "<=",
74            Self::Gt => ">",
75            Self::Gte => ">=",
76            Self::BitAnd => "&",
77            Self::BitOr => "|",
78            Self::BitXor => "^",
79            Self::BitShl => "<<",
80            Self::BitShr => ">>",
81        }
82    }
83
84    pub fn assoc(self) -> Assoc {
85        match self {
86            Self::Add | Self::Sub | Self::Mul | Self::Div | Self::Mod => Assoc::Left,
87            Self::BitAnd | Self::BitOr | Self::BitXor | Self::BitShl | Self::BitShr => Assoc::Left,
88            Self::And | Self::Or => Assoc::Left,
89            Self::Eq | Self::Neq | Self::Lt | Self::Lte | Self::Gt | Self::Gte => Assoc::Left,
90        }
91    }
92}
93
94impl PrecedenceTrait for BinOp {
95    fn precedence(&self) -> Precedence {
96        match self {
97            BinOp::Add | BinOp::Sub => Precedence::Sum,
98            BinOp::Mul | BinOp::Div | BinOp::Mod => Precedence::Product,
99            BinOp::And => Precedence::LogicalAnd,
100            BinOp::Or => Precedence::LogicalOr,
101            BinOp::Eq | BinOp::Neq => Precedence::Equals,
102            BinOp::Lt | BinOp::Lte | BinOp::Gt | BinOp::Gte => Precedence::LessGreater,
103            BinOp::BitAnd => Precedence::BitwiseAnd,
104            BinOp::BitOr => Precedence::BitwiseOr,
105            BinOp::BitXor => Precedence::BitwiseXor,
106            BinOp::BitShl | BinOp::BitShr => Precedence::BitShift,
107        }
108    }
109}
110
111node! {
112    struct Binary
113}
114
115impl<'a> Binary<'a> {
116    pub fn lhs(self) -> Expr<'a> {
117        self.0.cast_first()
118    }
119
120    pub fn rhs(self) -> Expr<'a> {
121        self.0.cast_last()
122    }
123
124    pub fn op(self) -> BinOp {
125        self.0
126            .children()
127            .find_map(|n| BinOp::from_kind(n.kind()))
128            .unwrap_or(BinOp::Add)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::assert_ast;
136    use crate::ast::{Int};
137
138    #[test]
139    fn test_binop() {
140        let binop = BinOp::from_kind(SyntaxKind::Plus).unwrap();
141        assert_eq!(binop, BinOp::Add);
142    }
143
144    #[test]
145    fn test_precedence_equal() {
146        assert_ast!("1 + 2 + 3",
147            bin as Binary {
148                assert_eq!(bin.op(), BinOp::Add);
149                with lhs: Int = bin.lhs() => {
150                    assert_eq!(lhs.get(), 1);
151                }
152                with rhs: Binary = bin.rhs() => {
153                    assert_eq!(rhs.op(), BinOp::Add);
154                    with inner_lhs: Int = rhs.lhs() => {
155                        assert_eq!(inner_lhs.get(), 2);
156                    }
157                    with inner_rhs: Int = rhs.rhs() => {
158                        assert_eq!(inner_rhs.get(), 3);
159                    }
160                }
161            }
162        );
163    }
164
165    #[test]
166    fn test_precedence_higher() {
167        assert_ast!("1 + 2 * 3",
168            bin as Binary {
169                assert_eq!(bin.op(), BinOp::Add);
170                with lhs: Int = bin.lhs() => {
171                    assert_eq!(lhs.get(), 1);
172                }
173                with rhs: Binary = bin.rhs() => {
174                    assert_eq!(rhs.op(), BinOp::Mul);
175                    with inner_lhs: Int = rhs.lhs() => {
176                        assert_eq!(inner_lhs.get(), 2);
177                    }
178                    with inner_rhs: Int = rhs.rhs() => {
179                        assert_eq!(inner_rhs.get(), 3);
180                    }
181                }
182            }
183        );
184        assert_ast!(
185            "1 * 2 + 3",
186            bin as Binary {
187                assert_eq!(bin.op(), BinOp::Add);
188                with lhs: Binary = bin.lhs() => {
189                    assert_eq!(lhs.op(), BinOp::Mul);
190                    with inner_lhs: Int = lhs.lhs() => {
191                        assert_eq!(inner_lhs.get(), 1);
192                    }
193                    with inner_rhs: Int = lhs.rhs() => {
194                        assert_eq!(inner_rhs.get(), 2);
195                    }
196                }
197                with rhs: Int = bin.rhs() => {
198                    assert_eq!(rhs.get(), 3);
199                }
200            }
201        );
202    }
203}