Viewing: argparse_ext.py


"""
Module which provides extensions for the standard Python argparse module.
Copyright 2015 Cray Inc.  All Rights Reserved
"""

import argparse
import copy

from argparse import _ensure_value, Action, ArgumentTypeError


class ExtendAction(Action):
    """Action to extend a list of argument values

    This action is similar to the standard AppendAction, but uses the
    extend() attribute of lists rather than the append() attribute.  As
    such, it also has an additional requirement:

    -   This action must receive an iterable 'values' argument from the
        parser.  There are two ways to make this happen:

        1.  Use type= to produce an iterable, e.g. type=str or type=list
        2.  Use nargs= to cause the parser to produe a list, which it
            does for any nargs= setting that is not None (default) and
            is not '?'
    """

    def __call__(self, parser, namespace, values, option_string):
        items = copy.copy(_ensure_value(namespace, self.dest, []))

        try:
            items.extend(values)
        except TypeError:
            # Assume the TypeError is because values is not iterable
            raise ArgumentTypeError(
                "argument type '{:s}' is not iterable".format(
                    type(values).__name__))

        setattr(namespace, self.dest, items)


def str2list(string, sep=',', totype=None, choices=None):
    """Split a string into a list with conversion and validation

    Split a string into a list, optionally convert each element to a
    given type and optionally validate that all resulting values are
    in a collection of valid values.
    """

    plural = {False: '', True: 's'}

    # Values should be string or an iterable container of strings.
    # Split values on the separator into a list
    try:
        lst = string.split(sep)
    except AttributeError:
        raise ArgumentTypeError(
            "argument type '{:s}' does not have split() attribute".format(
                type(string).__name__))

    # Perform type conversion
    if totype is not None:
        errs = []
        for i, v in enumerate(lst):
            try:
                lst[i] = totype(v)
            except (TypeError, ValueError):
                errs.append(v)
        if errs:
            msg = "invalid {:s} value{:s}: {!r:s}".format(
                totype.__name__, plural[len(errs) > 1], errs)
            raise ArgumentTypeError(msg)

    # Verify each separate value
    if choices is not None:
        errs = filter(lambda x:x not in choices, lst)
        if errs:
            msg = "invalid choice{:s}: {!r:s} (choose from {!s:s})".format(
                plural[len(errs) > 1], errs, choices)
            raise ArgumentTypeError(msg)

    return lst


def tolist(sep=',', totype=None, choices=None):
    """Returns a parameterized callable for argument parser type conversion

    This function returns a function which accepts a single argument at
    call time and which uses the supplied arguments to modify its conversion
    behavior.
    """
    return lambda x:str2list(x, sep=sep, totype=totype, choices=choices)


if __name__ == '__main__':
    import unittest

    class Test_Action_Base(unittest.TestCase):
        """Create a base class for testing argparse Action classes"""

        def setUp(self):
            """Create the ExtendAction object and args Namespace"""
            self.action = ExtendAction([], dest='dest')
            self.args   = argparse.Namespace()

        def actionRun(self, values):
            """Run the Action instance using values"""
            self.action(None, self.args, values, '')

        def actionEqual(self, values, expected):
            """Run the Action and check the expected results"""
            self.actionRun(values)
            self.assertEqual(self.args.dest, expected)

        def actionArgTypeErr(self, values):
            """Run the Action and verify it raises ArgumentTypeError"""
            self.assertRaises(
                ArgumentTypeError, self.action, None, self.args, values, '')


    class Test_ExtendAction(Test_Action_Base):
        """Test the ExtendAction class"""

        def test_non_iterable(self):
            """Test ExtendAction with a non-iterable type

            This is similar to:
                parser.add_argument('-z', nargs=None, type=int ...)

                parser.parse_args(['-z', '0'])
            """
            self.actionArgTypeErr(0)

        def test_single_value(self):
            """Test ExtendAction with a single value

            This is similar to:
                parser.add_argument('-z', nargs=None ...)

                parser.parse_args(['-z', 'a'])
            """
            self.actionEqual('a', ['a'])

        def test_single_string(self):
            """Test ExtendAction with a single value

            This is similar to:
                parser.add_argument('-z', nargs=None ...)

                parser.parse_args(['-z', 'abc'])
            """
            self.actionEqual('abc', ['a', 'b', 'c'])

        def test_single_value_multiple_calls(self):
            """Test ExtendAction with a single value and multiple calls

            This is similar to:
                parser.add_argument('-z', nargs=None, type=int ...)

                parser.parse_args(['-z', 'a', '-z', 'b'])
            """
            self.actionEqual('a', ['a'])
            self.actionEqual('b', ['a', 'b'])

        def test_value_list(self):
            """Test ExtendAction with a value list

            This is similar to:
                parser.add_argument('-z', nargs=1 ...)

                parser.parse_args(['-z', 'abc'])
            """
            self.actionEqual(['abc'], ['abc'])

        def test_value_list_multiple_calls(self):
            """Test ExtendAction with a single value and multiple calls

            This is similar to:
                parser.add_argument('-z', nargs=1 ...)

                parser.parse_args(['-z', 'abc', '-z', 'def'])
            """
            self.actionRun(['abc'])
            self.actionEqual(['def'], ['abc', 'def'])

        def test_value_list_multiple_values(self):
            """Test ExtendAction with a value list of length > 1

            This is similar to:
                parser.add_argument('-z', nargs=2 ...)
                -or-
                parser.add_argument('-z', nargs='+' ...)
                -or-
                parser.add_argument('-z', nargs='*' ...)

                parser.parse_args(['-z', 'abc', 'def'])
            """
            self.actionEqual(['abc', 'def'], ['abc', 'def'])


    class Test_tolist_str2list(unittest.TestCase):
        """Test the str2list and tolist conversion functions"""

        def test_sep(self):
            """Verify default and non-default separators work"""
            f = tolist()
            self.assertEqual(f('a,b,c'), ['a','b','c'])
            f = tolist(sep=':')
            self.assertEqual(f('a:b:c'), ['a','b','c'])

        def test_non_iterable(self):
            """Verify a non-iterable string is caught"""
            f = tolist()
            self.assertRaises(ArgumentTypeError, f, 0)

        def test_type_conversion(self):
            """Verify type conversion works properly"""
            f = tolist(totype=int)
            self.assertEqual(f('0,1,2'), [0, 1, 2])
            self.assertRaises(
                ArgumentTypeError, f, '1,z,2,q')

        def test_choices(self):
            """Verify the choices validation works properly"""
            f = tolist(totype=int, choices=[0, 1, 2, 3])
            self.assertEqual(f('0,1,2'), [0, 1, 2])
            self.assertRaises(ArgumentTypeError, f, '0,5,2')


    unittest.main()