Python DevCenter
oreilly.comSafari Books Online.Conferences.


Building Decision Trees in Python
Pages: 1, 2, 3, 4, 5

The Decision Tree Learning Algorithm

With most of the preliminary information out of the way, you can now look at the actual decision tree algorithm. The following code listing is the main function used to create your decision tree:

def create_decision_tree(data, attributes, target_attr, fitness_func):
    Returns a new decision tree based on the examples given.
    data    = data[:]
    vals    = [record[target_attr] for record in data]
    default = majority_value(data, target_attr)

    # If the dataset is empty or the attributes list is empty, return the
    # default value. When checking the attributes list for emptiness, we
    # need to subtract 1 to account for the target attribute.
    if not data or (len(attributes) - 1) <= 0:
        return default
    # If all the records in the dataset have the same classification,
    # return that classification.
    elif vals.count(vals[0]) == len(vals):
        return vals[0]
        # Choose the next best attribute to best classify our data
        best = choose_attribute(data, attributes, target_attr,

        # Create a new decision tree/node with the best attribute and an empty
        # dictionary object--we'll fill that up next.
        tree = {best:{}}

        # Create a new decision tree/sub-node for each of the values in the
        # best attribute field
        for val in get_values(data, best):
            # Create a subtree for the current value under the "best" field
            subtree = create_decision_tree(
                get_examples(data, best, val),
                [attr for attr in attributes if attr != best],

            # Add the new subtree to the empty dictionary object in our new
            # tree/node we just created.
            tree[best][val] = subtree

    return tree

The create_decision_tree function starts off by declaring three variables: data, vals, and default. The first, data, is just a copy of the data list being passed into the function. The reason I do this is because Python passes all mutable data types, such as dictionaries and lists, by reference. It's a good rule of thumb to make a copy of any of these in order to keep from accidentally altering the original data. vals is a list of all the values in the target attribute for each record in the data set, and default holds the default value that is returned from the function when the data set is empty. That is simply the value in the target attribute with the highest frequency, and thus, the best guess for when the decision tree is unable to classify a record.

The next lines are the real nitty-gritty of the algorithm. The algorithm makes use of recursion to create the decision tree, and as such it needs a base case (or, in this case, two base cases) to prevent it from entering an infinite recursive loop. What are the base cases for this algorithm? For starters, if either the data or attributes list is empty, then the algorithm has reached a stopping point. The first if-then statement takes care of this case. If either list is empty, then the algorithm returns a default value. (Actually, for the attributes list, check to see whether it has only one attribute in it, because the attributes list also contains the target attribute, which the decision tree never uses; that is what the tree should predict.) It returns the value with the highest frequency in the data set for the target attribute. The only other case to worry about is when the remaining records in the data list all have the same value for the target attribute, in which case the algorithm returns that value.

Those are the base cases. What about the recursive case? Well, when everything else is normal (that is, the data and attributes lists are not empty and the records in the list of data still have multiple values for the target attribute), the algorithm needs to choose the "next best" attribute for classifying the test data and add it to the decision tree. The choose_attribute function is responsible for picking the "next best" attribute for classifying the records in the test data set. After this, the code creates a new decision tree containing only the newly selected "best" attribute. Then the recursion takes place. In other words, each of the subtrees is created by making a recursive call to the create_decision_tree function and adding the returned tree to the newly created tree in the last step.

The first step in this process is getting the "next best" attribute from the set of available attributes. The call to choose_attribute takes care of this step. The next step is to create a new decision tree containing the chosen attribute as the root node. All that remains to do after this is to create the subtrees for each of the values in the best attribute. The get_values function cycles through each of the records in the data set and returns a list containing the unique values for the chosen attribute. Next, the code loops through each of these unique values and creates a subtree for them by making a recursive call to the create_decision_tree function. The call to get_examples just returns a list of all the records in the data set that have the value val for the attribute defined by the best variable. This list of examples passes to the create_decision_tree function along with the list of remaining attributes (minus the currently selected "next best" attribute). The call to create_decision_tree will return the subtree for the remaining list of attributes and the subset of data passed into it. All that's left is to add each of these subtrees to the current decision tree and return it.

If you're not used to recursion, this process can seem a bit strange. Take some time to look over the code and make sure that you understand what is happening here. Create a little script to run the function and print out the tree (or, just alter to do so), so you can get a better idea of how it's functioning. It's really a good idea to take your time and make sure you understand what's happening, because many programming problems lend themselves to a recursive solution--you just may be adding a very important tool to your programming arsenal.

That's about all there is to the algorithm; everything else is just helper functions to the main algorithm. Most of the functions should be fairly self-explanatory, with the exception of the ID3 heuristic.

Pages: 1, 2, 3, 4, 5

Next Pagearrow

Sponsored by: