asn1_rs/
header.rs

1use crate::ber::*;
2use crate::der_constraint_fail_if;
3use crate::error::*;
4#[cfg(feature = "std")]
5use crate::ToDer;
6use crate::{BerParser, Class, DerParser, DynTagged, FromBer, FromDer, Length, Tag, ToStatic};
7use alloc::borrow::Cow;
8use core::convert::TryFrom;
9use nom::bytes::streaming::take;
10
11/// BER/DER object header (identifier and length)
12#[derive(Clone, Debug)]
13pub struct Header<'a> {
14    /// Object class: universal, application, context-specific, or private
15    pub(crate) class: Class,
16    /// Constructed attribute: true if constructed, else false
17    pub(crate) constructed: bool,
18    /// Tag number
19    pub(crate) tag: Tag,
20    /// Object length: value if definite, or indefinite
21    pub(crate) length: Length,
22
23    /// Optionally, the raw encoding of the tag
24    ///
25    /// This is useful in some cases, where different representations of the same
26    /// BER tags have different meanings (BER only)
27    pub(crate) raw_tag: Option<Cow<'a, [u8]>>,
28}
29
30impl<'a> Header<'a> {
31    /// Build a new BER/DER header from the provided values
32    pub const fn new(class: Class, constructed: bool, tag: Tag, length: Length) -> Self {
33        Header {
34            tag,
35            constructed,
36            class,
37            length,
38            raw_tag: None,
39        }
40    }
41
42    /// Build a new BER/DER header from the provided tag, with default values for other fields
43    #[inline]
44    pub const fn new_simple(tag: Tag) -> Self {
45        let constructed = matches!(tag, Tag::Sequence | Tag::Set);
46        Self::new(Class::Universal, constructed, tag, Length::Definite(0))
47    }
48
49    /// Set the class of this `Header`
50    #[inline]
51    pub fn with_class(self, class: Class) -> Self {
52        Self { class, ..self }
53    }
54
55    /// Set the constructed flags of this `Header`
56    #[inline]
57    pub fn with_constructed(self, constructed: bool) -> Self {
58        Self {
59            constructed,
60            ..self
61        }
62    }
63
64    /// Set the tag of this `Header`
65    #[inline]
66    pub fn with_tag(self, tag: Tag) -> Self {
67        Self { tag, ..self }
68    }
69
70    /// Set the length of this `Header`
71    #[inline]
72    pub fn with_length(self, length: Length) -> Self {
73        Self { length, ..self }
74    }
75
76    /// Update header to add reference to raw tag
77    #[inline]
78    pub fn with_raw_tag(self, raw_tag: Option<Cow<'a, [u8]>>) -> Self {
79        Header { raw_tag, ..self }
80    }
81
82    /// Return the class of this header.
83    #[inline]
84    pub const fn class(&self) -> Class {
85        self.class
86    }
87
88    /// Return true if this header has the 'constructed' flag.
89    #[inline]
90    pub const fn constructed(&self) -> bool {
91        self.constructed
92    }
93
94    /// Return the tag of this header.
95    #[inline]
96    pub const fn tag(&self) -> Tag {
97        self.tag
98    }
99
100    /// Return the length of this header.
101    #[inline]
102    pub const fn length(&self) -> Length {
103        self.length
104    }
105
106    /// Return the raw tag encoding, if it was stored in this object
107    #[inline]
108    pub fn raw_tag(&self) -> Option<&[u8]> {
109        self.raw_tag.as_ref().map(|cow| cow.as_ref())
110    }
111
112    /// Test if object is primitive
113    #[inline]
114    pub const fn is_primitive(&self) -> bool {
115        !self.constructed
116    }
117
118    /// Test if object is constructed
119    #[inline]
120    pub const fn is_constructed(&self) -> bool {
121        self.constructed
122    }
123
124    /// Return error if class is not the expected class
125    #[inline]
126    pub const fn assert_class(&self, class: Class) -> Result<()> {
127        self.class.assert_eq(class)
128    }
129
130    /// Return error if tag is not the expected tag
131    #[inline]
132    pub const fn assert_tag(&self, tag: Tag) -> Result<()> {
133        self.tag.assert_eq(tag)
134    }
135
136    /// Return error if object is not primitive
137    #[inline]
138    pub const fn assert_primitive(&self) -> Result<()> {
139        if self.is_primitive() {
140            Ok(())
141        } else {
142            Err(Error::ConstructUnexpected)
143        }
144    }
145
146    /// Return error if object is primitive
147    #[inline]
148    pub const fn assert_constructed(&self) -> Result<()> {
149        if !self.is_primitive() {
150            Ok(())
151        } else {
152            Err(Error::ConstructExpected)
153        }
154    }
155
156    /// Test if object class is Universal
157    #[inline]
158    pub const fn is_universal(&self) -> bool {
159        self.class as u8 == Class::Universal as u8
160    }
161    /// Test if object class is Application
162    #[inline]
163    pub const fn is_application(&self) -> bool {
164        self.class as u8 == Class::Application as u8
165    }
166    /// Test if object class is Context-specific
167    #[inline]
168    pub const fn is_contextspecific(&self) -> bool {
169        self.class as u8 == Class::ContextSpecific as u8
170    }
171    /// Test if object class is Private
172    #[inline]
173    pub const fn is_private(&self) -> bool {
174        self.class as u8 == Class::Private as u8
175    }
176
177    /// Return error if object length is definite
178    #[inline]
179    pub const fn assert_definite(&self) -> Result<()> {
180        if self.length.is_definite() {
181            Ok(())
182        } else {
183            Err(Error::DerConstraintFailed(DerConstraint::IndefiniteLength))
184        }
185    }
186
187    /// Get the content following a BER header
188    #[inline]
189    pub fn parse_ber_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
190        // defaults to maximum depth 8
191        // depth is used only if BER, and length is indefinite
192        BerParser::get_object_content(i, self, 8)
193    }
194
195    /// Get the content following a DER header
196    #[inline]
197    pub fn parse_der_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
198        self.assert_definite()?;
199        DerParser::get_object_content(i, self, 8)
200    }
201}
202
203impl From<Tag> for Header<'_> {
204    #[inline]
205    fn from(tag: Tag) -> Self {
206        let constructed = matches!(tag, Tag::Sequence | Tag::Set);
207        Self::new(Class::Universal, constructed, tag, Length::Definite(0))
208    }
209}
210
211impl<'a> ToStatic for Header<'a> {
212    type Owned = Header<'static>;
213
214    fn to_static(&self) -> Self::Owned {
215        let raw_tag: Option<Cow<'static, [u8]>> =
216            self.raw_tag.as_ref().map(|b| Cow::Owned(b.to_vec()));
217        Header {
218            tag: self.tag,
219            constructed: self.constructed,
220            class: self.class,
221            length: self.length,
222            raw_tag,
223        }
224    }
225}
226
227impl<'a> FromBer<'a> for Header<'a> {
228    fn from_ber(bytes: &'a [u8]) -> ParseResult<Self> {
229        let (i1, el) = parse_identifier(bytes)?;
230        let class = match Class::try_from(el.0) {
231            Ok(c) => c,
232            Err(_) => unreachable!(), // Cannot fail, we have read exactly 2 bits
233        };
234        let (i2, len) = parse_ber_length_byte(i1)?;
235        let (i3, len) = match (len.0, len.1) {
236            (0, l1) => {
237                // Short form: MSB is 0, the rest encodes the length (which can be 0) (8.1.3.4)
238                (i2, Length::Definite(usize::from(l1)))
239            }
240            (_, 0) => {
241                // Indefinite form: MSB is 1, the rest is 0 (8.1.3.6)
242                // If encoding is primitive, definite form shall be used (8.1.3.2)
243                if el.1 == 0 {
244                    return Err(nom::Err::Error(Error::ConstructExpected));
245                }
246                (i2, Length::Indefinite)
247            }
248            (_, l1) => {
249                // if len is 0xff -> error (8.1.3.5)
250                if l1 == 0b0111_1111 {
251                    return Err(nom::Err::Error(Error::InvalidLength));
252                }
253                let (i3, llen) = take(l1)(i2)?;
254                match bytes_to_u64(llen) {
255                    Ok(l) => {
256                        let l =
257                            usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
258                        (i3, Length::Definite(l))
259                    }
260                    Err(_) => {
261                        return Err(nom::Err::Error(Error::InvalidLength));
262                    }
263                }
264            }
265        };
266        let constructed = el.1 != 0;
267        let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
268        Ok((i3, hdr))
269    }
270}
271
272impl<'a> FromDer<'a> for Header<'a> {
273    fn from_der(bytes: &'a [u8]) -> ParseResult<Self> {
274        let (i1, el) = parse_identifier(bytes)?;
275        let class = match Class::try_from(el.0) {
276            Ok(c) => c,
277            Err(_) => unreachable!(), // Cannot fail, we have read exactly 2 bits
278        };
279        let (i2, len) = parse_ber_length_byte(i1)?;
280        let (i3, len) = match (len.0, len.1) {
281            (0, l1) => {
282                // Short form: MSB is 0, the rest encodes the length (which can be 0) (8.1.3.4)
283                (i2, Length::Definite(usize::from(l1)))
284            }
285            (_, 0) => {
286                // Indefinite form is not allowed in DER (10.1)
287                return Err(nom::Err::Error(Error::DerConstraintFailed(
288                    DerConstraint::IndefiniteLength,
289                )));
290            }
291            (_, l1) => {
292                // if len is 0xff -> error (8.1.3.5)
293                if l1 == 0b0111_1111 {
294                    return Err(nom::Err::Error(Error::InvalidLength));
295                }
296                // DER(9.1) if len is 0 (indefinite form), obj must be constructed
297                der_constraint_fail_if!(
298                    &i[1..],
299                    len.1 == 0 && el.1 != 1,
300                    DerConstraint::NotConstructed
301                );
302                let (i3, llen) = take(l1)(i2)?;
303                match bytes_to_u64(llen) {
304                    Ok(l) => {
305                        // DER: should have been encoded in short form (< 127)
306                        // XXX der_constraint_fail_if!(i, l < 127);
307                        let l =
308                            usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
309                        (i3, Length::Definite(l))
310                    }
311                    Err(_) => {
312                        return Err(nom::Err::Error(Error::InvalidLength));
313                    }
314                }
315            }
316        };
317        let constructed = el.1 != 0;
318        let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
319        Ok((i3, hdr))
320    }
321}
322
323impl DynTagged for (Class, bool, Tag) {
324    fn tag(&self) -> Tag {
325        self.2
326    }
327}
328
329#[cfg(feature = "std")]
330impl ToDer for (Class, bool, Tag) {
331    fn to_der_len(&self) -> Result<usize> {
332        let (_, _, tag) = self;
333        match tag.0 {
334            0..=30 => Ok(1),
335            t => {
336                let mut sz = 1;
337                let mut val = t;
338                loop {
339                    if val <= 127 {
340                        return Ok(sz + 1);
341                    } else {
342                        val >>= 7;
343                        sz += 1;
344                    }
345                }
346            }
347        }
348    }
349
350    fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
351        let (class, constructed, tag) = self;
352        let b0 = (*class as u8) << 6;
353        let b0 = b0 | if *constructed { 0b10_0000 } else { 0 };
354        if tag.0 > 30 {
355            let mut val = tag.0;
356
357            const BUF_SZ: usize = 8;
358            let mut buffer = [0u8; BUF_SZ];
359            let mut current_index = BUF_SZ - 1;
360
361            // first byte: class+constructed+0x1f
362            let b0 = b0 | 0b1_1111;
363            let mut sz = writer.write(&[b0])?;
364
365            // now write bytes from right (last) to left
366
367            // last encoded byte
368            buffer[current_index] = (val & 0x7f) as u8;
369            val >>= 7;
370
371            while val > 0 {
372                current_index -= 1;
373                if current_index == 0 {
374                    return Err(SerializeError::InvalidLength);
375                }
376                buffer[current_index] = (val & 0x7f) as u8 | 0x80;
377                val >>= 7;
378            }
379
380            sz += writer.write(&buffer[current_index..])?;
381            Ok(sz)
382        } else {
383            let b0 = b0 | (tag.0 as u8);
384            let sz = writer.write(&[b0])?;
385            Ok(sz)
386        }
387    }
388
389    fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
390        Ok(0)
391    }
392}
393
394impl DynTagged for Header<'_> {
395    fn tag(&self) -> Tag {
396        self.tag
397    }
398}
399
400#[cfg(feature = "std")]
401impl ToDer for Header<'_> {
402    fn to_der_len(&self) -> Result<usize> {
403        let tag_len = (self.class, self.constructed, self.tag).to_der_len()?;
404        let len_len = self.length.to_der_len()?;
405        Ok(tag_len + len_len)
406    }
407
408    fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
409        let sz = (self.class, self.constructed, self.tag).write_der_header(writer)?;
410        let sz = sz + self.length.write_der_header(writer)?;
411        Ok(sz)
412    }
413
414    fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
415        Ok(0)
416    }
417
418    fn write_der_raw(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
419        // use raw_tag if present
420        let sz = match &self.raw_tag {
421            Some(t) => writer.write(t)?,
422            None => (self.class, self.constructed, self.tag).write_der_header(writer)?,
423        };
424        let sz = sz + self.length.write_der_header(writer)?;
425        Ok(sz)
426    }
427}
428
429/// Compare two BER headers. `len` fields are compared only if both objects have it set (same for `raw_tag`)
430impl<'a> PartialEq<Header<'a>> for Header<'a> {
431    fn eq(&self, other: &Header) -> bool {
432        self.class == other.class
433            && self.tag == other.tag
434            && self.constructed == other.constructed
435            && {
436                if self.length.is_null() && other.length.is_null() {
437                    self.length == other.length
438                } else {
439                    true
440                }
441            }
442            && {
443                // it tag is present for both, compare it
444                if self.raw_tag.as_ref().xor(other.raw_tag.as_ref()).is_none() {
445                    self.raw_tag == other.raw_tag
446                } else {
447                    true
448                }
449            }
450    }
451}
452
453impl Eq for Header<'_> {}
454
455#[cfg(test)]
456mod tests {
457    use crate::*;
458    use hex_literal::hex;
459
460    /// Generic tests on methods, and coverage tests
461    #[test]
462    fn methods_header() {
463        // Getters
464        let input = &hex! {"02 01 00"};
465        let (rem, header) = Header::from_ber(input).expect("parsing header failed");
466        assert_eq!(header.class(), Class::Universal);
467        assert_eq!(header.tag(), Tag::Integer);
468        assert!(header.assert_primitive().is_ok());
469        assert!(header.assert_constructed().is_err());
470        assert!(header.is_universal());
471        assert!(!header.is_application());
472        assert!(!header.is_private());
473        assert_eq!(rem, &input[2..]);
474
475        // test PartialEq
476        let hdr2 = Header::new_simple(Tag::Integer);
477        assert_eq!(header, hdr2);
478
479        // builder methods
480        let hdr3 = hdr2
481            .with_class(Class::ContextSpecific)
482            .with_constructed(true)
483            .with_length(Length::Definite(1));
484        assert!(hdr3.constructed());
485        assert!(hdr3.is_constructed());
486        assert!(hdr3.assert_constructed().is_ok());
487        assert!(hdr3.is_contextspecific());
488        let xx = hdr3.to_der_vec().expect("serialize failed");
489        assert_eq!(&xx, &[0xa2, 0x01]);
490
491        // indefinite length
492        let hdr4 = hdr3.with_length(Length::Indefinite);
493        assert!(hdr4.assert_definite().is_err());
494        let xx = hdr4.to_der_vec().expect("serialize failed");
495        assert_eq!(&xx, &[0xa2, 0x80]);
496
497        // parse_*_content
498        let hdr = Header::new_simple(Tag(2)).with_length(Length::Definite(1));
499        let (_, r) = hdr.parse_ber_content(&input[2..]).unwrap();
500        assert_eq!(r, &input[2..]);
501        let (_, r) = hdr.parse_der_content(&input[2..]).unwrap();
502        assert_eq!(r, &input[2..]);
503    }
504}