Viewing: flagtools.py


"""
Flag handling routines
Copyright 2015 Cray Inc.  All Rights Reserved
"""


### TBD: The "Simple" in the addSimple* interfaces refers to a flag
### that's a single bit.  It's meant to distinguish from flags that
### have multibit fields, such as the node/zone indices stuck in the
### high end of struct page.flags; or a field that's mostly a pointer
### but with some flags in the low bits.
#
### To add cases like that will mean redoing most of the
### implementation, but all the current interfaces should be ok, with
### new interfaces added to let users define the non-simple flags.

import uflookup


class FlagSet:
    """A collection of flags and values, with routines for translating

    For decoding a flag int to a string, encoding a flag string to an
    int, and providing python identifiers for testing by name, e.g.,

    jafs = FlagSet() # job_attach flagset
    jafs.addSimpleFlag("disable_affinity_apply")
    if job_attach.flags & jafs.disable_affinity_apply: ...

    The advantages over just using a dict include:
    * Define the values once, and get value->string, string->value,
      and python identifiers ns.<name> and ns.<name>_shift as above.
    * The auto-incrementing _next_bit
    """
    def __init__(self, mapping=None):
        """Create and initialize a FlagSet object

        Arguments:
            mapping:    if specified, provides a mapping object, e.g. dict,
                        that supplies the initial key(name)/value pairs.
        """
        # Public dict of flag names to flag values (not the bit number)
        self.str_to_value = {}
        # Public dict of flag values to flag names
        self.value_to_str = {}

        self._next_bit = 0

        # sorted_values is so that translating a value to a string
        # will report the strings in the same order every time.  That
        # order is by numerically increasing value.
        self._sorted_values = []
        self._sorted_strs = []

        if mapping is not None:
            self.addMap(mapping)

    def addSimpleFlag(self, s, bit=None):
        """Add a single-bit flag.

        If bit is not specified, uses the bit one greater than the
        previously defined bit.  If multiple flags are defined to use
        the same bit, value_to_str will remember only the first."""

        if s in self.str_to_value.keys():
            raise ValueError("Flag {0} already defined (value {1:x})".format(
                s, self.str_to_value[s]))
        if s + "_shift" in self.str_to_value.keys():
            raise ValueError("Flag {0} conflicts with another "
                             "flag ({1})".format(s, s + "_shift"))

        try:
            getattr(self, s)
        except AttributeError:
            pass
        else:
            raise ValueError("Value {0} already used by FlagSet object!".
                             format(s))

        try:
            getattr(self, s + "_shift")
        except AttributeError:
            pass
        else:
            raise valueError("{0}_shift already used by FlagSet object!".
                             format(s))


        if bit is None:
            bit = self._next_bit;
        self._next_bit = bit + 1

        value = 1 << bit
        if value not in self.value_to_str:
            self.value_to_str[value] = s
        self.str_to_value[s] = value

        self._sorted_values = []

        setattr(self, s, value)
        setattr(self, s+"_shift", bit)

    def addSimpleFlags(self, *l):
        """Adds a list of single-bit flags."""
        map(self.addSimpleFlag, l)

    def addMap(self, mapping):
        """Add the key/value pairs from a mapping type"""
        for k, v in mapping.items():
            self.addSimpleFlag(k, v)

    def _EnsureSorted(self):
        if self._sorted_values:
            return
        self._sorted_values = sorted(self.value_to_str.keys())
#        self._sorted_strs = sorted(self.str_to_value.keys())


    def flagsToStringList(self, flagint):
        """Translate a given flag int to a list of flag strings."""
        self._EnsureSorted()
        strs = []
        for v in self._sorted_values:
            if flagint & v != 0:
                strs.append(self.value_to_str[v])
                flagint &= ~v
        if flagint != 0:
            strs.append("{0:#x}".format(flagint))
        return strs

    def UFLookup(self, key, **kwargs):
        return uflookup.UFLookup(self.str_to_value, key, **kwargs)

    # TBD: interface to enable a script --dump-flag-translations argument?



def join_flaglist(fl, sep = "|", empty = "0"):
    """Helper function to join a list of flag strings."""
    if fl:
        return sep.join(fl)
    else:
        return empty


### Tests

# I'm trying to follow the convention of

