====== Python CPT Wrappers ======
class CPTWrapper(object):
pass;
class SimpleDiscreteCPTBN(CPTWrapper):
def __init__(self, nodeOrder, map, stateSpace, jft):
'''
nodeOrder: [node1, node2, node3, ...]
map: {node1:[], node2:[node1], node3:[node1, node2], ...}
stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}
jft: {node1:[0, 2, 3, 1, 1, ...], node2:[3, 3, 1, 1, 2, ...], node3:[1, 1, 0, 0, 0, ...], ...}
'''
self.__nodeMap = map;
self.__orderedNodeList = nodeOrder;
self.__stateSpace = stateSpace;
self.__cfts = self.__JFTtoCFT(nodeOrder, map, jft);
self.dataPoints = len(jft[nodeOrder[0]]);
def __JFTtoCFT(self, nodeOrder, map, jft):
'''
return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]
'''
#return: {node1:[{node3:1, node1:0, node2:3}, {node3:1, node1:2, node2:3}, ...]
retVal = defaultdict(int);
# Cycle through each node in order
for node in nodeOrder:
parents = map[node];
#retVal[node] = [];
# Cycle through each data point
for i in range(len(jft[node])):
# Define the state space for this CFT entry
state = {};
state[node] = jft[node][i];
for parent in parents:
state[parent] = jft[parent][i];
# Define the CFT entry
#retVal[node].append(state);
retVal[repr(state)] += 1;
return retVal;
def stateFreq(self, state, node):
return (self.__cfts[repr(state)] + 1, self.dataPoints + len(self.__stateSpace[node]));
import MySQLdb
class SQLDiscreteCPTBN(CPTWrapper):
class BNStateKey(object):
def __init__(self, nodeOrder, state):
self.nodeOrder = nodeOrder;
self.state = state;
def nodeOrder(self):
return self.nodeOrder;
def state(self):
return self.state;
def __repr__(self):
retVal = "{";
for node in self.nodeOrder:
if self.state.has_key(node):
retVal += "'" + str(node) + "':" + str(self.state[node]) + ", ";
retVal += "}";
return retVal;
def __init__(self, nodeOrder, map, stateSpace, database, username, password, jftTable):
'''
nodeOrder: [node1, node2, node3, ...]
map: {node1:[], node2:[node1], node3:[node1, node2], ...}
stateSpace: {node1:['a', 'b', 'c'], node2:[1, 3, 5, 7], node3:['high', 'low'], ...}
'''
self.__database = database;
self.__username = username;
self.__password = password;
self.__nodeMap = map;
self.__orderedNodeList = nodeOrder;
self.__stateSpace = stateSpace;
self.__cfts = self.__JFTtoCFT(nodeOrder, map, jftTable);
def getNodes(self):
return self.__orderedNodeList;
def getParents(self, node):
return self.__nodeMap(node);
def getStates(self, node):
return self.__stateSpace(node);
def __JFTtoCFT(self, nodeOrder, map, jftTable):
'''
return: {"{node3:1, node1:0, node2:3}":1, "{node3:1, node1:2, node2:3}":1, ...]
'''
retVal = {};
# Cycle through each node in order
db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
for node in nodeOrder:
parents = map[node];
# Check if table exists
tableName = jftTable + "_" + node;
for parent in parents:
tableName += "_" + parent;
checkSQL = "SHOW TABLES LIKE '" + tableName + "'";
c.execute(checkSQL);
tableExists = len(c.fetchall()) > 0;
if(not tableExists):
# Construct SQL statement
createSQL = "CREATE TABLE IF NOT EXISTS ";
createSQL += tableName;
createSQL += " SELECT ";
nodeListSQL = node;
for parent in parents:
nodeListSQL += ", " + parent;
createSQL += nodeListSQL;
createSQL += ", SUM(frequency) as frequency ";
createSQL += "FROM " + jftTable + " ";
createSQL += "GROUP BY " + nodeListSQL;
# Gather results into DB table
c.execute(createSQL);
db.commit();
retVal[node] = tableName;
# Close DB connections`
c.close();
db.close();
return retVal;
def stateFreq(self, node, *args, **kwargs):
'''
stateFreq(node, [condNodeList], [evidence={}], [cft=True or False])
-> state table: ({'A':0.3, 'B':8, 'frequency':4}, ...) OR
-> conditional frequency table: {"{'A':0.3, 'B':8}":4, ...}
node: The node in this BN to examine
condNodeList: A conditional node list identifying each of the
conditional nodes of interest. This parameter is optional and is of
the form [node1, node2, ...]. Nodes in this list must be a subset of
the parent nodes of the given node.
cft: True will cause this method to return a conditional frequency
table. False will cause this method to return a state table.
evidence: A dictionary mapping conditional nodes to a value. The nodes
in the dictionary need not be in the conditional node list. This
parameter is optional.
'''
condNodeList = [];
if(len(args) > 0):
condNodeList = args[0];
# Construct SQL statement
db = MySQLdb.connect(passwd=self.__password, db=self.__database, user=self.__username);
c = db.cursor(cursorclass=MySQLdb.cursors.DictCursor);
selectSQL = "SELECT ";
nodeListSQL = node + ", ";
for condNode in condNodeList:
nodeListSQL += condNode + ", ";
selectSQL += nodeListSQL;
selectSQL += "SUM(frequency) as frequency FROM " + self.__cfts[node];
if(kwargs.has_key('evidence') and kwargs['evidence'] != {}):
evidence = kwargs['evidence'];
selectSQL += " WHERE ";
for eNode in evidence.keys():
if eNode in self.getParents(node):
selectSQL += eNode + " = " + str(evidence[eNode]) + " AND ";
selectSQL = selectSQL[:len(selectSQL) - 5];
selectSQL += " GROUP BY " + nodeListSQL.rstrip(", ");
# Gather results
c.execute(selectSQL);
retVal = c.fetchall();
c.close();
# Build CFT if asked to
if(kwargs.has_key('cft') and kwargs['cft']):
newRetVal = {};
for item in retVal:
val = item.pop('frequency');
newRetVal[repr(self.BNStateKey(self.getNodes(), item))] = val;
retVal = newRetVal;
# Close DB connections`
c.close();
db.close();
return retVal;
~~ODT~~
\\
\\
\\
\\
\\
\\
\\
\\
\\
\\
~~DISCUSSION~~