#   assertEquals(expectedvalue, function_under_test(args))

# I didn't discover that (on some unittest page) until I was halfway
# through, so I may not have gotten them all the right order.

if __name__ == '__main__':
    import unittest

    class Test_join_flaglist(unittest.TestCase):
        """Test the join_flaglist function"""

        def assertJoinFlaglistEqual(self, expectedstring, flaglist):
            self.assertEqual(expectedstring, join_flaglist(flaglist))

        def test_single_value(self):
            """Test join_flaglist() with a single value"""
            self.assertJoinFlaglistEqual("aflag", ["aflag"])

        def test_two_values(self):
            """Test join_flaglist() with two values"""
            self.assertJoinFlaglistEqual("aflag|bflag",["aflag", "bflag"])

        def test_three_values(self):
            """Test join_flaglist() with three values"""
            self.assertJoinFlaglistEqual("af|bf|cf", ["af", "bf", "cf"])

        def test_comma_sep(self):
            """Test join_flaglist() with a non-default sep"""
            self.assertEqual("af,bf,cf",
                             join_flaglist(["af", "bf", "cf"], sep=','))

        def test_join_empty(self):
            """Test join_flaglist() with an empty list"""
            self.assertEqual("0", join_flaglist([]))

        def test_join_empty_nondefault(self):
            """Test join_flaglist() with a non-default value of empty"""
            self.assertEqual(" ", join_flaglist([], empty=" "))


    class Test_FlagSet(unittest.TestCase):
        """Test the FlagSet class"""

        def setUp(self):
            self.fs = FlagSet()

        def VerifyFlag(self, string, value):
            """Test string->value and value->string"""
            self.assertEqual(value, self.fs.str_to_value[string])
            self.assertEqual(string, self.fs.value_to_str[value])
            self.assertEqual(value, getattr(self.fs, string))
            self.assertEqual(value, 1<<getattr(self.fs, string+"_shift"))

    class Test_FlagSet_Constructor(Test_FlagSet):
        def test_constructor(self):
            """Too much?"""
            self.assertEqual(self.fs._next_bit, 0)
            self.assertFalse(self.fs.value_to_str)
            # etc.

    class Test_Add_Simple_Flag(Test_FlagSet):
        def test_add_simple_flag(self):
            """Test that adding a simple flag to an empty FlagSet works"""
            self.fs.addSimpleFlag("FOO")
            self.VerifyFlag("FOO", 1)

        def test_3_add_simple_flag(self):
            """Test multiple addSimpleFlag calls"""
            self.fs.addSimpleFlag("FOO")
            self.fs.addSimpleFlag("BAR")
            self.fs.addSimpleFlag("BAZ")

            self.VerifyFlag("FOO", 1)
            self.VerifyFlag("BAR", 2)
            self.VerifyFlag("BAZ", 4)

            self.assertEqual(1, self.fs.FOO)
            self.assertEqual(2, self.fs.BAR)
            self.assertEqual(4, self.fs.BAZ)

            self.assertEqual(0, self.fs.FOO_shift)
            self.assertEqual(1, self.fs.BAR_shift)
            self.assertEqual(2, self.fs.BAZ_shift)

            self.fs._EnsureSorted()
#            self.assertEqual(self.fs._sorted_strs, ["BAR", "BAZ", "FOO"])
            self.assertEqual(self.fs._sorted_values, [1, 2, 4])

        def test_add_simple_flag_with_value(self):
            """Test addSimpleFlag calls with explicit bit="""
            self.fs.addSimpleFlag("FOO")
            self.fs.addSimpleFlag("BAR", bit=1)
            self.fs.addSimpleFlag("BAZ")
            self.fs.addSimpleFlag("BLAT", bit=17)
            self.fs.addSimpleFlag("FROB")
            self.fs.addSimpleFlag("SNARF", bit=5)

            self.VerifyFlag("FOO", 1)
            self.VerifyFlag("BAR", 2)
            self.VerifyFlag("BAZ", 4)
            self.VerifyFlag("SNARF", 32)
            self.VerifyFlag("BLAT", 1<<17)
            self.VerifyFlag("FROB", 1<<18)

            self.fs._EnsureSorted()
#            self.assertEqual(self.fs._sorted_strs,
#                             ["BAR", "BAZ", "BLAT", "FOO", "FROB"])
            self.assertEqual(self.fs._sorted_values,
                             [1, 2, 4, 32, 1<<17, 1<<18])


        def test_add_simple_flag_dup_name(self):
            """Test exception on duplicate flag name"""
            self.fs.addSimpleFlag("FOO")
            self.assertRaises(ValueError, self.fs.addSimpleFlag, "FOO")

        def test_add_simple_flag_dup_value(self):
            """Test exception on duplicate flag value"""
            self.fs.addSimpleFlag("FOO")
            self.fs.addSimpleFlag("BAR", bit=0)

            self.VerifyFlag("FOO", 1)
            self.assertEqual(1, self.fs.str_to_value["BAR"])

        def test_add_shift_duplicated_name(self):
            """Test that name and name_shift can't both be added"""
            self.fs.addSimpleFlag("FOO_shift")
            self.assertRaises(ValueError, self.fs.addSimpleFlag, "FOO")
            self.assertRaises(ValueError,
                              self.fs.addSimpleFlag, "FOO_shift_shift")

        def test_attr_name_conflict(self):
            """Test that adding a flag won't clobber an object attribute"""
            self.assertRaises(ValueError,
                              self.fs.addSimpleFlag, "addSimpleFlag")

    class Test_Add_Simple_Flags(Test_FlagSet):
        def test_add_simple_flags(self):
            """Test that addSimpleFlags() can add several flags"""

            self.fs.addSimpleFlags("FOO", "BAR", "BAZ")
            self.VerifyFlag("FOO", 1)
            self.VerifyFlag("BAR", 2)
            self.VerifyFlag("BAZ", 4)

    class Test_FlagSet_mapping(Test_FlagSet):
        def setUp(self):
            self.fs = FlagSet(mapping={"FOO": 9, "BAR": 1})

        def test_constructor(self):
            self.VerifyFlag("FOO", 1<<9)
            self.VerifyFlag("BAR", 1<<1)

        def test_addMap(self):
            self.fs.addMap({"BAZ": 3, "ZING": 7})

            self.VerifyFlag("FOO", 1<<9)
            self.VerifyFlag("BAR", 1<<1)
            self.VerifyFlag("BAZ", 1<<3)
            self.VerifyFlag("ZING", 1<<7)

    class Test_FlagSet_FBBZZ(Test_FlagSet):
        """FlagSet with certain set of flags"""
        def setUp(self):
            self.fs = FlagSet()
            self.fs.addSimpleFlags("FOO", "BAR", "BAZ")
            self.fs.addSimpleFlag("ZING", bit=13)
            self.fs.addSimpleFlag("ZOING", bit=42)

        def Verify_F2SL(self, expectedstrlist, flags):
            self.assertEqual(expectedstrlist, self.fs.flagsToStringList(flags))

    class Test_FlagSet_FBBZZ_flagsToStringList(Test_FlagSet_FBBZZ):
        def test_F(self):
            self.Verify_F2SL(["FOO"], 1)
        def test_B(self):
            self.Verify_F2SL(["BAR"], 2)
        def test_B2(self):
            self.Verify_F2SL(["BAZ"], 4)
        def test_Z(self):
            self.Verify_F2SL(["ZING"], 1<<13)
        def test_Z2(self):
            self.Verify_F2SL(["ZOING"], 1<<42)

        def test_FB(self):
            self.Verify_F2SL(["FOO", "BAR"], 3)
        def test_FBB(self):
            self.Verify_F2SL(["FOO", "BAR", "BAZ"], 7)
        def test_FB2(self):
            self.Verify_F2SL(["BAR", "BAZ"], 6)

        def test_FBBZZ(self):
            self.Verify_F2SL(["FOO", "BAR", "BAZ", "ZING", "ZOING"],
                             7|1<<13|1<<42)

        def test_unknownflag(self):
            self.Verify_F2SL(["0x10"], 0x10)
        def test_unknownflags(self):
            self.Verify_F2SL(["0x30"], 0x30)
        def test_knownandunknownflags(self):
            self.Verify_F2SL(["FOO", "0x30"], 0x31)


    # Run all unit tests
    unittest.main()